-
Notifications
You must be signed in to change notification settings - Fork 243
[tx] Per-layer gradient checkpointing with stacked decoder layers #996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Compute lm_head projection in chunks to avoid materializing the full [B*T, V] logits tensor. Key changes: - Add compute_logits flag to model.__call__ (skip lm_head when False) - Add lm_head weight to CausalLMOutput for external computation - Implement chunked logprobs with jax.lax.map (default chunk_size=1024) - Add loss_chunk_size config option Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor. For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0 The chunked cross-entropy path computes logits via direct matmul with lm_head weight, bypassing LoRA adapters. This is incorrect when train_unembed=True since LoRA should be applied to lm_head. Changes: - Rename is_training to skip_logits for clarity - Add _use_chunked_loss flag to backend - Automatically switch to non-chunked mode when: - train_unembed=True (requires LoRA on lm_head) - loss_chunk_size <= 0 (config-based disable) - Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's activations are held at a time during backward pass, reducing peak memory by ~num_layers factor. - Add gradient_checkpointing config to ModelConfig - Apply jax.checkpoint per-layer when is_training=True - Rename compute_logits to is_training (controls both logits and checkpointing) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles ONE loop body and reuses buffers during backward recomputation. With a Python loop, XLA unrolls N separate checkpoint regions and can't optimize buffer reuse across them. Only enabled when gradient_checkpointing=True. Without checkpointing, activations are stored anyway, so fori_loop's buffer reuse doesn't help and its weight stacking overhead makes it worse. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match - test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Resolve conflicts in llama3.py and qwen3.py - Integrate LogitsProcessor from main - Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths - Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers - Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods: compute_logits(), compute_logprobs(), logits_to_logprobs() - Model forward() returns only hidden_states (removed logits computation) - Simplified CausalLMOutput: removed logits and lm_head fields - Generator uses LogitsProcessor for all logits/logprobs computation - Backend uses LogitsProcessor.compute_logprobs() with chunking - Updated tests to use new LogitsProcessor API Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods - Models expose wrapper methods instead of direct LogitsProcessor access - Update generator and jax.py backend to use model methods - LogitsProcessor is now internal implementation detail Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array - Check at runtime if any adapter in batch needs LoRA on lm_head - Use jax.lax.cond to choose chunked vs non-chunked path - Handle adapter reuse correctly (reset mask on delete) - Remove unused _use_chunked_loss flag Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class - Subclasses explicitly call CausalLMBase.__init__(self, lm_head) - Fix test to support multiple adapters for mixed train_unembed test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add _is_stacked_layer_param helper to distinguish stacked vs non-stacked paths - Update load_safetensors/save_safetensors to handle both formats - Add num_layers argument to load_safetensors calls - Use Auto axis types in test mesh to avoid sharding errors - Update KV cache assertions for stacked array format Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add KVCache.update() to stack list-based KV outputs from non-stacked models - Add _is_stacked_path() in lora.py to correctly index LoRA params These workarounds allow DeepSeekV3 to work with the new stacked layer format used by Qwen3/Llama3, without modifying the DeepSeekV3 model itself. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts commit 801458b.
- Split DeepseekV3DecoderLayer into DenseDecoderLayer and MoEDecoderLayer - Use create_stacked_layers/forward_layers for both layer groups - Add _get_layer_group_info for HF weight loading with layer offsets - Update LoRA adapter indexing to handle dense_layers/moe_layers paths - Fix dtype preservation in MoE routing weights - Update tests for stacked adapter extraction This enables gradient checkpointing and unified forward pass for DeepSeekV3, matching the architecture used by Qwen3/Llama3. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
These methods were added to distinguish training/inference paths but are no longer needed with the unified forward_layers approach. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Rename test_eval_mode_uses_standard_path to test_kv_cache_with_checkpointing - Clarify dtype cast comment in DeepSeekV3 MoE routing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1. Remove unused _is_layer_param function from tx/utils/models.py 2. Remove unused num_layers parameter from load_safetensors/save_safetensors 3. Add is_stacked_lora_path() shared utility for LoRA adapter indexing 4. Create tests/models/lora_test_utils.py with shared test helpers: - get_adapter_params, get_out_of_rank_params, verify_params_unchanged - get_moe_out_of_rank_params for MoE-specific rank handling 5. Update all test files to use shared utilities Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update test_jax_backend.py to use stacked layer indexing: - layers.self_attn.q_proj instead of layers[0].self_attn.q_proj - Access adapter params with [layer_idx, adapter_idx] instead of [adapter_idx] Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a major refactoring to use stacked weights for model layers, which is a significant architectural improvement for performance in JAX. The changes are extensive, replacing nnx.List of layers with nnx.vmap-created stacked modules and using jax.lax.scan for an efficient forward pass. Key logic is well-encapsulated in the new tx/models/utils.py file. The refactoring is consistently applied across different model architectures (Llama3, Qwen3, DeepseekV3), and the test suite has been updated to reflect these changes, including the addition of a shared test utility module and new tests for gradient checkpointing. The code quality is high and the changes are well-executed.
The get_mean and reset_adapter methods assumed gradients had shape (num_adapters, ...), but stacked layers have shape (num_layers, num_adapters, ...). Use is_stacked_lora_path to detect and index correctly for each case. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring to implement per-layer gradient checkpointing and stacked decoder layers, aiming to reduce memory usage during training. The core logic is cleanly encapsulated in the new tx/models/utils.py, which provides a unified forward_layers function for both training and inference. The changes are consistently propagated across all models, with careful handling of the heterogeneous layers in DeepSeekV3. The utilities for LoRA parameter management and weight loading/saving are correctly generalized for the new stacked architecture. The test suite has been impressively updated to validate the correctness of this new approach, including checks for gradient matching. I have one minor suggestion to improve type safety. Overall, this is a high-quality and impactful contribution.
| @@ -404,18 +408,13 @@ def __call__( | |||
|
|
|||
|
|
|||
| class DeepseekV3DecoderLayer(nnx.Module): | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __call__ method of this base class uses self.mlp, but mlp is not defined or hinted in the base class itself. This can be confusing for readers and may be flagged by static type checkers. Adding a class-level type hint for mlp would improve readability and maintainability.
For example:
class DeepseekV3DecoderLayer(nnx.Module):
mlp: nnx.Module
...Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add parametrized test for is_stacked_lora_path covering stacked (layers, dense_layers, moe_layers) and non-stacked paths - Add roundtrip test for extract/insert_adapter_state with stacked layers - Add DeepSeekV3 gradient checkpointing test for split stacking Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
| from tx.tinker.types import LoraConfig | ||
|
|
||
|
|
||
| def _get_sharding_spec(arr: jax.Array): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be good to get rid of this function (so we also don't depend on the jax.core.Tracer class, which looks like more of an internal class).
The best way to handle this I think is to introduce a required sharding argument in the constructors of the classes LoRA{Layer} and then call nnx.with_partitioning(...) inside the classes. This way the sharding is always required and we can even get rid of the assert.
I'm happy to make a separate PR with the refactor (do let me know if you see a problems with this).
| return hidden_states, updated_cache | ||
|
|
||
|
|
||
| class DeepseekV3DenseDecoderLayer(DeepseekV3DecoderLayer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we simplify this a little and get rid of these classes? What I have in mind is to just have the class DeepseekV3DecoderLayer from above and do
class DeepseekV3DecoderLayer(nnx.Module):
def __init__(self, mlp_cls: Callable[[ModelConfig, ...], nnx.Module], config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
# ...
self.mlp = mlp_cls(config, dtype=dtype, rngs=rngs)and then down below where the stacked layers are created, do
def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer:
return DeepseekV3DecoderLayer(DeepseekV3MLP, config, dtype=dtype, rngs=rngs)| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
|
|
||
| # Merge KV caches from dense and MoE layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This merits a KVCache.concatenate() function I think, it could even be implemented generically like
@staticmethod
def concatenate(*kv_caches: KVCache) -> KVCache:
return KVCache(
keys=jnp.concatenate(kv_cache.key for kv_cache in kv_caches),
values=jnp.concatenate(kv_cache.value for kv_cache in kv_caches),
cache_position=kv_caches[0].cache_position,
)
The reasoning is that we try to keep this kind of complexity out of the modeling files as much as possible.
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| # Split KV cache for dense and MoE layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the below, it would be nice to write this as something like
dense_kv_cache, moe_kv_cache = kv_cache.split(self.num_dense_layers, self.num_moe_layers)If you wanted, you could probably implement this genericlly by lifting https://docs.jax.dev/en/latest/_autosummary/jax.numpy.split.html, something like
def split(kv_cache: KVCache, indices_or_sections):
keys_list = jnp.split(kv_cache.keys, indices_or_sections, axis=0)
values_list = jnp.split(kv_cache.keys, indices_or_sections, axis=0)
return [KVCache(keys=keys, values=values, cache_position=kv_cache.cache_position) for (keys, values) in zip(keys_list, values_list))but not sure if that's actually better. Otherwise just supporting the case of two splits is totally fine.
| return None | ||
|
|
||
|
|
||
| def _adapter_index(is_stacked: bool, adapter_index: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change this function to
def get_adapter_idx(path: str, adapter_index: int):
is_stacked = _is_stacked(path)
return (slice(None), adapter_index) if is_stacked else (adapter_index,)and then use it everywhere in the PR. I believe all the call sited of _adapter_index are already of this form, and there are lots more call sites where we can get rid of a pattern like
if is_stacked:
# Process stacked weights
else:
# Process unstacked weightsThis can always be done like
idx = get_adapter_idx(path, adapter_index)
# Process weights with weights[idx,...]| raise ValueError("The 'learning_rate' key must be provided in optimizer_args.") | ||
|
|
||
|
|
||
| def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need this function? Can't we just do things like
idx = get_adapter_idx(path, adapter_index)
p.at[idx, ..., :, :rank]
p.at[idx, ..., :rank, :]| if "layers" not in path_strs: | ||
| return False | ||
| layers_idx = path_strs.index("layers") | ||
| if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why this case is needed? If it isn't, we only need one method here and can always use the logic of is_stacked_lora_path.
My rationale for this question is, if we always stack the layers that are stackable, a case like ('model', 'layers', '0', 'self_attn', ...) should never happen, right?
Summary
Implement per-layer gradient checkpointing using
jax.lax.scanwith permanently stacked decoder layer weights viannx.vmap. This reduces peak memory by ~num_layers factor during training while maintaining a unified code path for trainingand inference.
Key Changes
1. Per-layer Gradient Checkpointing
jax.lax.scanwithjax.checkpointto recompute activations during backward passgradient_checkpointing=Truein model config2. Stacked Layer Weights
nnx.vmap→ shape(num_layers, ...)forward_layers()function for both training and inference(num_layers, batch, seq, heads, dim)3. DeepSeekV3 Split Stacking
dense_layersfor initial layers,moe_layersfor MoE layersFiles Changed
tx/models/utils.py- New:create_stacked_layers(),forward_layers()tx/models/{llama3,qwen3,deepseekv3}.py- Use stacked layerstx/layers/lora.py- Stacked LoRA indexingtx/utils/models.py- Stack/unstack for HF checkpoint compatibilitytx/utils/generator.py- Stacked KV cachetx/tinker/backends/jax.py- Fix gradient accumulation for stacked paramsTest plan