-
Notifications
You must be signed in to change notification settings - Fork 243
[tx] Add GLM4.7 #989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[tx] Add GLM4.7 #989
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for GLM4 models, including both dense and Mixture-of-Experts (MoE) variants. The implementation is well-structured, borrowing from existing model patterns in the repository. The changes include the core model logic, configuration updates, and a comprehensive set of tests for parity with the HuggingFace implementation. My review focuses on a few minor opportunities to improve code clarity and remove redundancies. Overall, this is a solid contribution.
skyrl-tx/tests/models/test_glm4.py
Outdated
|
|
||
| # Verify dense model uses Glm4MLP in all layers | ||
| for layer in dense_model.model.layers: | ||
| from tx.models.glm4 import Glm4MLP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing inside a function (and in this case, a loop) is inefficient and against Python style guidelines (PEP 8). A similar local import is present on line 188. These should be moved to the top of the file. You can add Glm4MLP to the existing import on line 13 and remove both local import statements.
| self.e_score_correction_bias = nnx.Variable(jnp.zeros(config.n_routed_experts, dtype=jnp.float32)) | ||
|
|
||
| def __call__(self, hidden_states: jax.Array) -> jax.Array: | ||
| hidden_states = hidden_states.reshape(-1, self.config.hidden_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| shared_output = self.shared_experts( | ||
| hidden_states_flat.reshape(batch_size, seq_len, hidden_size), adapter_indices | ||
| ).reshape(-1, hidden_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression hidden_states_flat.reshape(batch_size, seq_len, hidden_size) is equivalent to the original hidden_states tensor defined on line 335. You can simplify this by passing hidden_states directly to self.shared_experts, which improves readability.
shared_output = self.shared_experts(hidden_states, adapter_indices).reshape(-1, hidden_size)|
In order to implement I just tried the following diff and with it seems to load. Ideally we use My main motivation when suggesting to implement |
|
@pcmoritz Oh okay. GLM uses a different Attention (non-mla) has a non-moe variant as well based on the huggingface implementation Should these variations live in deepseek code itself based on fields available in the config? Using this as reference https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modeling_glm4.py Edit: Made the mistake of assuming flash was just a smaller version - turns out a different architecture as well 😅 |
|
I'll keep this PR separately in that case, get ep axis working with deepseek |
|
That (making a separate PR for fixing the ep axis) sounds great to me, thank you! With the config and changes I posted above, I was able to run btw, with this branch of the tinker-cookbook https://github.com/pcmoritz/tinker-cookbook/tree/glm-4 (which might not be fully correct, I need to check it some more). Once we get native expert parallelism it will be even faster I think (based on my experience with qwen3 moe), so that will be exciting. |
Adds GLM4 Moe and non-moe variant.
Most of the code is borrowed from DeepseekV3 implementation. Moe + non Moe under a common class abstraction borrowed from Qwen3.
Tested end to end training on A100
GPU tests: