Skip to content

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 30, 2026

Summary

Implement per-layer gradient checkpointing using jax.lax.scan with permanently stacked decoder layer weights via
nnx.vmap. This reduces peak memory by ~num_layers factor during training while maintaining a unified code path for training
and inference.

Key Changes

1. Per-layer Gradient Checkpointing

  • Use jax.lax.scan with jax.checkpoint to recompute activations during backward pass
  • XLA compiles ONE loop body and reuses buffers, unlike Python loops which unroll N separate checkpoint regions
  • Enable via gradient_checkpointing=True in model config

2. Stacked Layer Weights

  • Stack decoder layer weights at initialization using nnx.vmap → shape (num_layers, ...)
  • Eliminates runtime stacking overhead (weights already in stacked format)
  • Single forward_layers() function for both training and inference
  • KV cache uses stacked format (num_layers, batch, seq, heads, dim)

3. DeepSeekV3 Split Stacking

  • Handles heterogeneous layers (dense MLP vs MoE) with separate stacks
  • dense_layers for initial layers, moe_layers for MoE layers
  • KV caches merged after forward pass

Files Changed

  • tx/models/utils.py - New: create_stacked_layers(), forward_layers()
  • tx/models/{llama3,qwen3,deepseekv3}.py - Use stacked layers
  • tx/layers/lora.py - Stacked LoRA indexing
  • tx/utils/models.py - Stack/unstack for HF checkpoint compatibility
  • tx/utils/generator.py - Stacked KV cache
  • tx/tinker/backends/jax.py - Fix gradient accumulation for stacked params

Test plan

  • Forward outputs match with/without checkpointing
  • Gradients match with/without checkpointing
  • All model tests pass (37)
  • All tinker tests pass (19)
  • DeepSeekV3 EP=2 tests pass

raulchen and others added 30 commits January 20, 2026 18:55
Compute lm_head projection in chunks to avoid materializing the full
[B*T, V] logits tensor. Key changes:

- Add compute_logits flag to model.__call__ (skip lm_head when False)
- Add lm_head weight to CausalLMOutput for external computation
- Implement chunked logprobs with jax.lax.map (default chunk_size=1024)
- Add loss_chunk_size config option

Memory savings: O(B*T*V) -> O(chunk_size*V) for logits tensor.
For Qwen3-4B with V=151k, 8k seq: ~19GB -> ~300MB peak logits memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…ze<=0

The chunked cross-entropy path computes logits via direct matmul with
lm_head weight, bypassing LoRA adapters. This is incorrect when
train_unembed=True since LoRA should be applied to lm_head.

Changes:
- Rename is_training to skip_logits for clarity
- Add _use_chunked_loss flag to backend
- Automatically switch to non-chunked mode when:
  - train_unembed=True (requires LoRA on lm_head)
  - loss_chunk_size <= 0 (config-based disable)
- Non-chunked path uses pre-computed logits with LoRA correctly applied
Recompute activations during backward to save memory. Only one layer's
activations are held at a time during backward pass, reducing peak
memory by ~num_layers factor.

- Add gradient_checkpointing config to ModelConfig
- Apply jax.checkpoint per-layer when is_training=True
- Rename compute_logits to is_training (controls both logits and checkpointing)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…euse

Add _forward_layers_checkpointed() using jax.lax.fori_loop so XLA compiles
ONE loop body and reuses buffers during backward recomputation. With a
Python loop, XLA unrolls N separate checkpoint regions and can't optimize
buffer reuse across them.

Only enabled when gradient_checkpointing=True. Without checkpointing,
activations are stored anyway, so fori_loop's buffer reuse doesn't help
and its weight stacking overhead makes it worse.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- test_jax_backend.py: extend test_gradient_checkpointing to verify gradients match
- test_models_common.py: add common tests for Llama3/Qwen3 (output, hidden_states, edge cases)
Handle edge case where self.layers is empty to prevent IndexError.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Resolve conflicts in llama3.py and qwen3.py
- Integrate LogitsProcessor from main
- Move chunked logprobs computation to LogitsProcessor.compute_chunked_logprobs
- Add LogitsProcessor.compute_logprobs() that handles both chunked and non-chunked paths
- Add _logits_to_logprobs() and _compute_chunked_logprobs() as private helpers
- Simplify jax.py to single compute_logprobs call
- LogitsProcessor is now a standalone utility with three static methods:
  compute_logits(), compute_logprobs(), logits_to_logprobs()
- Model forward() returns only hidden_states (removed logits computation)
- Simplified CausalLMOutput: removed logits and lm_head fields
- Generator uses LogitsProcessor for all logits/logprobs computation
- Backend uses LogitsProcessor.compute_logprobs() with chunking
- Updated tests to use new LogitsProcessor API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Create CausalLMBase class with compute_logits/compute_logprobs methods
- Models expose wrapper methods instead of direct LogitsProcessor access
- Update generator and jax.py backend to use model methods
- LogitsProcessor is now internal implementation detail

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace _has_train_unembed flag with _train_unembed_mask array
- Check at runtime if any adapter in batch needs LoRA on lm_head
- Use jax.lax.cond to choose chunked vs non-chunked path
- Handle adapter reuse correctly (reset mask on delete)
- Remove unused _use_chunked_loss flag

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Replace abstract property with __init__(lm_head) in base class
- Subclasses explicitly call CausalLMBase.__init__(self, lm_head)
- Fix test to support multiple adapters for mixed train_unembed test

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
raulchen and others added 12 commits January 29, 2026 17:27
- Add _is_stacked_layer_param helper to distinguish stacked vs non-stacked paths
- Update load_safetensors/save_safetensors to handle both formats
- Add num_layers argument to load_safetensors calls
- Use Auto axis types in test mesh to avoid sharding errors
- Update KV cache assertions for stacked array format

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add KVCache.update() to stack list-based KV outputs from non-stacked models
- Add _is_stacked_path() in lora.py to correctly index LoRA params

These workarounds allow DeepSeekV3 to work with the new stacked layer format
used by Qwen3/Llama3, without modifying the DeepSeekV3 model itself.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Split DeepseekV3DecoderLayer into DenseDecoderLayer and MoEDecoderLayer
- Use create_stacked_layers/forward_layers for both layer groups
- Add _get_layer_group_info for HF weight loading with layer offsets
- Update LoRA adapter indexing to handle dense_layers/moe_layers paths
- Fix dtype preservation in MoE routing weights
- Update tests for stacked adapter extraction

This enables gradient checkpointing and unified forward pass for DeepSeekV3,
matching the architecture used by Qwen3/Llama3.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
These methods were added to distinguish training/inference paths
but are no longer needed with the unified forward_layers approach.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Rename test_eval_mode_uses_standard_path to test_kv_cache_with_checkpointing
- Clarify dtype cast comment in DeepSeekV3 MoE routing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1. Remove unused _is_layer_param function from tx/utils/models.py
2. Remove unused num_layers parameter from load_safetensors/save_safetensors
3. Add is_stacked_lora_path() shared utility for LoRA adapter indexing
4. Create tests/models/lora_test_utils.py with shared test helpers:
   - get_adapter_params, get_out_of_rank_params, verify_params_unchanged
   - get_moe_out_of_rank_params for MoE-specific rank handling
5. Update all test files to use shared utilities

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update test_jax_backend.py to use stacked layer indexing:
- layers.self_attn.q_proj instead of layers[0].self_attn.q_proj
- Access adapter params with [layer_idx, adapter_idx] instead of [adapter_idx]

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a major refactoring to use stacked weights for model layers, which is a significant architectural improvement for performance in JAX. The changes are extensive, replacing nnx.List of layers with nnx.vmap-created stacked modules and using jax.lax.scan for an efficient forward pass. Key logic is well-encapsulated in the new tx/models/utils.py file. The refactoring is consistently applied across different model architectures (Llama3, Qwen3, DeepseekV3), and the test suite has been updated to reflect these changes, including the addition of a shared test utility module and new tests for gradient checkpointing. The code quality is high and the changes are well-executed.

raulchen and others added 3 commits January 30, 2026 12:32
The get_mean and reset_adapter methods assumed gradients had shape
(num_adapters, ...), but stacked layers have shape (num_layers,
num_adapters, ...). Use is_stacked_lora_path to detect and index
correctly for each case.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen changed the title [WIP] Stack weights [tx] Per-layer gradient checkpointing with stacked decoder layers Jan 30, 2026
@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring to implement per-layer gradient checkpointing and stacked decoder layers, aiming to reduce memory usage during training. The core logic is cleanly encapsulated in the new tx/models/utils.py, which provides a unified forward_layers function for both training and inference. The changes are consistently propagated across all models, with careful handling of the heterogeneous layers in DeepSeekV3. The utilities for LoRA parameter management and weight loading/saving are correctly generalized for the new stacked architecture. The test suite has been impressively updated to validate the correctness of this new approach, including checks for gradient matching. I have one minor suggestion to improve type safety. Overall, this is a high-quality and impactful contribution.

@@ -404,18 +408,13 @@ def __call__(


class DeepseekV3DecoderLayer(nnx.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __call__ method of this base class uses self.mlp, but mlp is not defined or hinted in the base class itself. This can be confusing for readers and may be flagged by static type checkers. Adding a class-level type hint for mlp would improve readability and maintainability.

For example:

class DeepseekV3DecoderLayer(nnx.Module):
    mlp: nnx.Module
    ...

raulchen and others added 3 commits January 30, 2026 13:19
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add parametrized test for is_stacked_lora_path covering stacked
  (layers, dense_layers, moe_layers) and non-stacked paths
- Add roundtrip test for extract/insert_adapter_state with stacked layers
- Add DeepSeekV3 gradient checkpointing test for split stacking

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
from tx.tinker.types import LoraConfig


def _get_sharding_spec(arr: jax.Array):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be good to get rid of this function (so we also don't depend on the jax.core.Tracer class, which looks like more of an internal class).

The best way to handle this I think is to introduce a required sharding argument in the constructors of the classes LoRA{Layer} and then call nnx.with_partitioning(...) inside the classes. This way the sharding is always required and we can even get rid of the assert.

I'm happy to make a separate PR with the refactor (do let me know if you see a problems with this).

return hidden_states, updated_cache


class DeepseekV3DenseDecoderLayer(DeepseekV3DecoderLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify this a little and get rid of these classes? What I have in mind is to just have the class DeepseekV3DecoderLayer from above and do

class DeepseekV3DecoderLayer(nnx.Module):
    def __init__(self, mlp_cls: Callable[[ModelConfig, ...], nnx.Module], config: DeepseekV3Config, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None:
        # ...
        self.mlp = mlp_cls(config, dtype=dtype, rngs=rngs)

and then down below where the stacked layers are created, do

def create_dense_layer(rngs: nnx.Rngs) -> DeepseekV3DecoderLayer:
    return DeepseekV3DecoderLayer(DeepseekV3MLP, config, dtype=dtype, rngs=rngs)

if output_hidden_states:
all_hidden_states.append(hidden_states)

# Merge KV caches from dense and MoE layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This merits a KVCache.concatenate() function I think, it could even be implemented generically like

@staticmethod
def concatenate(*kv_caches: KVCache) -> KVCache:
    return KVCache(
        keys=jnp.concatenate(kv_cache.key for kv_cache in kv_caches),
        values=jnp.concatenate(kv_cache.value for kv_cache in kv_caches),
        cache_position=kv_caches[0].cache_position,
    )

The reasoning is that we try to keep this kind of complexity out of the modeling files as much as possible.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
# Split KV cache for dense and MoE layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the below, it would be nice to write this as something like

dense_kv_cache, moe_kv_cache = kv_cache.split(self.num_dense_layers, self.num_moe_layers)

If you wanted, you could probably implement this genericlly by lifting https://docs.jax.dev/en/latest/_autosummary/jax.numpy.split.html, something like

def split(kv_cache: KVCache, indices_or_sections):
    keys_list = jnp.split(kv_cache.keys, indices_or_sections, axis=0)
    values_list =  jnp.split(kv_cache.keys, indices_or_sections, axis=0)
    return [KVCache(keys=keys, values=values, cache_position=kv_cache.cache_position) for (keys, values) in zip(keys_list, values_list))

but not sure if that's actually better. Otherwise just supporting the case of two splits is totally fine.

return None


def _adapter_index(is_stacked: bool, adapter_index: int):
Copy link
Collaborator

@pcmoritz pcmoritz Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change this function to

def get_adapter_idx(path: str, adapter_index: int):
    is_stacked = _is_stacked(path)
    return (slice(None), adapter_index) if is_stacked else (adapter_index,)

and then use it everywhere in the PR. I believe all the call sited of _adapter_index are already of this form, and there are lots more call sites where we can get rid of a pattern like

if is_stacked:
     # Process stacked weights
else:
     # Process unstacked weights

This can always be done like

idx = get_adapter_idx(path, adapter_index)
# Process weights with weights[idx,...]

raise ValueError("The 'learning_rate' key must be provided in optimizer_args.")


def _lora_slice(is_stacked: bool, adapter_index: int, rank: int, is_lora_A: bool) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need this function? Can't we just do things like

idx = get_adapter_idx(path, adapter_index)
p.at[idx, ..., :, :rank]
p.at[idx, ..., :rank, :]

if "layers" not in path_strs:
return False
layers_idx = path_strs.index("layers")
if layers_idx + 1 < len(path_strs) and path_strs[layers_idx + 1].isdigit():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this case is needed? If it isn't, we only need one method here and can always use the logic of is_stacked_lora_path.

My rationale for this question is, if we always stack the layers that are stackable, a case like ('model', 'layers', '0', 'self_attn', ...) should never happen, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants