-
Notifications
You must be signed in to change notification settings - Fork 243
[tx] Add max_micro_batches to limit batch size for forward/forward_backward requests #968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add max_micro_batches config (default: 64) to EngineConfig to limit how many micro batches are processed before returning results to clients. This prevents long wait times when clients send large numbers of requests. The limit behavior depends on train_micro_batch_size: - When > 0: counts micro batches as ceil(sequences / micro_batch_size) - When = 0 (full batch mode): each request counts as 1 Always includes at least one request to avoid starvation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
@raulchen is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a max_micro_batches configuration to the engine to limit the number of micro-batches processed in a single batch. This is a useful feature to prevent long wait times for clients. The implementation includes the necessary configuration, engine logic, and a comprehensive set of tests to validate the new behavior.
My main feedback is on a design choice in the engine logic that creates a tight coupling with the JaxBackend. I've left a comment with a suggestion to refactor this for better maintainability and extensibility. Otherwise, the changes look good and the tests are well-written.
| if self.config.max_micro_batches > 0 and isinstance(self.backend, JaxBackend): | ||
| micro_batch_size = self.backend.config.train_micro_batch_size | ||
| limited = [] | ||
| total_micro_batches = 0 | ||
| for op in batchable: | ||
| num_sequences = len(op.request_data.get("data", [])) | ||
| if micro_batch_size > 0: | ||
| # Gradient accumulation enabled: count actual micro batches | ||
| num_micro_batches = math.ceil(num_sequences / micro_batch_size) | ||
| else: | ||
| # Full batch mode: each request is processed as one unit | ||
| num_micro_batches = 1 | ||
| # Always include at least one request to avoid starvation | ||
| if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches: | ||
| break | ||
| limited.append(op) | ||
| total_micro_batches += num_micro_batches | ||
| batchable = limited |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for limiting micro-batches is tightly coupled to the JaxBackend implementation. Specifically, the isinstance(self.backend, JaxBackend) check and accessing self.backend.config.train_micro_batch_size directly violate the Open/Closed Principle. This makes it difficult to extend the engine with other backends that might also support micro-batching, as it would require modifying this core engine logic.
A better design would be to abstract this calculation away from the engine. The backend should be responsible for reporting how many micro-batches a given operation will consume.
I'd suggest adding a method to the AbstractBackend interface, for example:
# In AbstractBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
"""Calculates the number of micro-batches for a given operation."""
# Default to 1 for backends that don't support micro-batching
return 1And JaxBackend would implement it:
# In JaxBackend
def get_micro_batch_count(self, op: FutureDB) -> int:
if self.config.train_micro_batch_size > 0:
num_sequences = len(op.request_data.get("data", []))
return math.ceil(num_sequences / self.config.train_micro_batch_size)
return 1Then, this block in the engine could be simplified and decoupled:
if self.config.max_micro_batches > 0:
limited = []
total_micro_batches = 0
for op in batchable:
num_micro_batches = self.backend.get_micro_batch_count(op)
# Always include at least one request to avoid starvation
if limited and total_micro_batches + num_micro_batches > self.config.max_micro_batches:
break
limited.append(op)
total_micro_batches += num_micro_batches
batchable = limitedThis would make the engine more maintainable and extensible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do this when we add a second backend cc @pcmoritz
Add max_micro_batches config (default: 64) to EngineConfig to limit how many micro batches are processed before returning results to clients. This prevents long wait times as well as retrieve_future requests piling up when clients send large numbers of requests.
The limit behavior depends on train_micro_batch_size:
Always includes at least one request to avoid starvation.