Skip to content

Conversation

@raulchen
Copy link
Contributor

Add max_micro_batches config (default: 64) to EngineConfig to limit how many micro batches are processed before returning results to clients. This prevents long wait times as well as retrieve_future requests piling up when clients send large numbers of requests.

The limit behavior depends on train_micro_batch_size:

  • When > 0: counts micro batches as ceil(sequences / micro_batch_size)
  • When = 0 (full batch mode): each request counts as 1

Always includes at least one request to avoid starvation.

Add max_micro_batches config (default: 64) to EngineConfig to limit
how many micro batches are processed before returning results to clients.
This prevents long wait times when clients send large numbers of requests.

The limit behavior depends on train_micro_batch_size:
- When > 0: counts micro batches as ceil(sequences / micro_batch_size)
- When = 0 (full batch mode): each request counts as 1

Always includes at least one request to avoid starvation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@vercel
Copy link

vercel bot commented Jan 26, 2026

@raulchen is attempting to deploy a commit to the Tyler's projects Team on Vercel.

A member of the Team first needs to authorize it.

@raulchen raulchen changed the title [tx] Add max_micro_batches config to limit batch size in engine [tx] Add max_micro_batches to limit batch size for forward/forward_backward requests Jan 26, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a max_micro_batches configuration to the engine to limit the number of micro-batches processed in a single batch. This is a useful feature to prevent long wait times for clients. The implementation includes the necessary configuration, engine logic, and a comprehensive set of tests to validate the new behavior.

My main feedback is on a design choice in the engine logic that creates a tight coupling with the JaxBackend. I've left a comment with a suggestion to refactor this for better maintainability and extensibility. Otherwise, the changes look good and the tests are well-written.

Comment on lines +275 to +292
if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend):
micro_batch_size = self.backend.config.train_micro_batch_size
limited = []
total_micro_batches = 0
for op in batchable:
num_sequences = len(op.request_data.get("data", []))
if micro_batch_size > 0:
# Gradient accumulation enabled: count actual micro batches
num_micro_batches = math.ceil(num_sequences / micro_batch_size)
else:
# Full batch mode: each request is processed as one unit
num_micro_batches = 1
# Always include at least one request to avoid starvation
if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches:
break
limited.append(op)
total_micro_batches += num_micro_batches
batchable = limited
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for limiting micro-batches is tightly coupled to the JaxBackend implementation. Specifically, the isinstance(self.backend, JaxBackend) check and accessing self.backend.config.train_micro_batch_size directly violate the Open/Closed Principle. This makes it difficult to extend the engine with other backends that might also support micro-batching, as it would require modifying this core engine logic.

A better design would be to abstract this calculation away from the engine. The backend should be responsible for reporting how many micro-batches a given operation will consume.

I'd suggest adding a method to the AbstractBackend interface, for example:

# In AbstractBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
    """Calculates the number of micro-batches for a given operation."""
    # Default to 1 for backends that don't support micro-batching
    return 1

And JaxBackend would implement it:

# In JaxBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
    if self.config.train_micro_batch_size > 0:
        num_sequences = len(op.request_data.get("data", []))
        return math.ceil(num_sequences / self.config.train_micro_batch_size)
    return 1

Then, this block in the engine could be simplified and decoupled:

        if self.config.max_micro_batches > 0:
            limited = []
            total_micro_batches = 0
            for op in batchable:
                num_micro_batches = self.backend.get_micro_batch_count(op)
                # Always include at least one request to avoid starvation
                if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches:
                    break
                limited.append(op)
                total_micro_batches += num_micro_batches
            batchable = limited

This would make the engine more maintainable and extensible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this when we add a second backend cc @pcmoritz

@pcmoritz pcmoritz added the tx label Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants