Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion skyrl-tx/tests/tinker/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from cloudpathlib import AnyPath
from datetime import datetime, timedelta, timezone

import pytest
from sqlmodel import Session, SQLModel

from tx.tinker.engine import TinkerEngine
from tx.tinker.config import EngineConfig
from tx.tinker import types
from tx.tinker.db_models import SessionDB, ModelDB
from tx.tinker.db_models import SessionDB, ModelDB, FutureDB, RequestStatus


BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM"
Expand Down Expand Up @@ -80,3 +81,94 @@ def test_cleanup_stale_sessions():
# Run cleanup and assert one model was unloaded
assert engine.cleanup_stale_sessions() == 1
assert not engine.backend.has_model(model_id)


class TestMaxMicroBatches:
"""Tests for max_micro_batches limiting in find_batchable_model_passes."""

@staticmethod
def _make_request_data(num_sequences: int) -> dict:
"""Create a ForwardBackwardInput request data with the given number of sequences."""
data = []
for _ in range(num_sequences):
data.append(
{
"model_input": {"chunks": [{"tokens": [1, 2, 3]}]},
"loss_fn_inputs": {
"target_tokens": {"data": [2, 3, 4]},
"weights": {"data": [1.0, 1.0, 1.0]},
"advantages": {"data": [0.0, 0.0, 0.0]},
"logprobs": {"data": [0.0, 0.0, 0.0]},
},
}
)
return {"data": data, "loss_fn": "cross_entropy"}

@staticmethod
def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> TinkerEngine:
"""Create an engine with the given micro batch configuration."""
config = EngineConfig(
base_model=BASE_MODEL,
checkpoints_base=AnyPath(""),
backend_config={
"max_lora_adapters": 4,
"max_lora_rank": 32,
"train_micro_batch_size": train_micro_batch_size,
},
max_micro_batches=max_micro_batches,
database_url="sqlite:///:memory:",
)
engine = TinkerEngine(config)
SQLModel.metadata.create_all(engine.db_engine)
return engine

def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]):
"""Add FORWARD_BACKWARD requests with the given sequence counts."""
with Session(engine.db_engine) as session:
for num_sequences in sequence_counts:
session.add(
FutureDB(
request_type=types.RequestType.FORWARD_BACKWARD,
model_id="model1",
request_data=self._make_request_data(num_sequences),
status=RequestStatus.PENDING,
)
)
session.commit()

@pytest.mark.parametrize(
"train_micro_batch_size,max_micro_batches,sequence_counts,expected_count",
[
# Gradient accumulation mode: ceil(16/4) + ceil(20/4) = 4 + 5 = 9 <= 10, ceil(8/4) = 2 would exceed
(4, 10, [16, 20, 8], 2),
# Full batch mode: each request counts as 1, so 3 requests fit in max_micro_batches=3
(0, 3, [100, 200, 50, 75], 3),
# Disabled: all requests included when max_micro_batches=0
(4, 0, [50] * 10, 10),
],
ids=["gradient_accumulation", "full_batch_mode", "disabled"],
)
def test_micro_batch_limiting(self, train_micro_batch_size, max_micro_batches, sequence_counts, expected_count):
"""Test that micro batches are limited correctly under different configurations."""
engine = self._create_engine(train_micro_batch_size, max_micro_batches)
self._add_requests(engine, sequence_counts)

with Session(engine.db_engine) as session:
result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD)

assert len(result) == expected_count

def test_always_includes_at_least_one_request(self):
"""Test that at least one request is always included even if it exceeds the limit."""
# train_micro_batch_size=4, max_micro_batches=10
# Request with 100 sequences = ceil(100/4) = 25 micro batches > 10
# Should still be included to avoid starvation
engine = self._create_engine(train_micro_batch_size=4, max_micro_batches=10)
self._add_requests(engine, [100])

with Session(engine.db_engine) as session:
result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD)

assert len(result) == 1
_, req_data = list(result.values())[0]
assert len(req_data.data) == 100
4 changes: 4 additions & 0 deletions skyrl-tx/tx/tinker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class EngineConfig(BaseModel):
default=300,
description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.",
)
max_micro_batches: int = Field(
default=64,
description="Maximum number of micro batches per forward/forward_backward batch. Limits how many are processed before returning results to clients. Set to 0 to disable.",
)


def convert_env_var(env_name: str, env_value: str, expected_type: type):
Expand Down
21 changes: 21 additions & 0 deletions skyrl-tx/tx/tinker/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Background engine for processing training requests."""

import argparse
import math
import time
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
Expand Down Expand Up @@ -270,6 +271,26 @@ def find_batchable_model_passes(
# Filter: only include ops that come before their model's barrier
batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]]

# Limit total micro batches if configured
if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend):
micro_batch_size = self.backend.config.train_micro_batch_size
limited = []
total_micro_batches = 0
for op in batchable:
num_sequences = len(op.request_data.get("data", []))
if micro_batch_size > 0:
# Gradient accumulation enabled: count actual micro batches
num_micro_batches = math.ceil(num_sequences / micro_batch_size)
else:
# Full batch mode: each request is processed as one unit
num_micro_batches = 1
# Always include at least one request to avoid starvation
if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches:
break
limited.append(op)
total_micro_batches += num_micro_batches
batchable = limited
Comment on lines +275 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for limiting micro-batches is tightly coupled to the JaxBackend implementation. Specifically, the isinstance(self.backend, JaxBackend) check and accessing self.backend.config.train_micro_batch_size directly violate the Open/Closed Principle. This makes it difficult to extend the engine with other backends that might also support micro-batching, as it would require modifying this core engine logic.

A better design would be to abstract this calculation away from the engine. The backend should be responsible for reporting how many micro-batches a given operation will consume.

I'd suggest adding a method to the AbstractBackend interface, for example:

# In AbstractBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
    """Calculates the number of micro-batches for a given operation."""
    # Default to 1 for backends that don't support micro-batching
    return 1

And JaxBackend would implement it:

# In JaxBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
    if self.config.train_micro_batch_size > 0:
        num_sequences = len(op.request_data.get("data", []))
        return math.ceil(num_sequences / self.config.train_micro_batch_size)
    return 1

Then, this block in the engine could be simplified and decoupled:

        if self.config.max_micro_batches > 0:
            limited = []
            total_micro_batches = 0
            for op in batchable:
                num_micro_batches = self.backend.get_micro_batch_count(op)
                # Always include at least one request to avoid starvation
                if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches:
                    break
                limited.append(op)
                total_micro_batches += num_micro_batches
            batchable = limited

This would make the engine more maintainable and extensible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this when we add a second backend cc @pcmoritz


return {
str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data))
for f in batchable
Expand Down
Loading