Skip to content

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Jan 29, 2026

Adds GLM4 Moe and non-moe variant.
Most of the code is borrowed from DeepseekV3 implementation. Moe + non Moe under a common class abstraction borrowed from Qwen3.

  • parity
  • lora gpu tests
  • End to end train

Tested end to end training on A100

GPU tests:

rootdir: /workspace/SkyRL/skyrl-tx
configfile: pyproject.toml
plugins: jaxtyping-0.3.6, anyio-4.12.1, forked-1.6.0
collected 2 items

tests/models/test_glm4_lora_training.py::test_lora_training_moe_rank_normalized The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type glm4_moe to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Fetching 1 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10866.07it/s]
Step 0: loss = 11.9481
Step 1: loss = 11.9478
Step 2: loss = 11.9471
Step 3: loss = 11.9459
Step 4: loss = 11.9446
Step 5: loss = 11.9432
Step 6: loss = 11.9411
Step 7: loss = 11.9400
Step 8: loss = 11.9383
Step 9: loss = 11.9367
PASSED
tests/models/test_glm4_lora_training.py::test_lora_training_high_rank The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type glm4_moe to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Fetching 1 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4104.02it/s]
Step 0: loss = 11.9481
Step 1: loss = 11.9478
Step 2: loss = 11.9471
Step 3: loss = 11.9459
Step 4: loss = 11.9446
Step 5: loss = 11.9432
Step 6: loss = 11.9411
Step 7: loss = 11.9400
Step 8: loss = 11.9383
Step 9: loss = 11.9367
PASSED

========================================================== 2 passed in 186.46s (0:03:06) ==========================================================

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for GLM4 models, including both dense and Mixture-of-Experts (MoE) variants. The implementation is well-structured, borrowing from existing model patterns in the repository. The changes include the core model logic, configuration updates, and a comprehensive set of tests for parity with the HuggingFace implementation. My review focuses on a few minor opportunities to improve code clarity and remove redundancies. Overall, this is a solid contribution.


# Verify dense model uses Glm4MLP in all layers
for layer in dense_model.model.layers:
from tx.models.glm4 import Glm4MLP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing inside a function (and in this case, a loop) is inefficient and against Python style guidelines (PEP 8). A similar local import is present on line 188. These should be moved to the top of the file. You can add Glm4MLP to the existing import on line 13 and remove both local import statements.

self.e_score_correction_bias = nnx.Variable(jnp.zeros(config.n_routed_experts, dtype=jnp.float32))

def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = hidden_states.reshape(-1, self.config.hidden_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The input hidden_states to this function is already flattened to (num_tokens, hidden_size) by the caller in Glm4MoE.__call__ (line 343). This reshape operation is therefore redundant and can be removed for a minor performance improvement and cleaner code.

Comment on lines +347 to +349
shared_output = self.shared_experts(
hidden_states_flat.reshape(batch_size, seq_len, hidden_size), adapter_indices
).reshape(-1, hidden_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression hidden_states_flat.reshape(batch_size, seq_len, hidden_size) is equivalent to the original hidden_states tensor defined on line 335. You can simplify this by passing hidden_states directly to self.shared_experts, which improves readability.

        shared_output = self.shared_experts(hidden_states, adapter_indices).reshape(-1, hidden_size)

@pcmoritz
Copy link
Collaborator

In order to implement zai-org/GLM-4.7-Flash, shouldn't we just reuse the DeepseekV3 implementation? I think they are the same architecture (see e.g. https://huggingface.co/zai-org/GLM-4.7-Flash).

I just tried the following diff

diff --git a/skyrl-tx/tx/models/deepseekv3.py b/skyrl-tx/tx/models/deepseekv3.py
index a2e48abd..30f03755 100644
--- a/skyrl-tx/tx/models/deepseekv3.py
+++ b/skyrl-tx/tx/models/deepseekv3.py
@@ -559,3 +559,5 @@ class DeepseekV3ForCausalLM(nnx.Module, ModelForCausalLM, GeneratorMixin, Logits
             kv_cache=outputs.kv_cache,
             hidden_states=outputs.hidden_states,
         )
+
+Glm4MoeLiteForCausalLM = DeepseekV3ForCausalLM
\ No newline at end of file

and with

uv run --extra gpu --extra tinker -m tx.tinker.api --base-model zai-org/GLM-4.7-Flash  --backend-config '{"max_lora_adapters": 2, "max_lora_rank": 1, "tensor_parallel_size": 8, "train_micro_batch_size": 1, "shard_attention_heads": false}'

it seems to load. Ideally we use expert_parallel_size here, but that doesn't currently work since deepseekv3.py currently doesn't implement the ep axis. I'd say let's first get that working well before implementing the non flash GLM architecture, what do you think?

My main motivation when suggesting to implement zai-org/GLM-4.7-Flash was to find a good small model that has the deepseek architecture so people can use the deepseek code on a smaller setting like 8xH100 :)

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Jan 30, 2026

@pcmoritz Oh okay.

GLM uses a different Attention (non-mla) has a non-moe variant as well based on the huggingface implementation

Should these variations live in deepseek code itself based on fields available in the config?

Using this as reference https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modeling_glm4.py

Edit: Made the mistake of assuming flash was just a smaller version - turns out a different architecture as well 😅
Using the Moe lite as the reference now, uses deepseek's exactly - https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py

@tanmaysachan
Copy link
Contributor Author

I'll keep this PR separately in that case, get ep axis working with deepseek

@pcmoritz
Copy link
Collaborator

pcmoritz commented Jan 30, 2026

That (making a separate PR for fixing the ep axis) sounds great to me, thank you! With the config and changes I posted above, I was able to run

uv run --with Pillow --with wandb==0.24.0 sl_loop.py base_url=http://localhost:8000/ model_name="zai-org/GLM-4.7-Flash" lora_rank=1 max_length=512

btw, with this branch of the tinker-cookbook https://github.com/pcmoritz/tinker-cookbook/tree/glm-4 (which might not be fully correct, I need to check it some more).

Once we get native expert parallelism it will be even faster I think (based on my experience with qwen3 moe), so that will be exciting.

@pcmoritz pcmoritz added the tx label Jan 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants