From e8708dc990daa6ac392bc89e7d6d2b69a33689f9 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 15:40:43 -0800 Subject: [PATCH 1/2] [tx] Add max_micro_batches config to limit batch size in engine Add max_micro_batches config (default: 64) to EngineConfig to limit how many micro batches are processed before returning results to clients. This prevents long wait times when clients send large numbers of requests. The limit behavior depends on train_micro_batch_size: - When > 0: counts micro batches as ceil(sequences / micro_batch_size) - When = 0 (full batch mode): each request counts as 1 Always includes at least one request to avoid starvation. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/tests/tinker/test_engine.py | 86 +++++++++++++++++++++++++++- skyrl-tx/tx/tinker/config.py | 4 ++ skyrl-tx/tx/tinker/engine.py | 21 +++++++ 3 files changed, 110 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 3319a8c3e..8756db3ca 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -1,12 +1,13 @@ from cloudpathlib import AnyPath from datetime import datetime, timedelta, timezone +import pytest from sqlmodel import Session, SQLModel from tx.tinker.engine import TinkerEngine from tx.tinker.config import EngineConfig from tx.tinker import types -from tx.tinker.db_models import SessionDB, ModelDB +from tx.tinker.db_models import SessionDB, ModelDB, FutureDB, RequestStatus BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -80,3 +81,86 @@ def test_cleanup_stale_sessions(): # Run cleanup and assert one model was unloaded assert engine.cleanup_stale_sessions() == 1 assert not engine.backend.has_model(model_id) + + +class TestMaxMicroBatches: + """Tests for max_micro_batches limiting in find_batchable_model_passes.""" + + @staticmethod + def _make_request_data(num_sequences: int) -> dict: + """Create a ForwardBackwardInput request data with the given number of sequences.""" + data = [] + for _ in range(num_sequences): + data.append({ + "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, + "loss_fn_inputs": { + "target_tokens": {"data": [2, 3, 4]}, + "weights": {"data": [1.0, 1.0, 1.0]}, + "advantages": {"data": [0.0, 0.0, 0.0]}, + "logprobs": {"data": [0.0, 0.0, 0.0]}, + }, + }) + return {"data": data, "loss_fn": "cross_entropy"} + + @staticmethod + def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> TinkerEngine: + """Create an engine with the given micro batch configuration.""" + config = EngineConfig( + base_model=BASE_MODEL, + checkpoints_base=AnyPath(""), + backend_config={"max_lora_adapters": 4, "max_lora_rank": 32, "train_micro_batch_size": train_micro_batch_size}, + max_micro_batches=max_micro_batches, + database_url="sqlite:///:memory:", + ) + engine = TinkerEngine(config) + SQLModel.metadata.create_all(engine.db_engine) + return engine + + def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): + """Add FORWARD_BACKWARD requests with the given sequence counts.""" + with Session(engine.db_engine) as session: + for num_sequences in sequence_counts: + session.add(FutureDB( + request_type=types.RequestType.FORWARD_BACKWARD, + model_id="model1", + request_data=self._make_request_data(num_sequences), + status=RequestStatus.PENDING, + )) + session.commit() + + @pytest.mark.parametrize( + "train_micro_batch_size,max_micro_batches,sequence_counts,expected_count", + [ + # Gradient accumulation mode: ceil(16/4) + ceil(20/4) = 4 + 5 = 9 <= 10, ceil(8/4) = 2 would exceed + (4, 10, [16, 20, 8], 2), + # Full batch mode: each request counts as 1, so 3 requests fit in max_micro_batches=3 + (0, 3, [100, 200, 50, 75], 3), + # Disabled: all requests included when max_micro_batches=0 + (4, 0, [50] * 10, 10), + ], + ids=["gradient_accumulation", "full_batch_mode", "disabled"], + ) + def test_micro_batch_limiting(self, train_micro_batch_size, max_micro_batches, sequence_counts, expected_count): + """Test that micro batches are limited correctly under different configurations.""" + engine = self._create_engine(train_micro_batch_size, max_micro_batches) + self._add_requests(engine, sequence_counts) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == expected_count + + def test_always_includes_at_least_one_request(self): + """Test that at least one request is always included even if it exceeds the limit.""" + # train_micro_batch_size=4, max_micro_batches=10 + # Request with 100 sequences = ceil(100/4) = 25 micro batches > 10 + # Should still be included to avoid starvation + engine = self._create_engine(train_micro_batch_size=4, max_micro_batches=10) + self._add_requests(engine, [100]) + + with Session(engine.db_engine) as session: + result = engine.find_batchable_model_passes(session, types.RequestType.FORWARD_BACKWARD) + + assert len(result) == 1 + _, req_data = list(result.values())[0] + assert len(req_data.data) == 100 diff --git a/skyrl-tx/tx/tinker/config.py b/skyrl-tx/tx/tinker/config.py index e126e5499..ab11a3d33 100644 --- a/skyrl-tx/tx/tinker/config.py +++ b/skyrl-tx/tx/tinker/config.py @@ -51,6 +51,10 @@ class EngineConfig(BaseModel): default=300, description="Seconds without heartbeat before session is considered stale. Set to -1 to disable cleanup.", ) + max_micro_batches: int = Field( + default=64, + description="Maximum number of micro batches per forward/forward_backward batch. Limits how many are processed before returning results to clients. Set to 0 to disable.", + ) def convert_env_var(env_name: str, env_value: str, expected_type: type): diff --git a/skyrl-tx/tx/tinker/engine.py b/skyrl-tx/tx/tinker/engine.py index b7f4b2917..2ba174a73 100644 --- a/skyrl-tx/tx/tinker/engine.py +++ b/skyrl-tx/tx/tinker/engine.py @@ -1,6 +1,7 @@ """Background engine for processing training requests.""" import argparse +import math import time from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -270,6 +271,26 @@ def find_batchable_model_passes( # Filter: only include ops that come before their model's barrier batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]] + # Limit total micro batches if configured + if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend): + micro_batch_size = self.backend.config.train_micro_batch_size + limited = [] + total_micro_batches = 0 + for op in batchable: + num_sequences = len(op.request_data.get("data", [])) + if micro_batch_size > 0: + # Gradient accumulation enabled: count actual micro batches + num_micro_batches = math.ceil(num_sequences / micro_batch_size) + else: + # Full batch mode: each request is processed as one unit + num_micro_batches = 1 + # Always include at least one request to avoid starvation + if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches: + break + limited.append(op) + total_micro_batches += num_micro_batches + batchable = limited + return { str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data)) for f in batchable From be102bb8ea049ac32f8214be1379e0dc69430d51 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 26 Jan 2026 17:26:03 -0800 Subject: [PATCH 2/2] lint --- skyrl-tx/tests/tinker/test_engine.py | 40 +++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/skyrl-tx/tests/tinker/test_engine.py b/skyrl-tx/tests/tinker/test_engine.py index 8756db3ca..706ace219 100644 --- a/skyrl-tx/tests/tinker/test_engine.py +++ b/skyrl-tx/tests/tinker/test_engine.py @@ -91,15 +91,17 @@ def _make_request_data(num_sequences: int) -> dict: """Create a ForwardBackwardInput request data with the given number of sequences.""" data = [] for _ in range(num_sequences): - data.append({ - "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, - "loss_fn_inputs": { - "target_tokens": {"data": [2, 3, 4]}, - "weights": {"data": [1.0, 1.0, 1.0]}, - "advantages": {"data": [0.0, 0.0, 0.0]}, - "logprobs": {"data": [0.0, 0.0, 0.0]}, - }, - }) + data.append( + { + "model_input": {"chunks": [{"tokens": [1, 2, 3]}]}, + "loss_fn_inputs": { + "target_tokens": {"data": [2, 3, 4]}, + "weights": {"data": [1.0, 1.0, 1.0]}, + "advantages": {"data": [0.0, 0.0, 0.0]}, + "logprobs": {"data": [0.0, 0.0, 0.0]}, + }, + } + ) return {"data": data, "loss_fn": "cross_entropy"} @staticmethod @@ -108,7 +110,11 @@ def _create_engine(train_micro_batch_size: int, max_micro_batches: int) -> Tinke config = EngineConfig( base_model=BASE_MODEL, checkpoints_base=AnyPath(""), - backend_config={"max_lora_adapters": 4, "max_lora_rank": 32, "train_micro_batch_size": train_micro_batch_size}, + backend_config={ + "max_lora_adapters": 4, + "max_lora_rank": 32, + "train_micro_batch_size": train_micro_batch_size, + }, max_micro_batches=max_micro_batches, database_url="sqlite:///:memory:", ) @@ -120,12 +126,14 @@ def _add_requests(self, engine: TinkerEngine, sequence_counts: list[int]): """Add FORWARD_BACKWARD requests with the given sequence counts.""" with Session(engine.db_engine) as session: for num_sequences in sequence_counts: - session.add(FutureDB( - request_type=types.RequestType.FORWARD_BACKWARD, - model_id="model1", - request_data=self._make_request_data(num_sequences), - status=RequestStatus.PENDING, - )) + session.add( + FutureDB( + request_type=types.RequestType.FORWARD_BACKWARD, + model_id="model1", + request_data=self._make_request_data(num_sequences), + status=RequestStatus.PENDING, + ) + ) session.commit() @pytest.mark.parametrize(