[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge for the no-EP and default align_size paths; the align_size>0 code path remains blocked by xfail tests and an unresolved backward-pass issue from a prior review. The core MoEBlock logic, ragged-A2A protocol, and pure-JAX dispatch/combine are well-reasoned and thoroughly tested. The C++ assertion relaxations are conservative and consistent. The main concern carried from a prior review is the hardcoded align_size=128 in the token_combine backward, which would produce wrong gradients once align_size>0 is enabled end-to-end — that path is currently gated behind xfail tests but is not addressed here. transformer_engine/jax/permutation.py — backward rule for the padded token_combine uses a hardcoded alignment that must match the forward pass value Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["__call__ [B,S,H]"] --> B["_gate: inputs @ gate_kernel"]
B --> C["_route_topk: fused_topk_with_score_function"]
C --> D{expert_parallelism_axis?}
D -- None --> E[_forward_no_ep]
D -- set --> F[_forward_a2a_ep / shard_map]
E --> E1[tokens_per_expert = sum routing_map]
E1 --> E2[_compute_aux_loss optional]
E1 --> E3[_global_permute]
E3 --> E4[_expert_ffn: grouped_dense x3 + act]
E4 --> E5[_global_combine]
F --> F1[_a2a_body per EP shard]
F1 --> F2[all_gather expert_bias]
F2 --> F3[_route_topk local]
F3 --> F4[aux: all_gather logits + re-run topk]
F3 --> F5[_global_permute]
F5 --> F6[all_gather group_sizes]
F6 --> F7[compute_ragged_a2a_params]
F7 --> F8[ragged_all_to_all fwd]
F8 --> F9[local_permute_after_a2a]
F9 --> F10[_expert_ffn E/num_ep groups]
F10 --> F11[local_unpermute_before_a2a]
F11 --> F12[compute_reverse_ragged_a2a_params]
F12 --> F13[ragged_all_to_all rev]
F13 --> F14[_global_combine]
E5 --> G[return output aux_loss]
F14 --> G
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
| def _compute_aux_loss( | ||
| self, | ||
| logits_2d: jnp.ndarray, | ||
| ) -> Optional[jnp.ndarray]: | ||
| """Compute the MoE auxiliary load-balancing loss. | ||
|
|
||
| The score-for-aux kernel has no data dependency on the main | ||
| routing kernel, so XLA can overlap them on the GPU. | ||
|
|
||
| ``logits_2d`` should be the *full* logits tensor over the global | ||
| token batch -- under EP the caller is responsible for | ||
| :func:`jax.lax.all_gather` ing the logits before calling this so | ||
| the aux_loss formula | ||
| ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` | ||
| sees the global ``T`` and the global ``tokens_per_expert``. | ||
| """ | ||
| if self.aux_loss_coeff <= 0.0: | ||
| return None | ||
| aux_scores, aux_routing_map = fused_topk_with_score_function( | ||
| logits_2d.astype(jnp.float32), | ||
| topk=self.num_experts_per_tok, | ||
| score_function=self.score_function, | ||
| compute_aux_scores=True, | ||
| ) | ||
| aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) | ||
| return fused_moe_aux_loss( | ||
| aux_scores.astype(jnp.float32), | ||
| aux_tokens_per_expert, | ||
| topk=self.num_experts_per_tok, | ||
| coeff=self.aux_loss_coeff, | ||
| ) |
There was a problem hiding this comment.
Aux loss
tokens_per_expert is inconsistent with actual grouped-topk routing
When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.
| sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( | ||
| init_key, inputs | ||
| ) | ||
| (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad( |
There was a problem hiding this comment.
SR: Do we want to jax.jit this function before calling it?
| assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" | ||
|
|
||
| @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) | ||
| def test_backward_grad(self, permutation_backend): |
There was a problem hiding this comment.
nit SR: rename to test_backward_grad_is_finite_and_nonzero or something similar to indicate this test doesn't compare with a reference impl
| for name in ("gate_kernel", "wi_0", "wi_1", "wo"): | ||
| g_pj = _unwrap_partitioned(grads_pj["params"][name]) | ||
| g_tr = _unwrap_partitioned(grads_tr["params"][name]) | ||
| assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( |
There was a problem hiding this comment.
1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?
|
|
||
| @pytest.mark.xfail( | ||
| reason=( | ||
| "TE grouped_dense FFI asserts sum(group_sizes) == M at " |
There was a problem hiding this comment.
This assertion is only for the V1 grouped GEMM. For the V2 grouped GEMM sum(group_sizes) < M is supported. Can you try the following?
- Enforce the grouped GEMM with NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1, similar to what we do here:
TransformerEngine/tests/jax/utils.py
Line 1681 in 4b6923d
- Remove the functools cache on this function so it can change:
- Run this test within a try/catch. If it runs, great. If the "catch" catches a runtime error that contains the string here, then pytest.skip("V2 grouped gemm is not supported")
|
|
||
| * Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the | ||
| Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. | ||
| * Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - |
There was a problem hiding this comment.
I feel like "unfused" implies slower, when in practice this approach is faster at least in MaxText. Would "triton" or "pure_jax" like we have in the tests fit better. What do you think?
| inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) | ||
|
|
||
| _, _, hidden_size = inputs.shape | ||
| params = self._make_params(hidden_size) |
There was a problem hiding this comment.
Do we need this "params" dictionary? With nn.compact we can define each param right before it's usage inline and nn.compact will handle collectiong all params for us. Most of our other modules in module.py define params inline instead of in a dictionary upfront
| # Gate | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: |
There was a problem hiding this comment.
Is this gating proj GEMM something we should quantize?
| inputs_2d: jnp.ndarray, | ||
| sparse_probs: jnp.ndarray, | ||
| routing_map: jnp.ndarray, | ||
| ) -> dict: |
There was a problem hiding this comment.
Can you replace the dict with a dataclass with the same fields
| sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias")) | ||
| aux_loss = self._compute_aux_loss(logits_2d) | ||
| perm = self._global_permute(inputs_2d, sparse_probs, routing_map) | ||
| expert_outputs = self._expert_ffn( |
There was a problem hiding this comment.
Since the grouped quant and grouped GEMM do not have custom partitioning rules, I think this being outside of shard_map will either raise an error about missing partitioning rules or silently replicate
| captured[name] = params[name] | ||
| in_specs[name] = P(ep_axis, None) | ||
|
|
||
| def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: |
There was a problem hiding this comment.
This is a fairly large function inside another big function. Can we move this to an outer scope or is this required to capture something?
…int in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox <jax@nvidia.com>
for more information, see https://pre-commit.ci
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
|
|
||
| namespace transformer_engine::detail { | ||
|
|
||
| namespace { |
There was a problem hiding this comment.
NOTE: I added this as I run into cudaErrorInvalidResourceHandle (at cast.cu:112 in nvte_multi_tensor_quantize) when trying to launch with 1 process in an independent script (that imports TE moe block) to test MoEBlock with data type FP8. This was because the global cudaStream or Event pool was created lazily via std::call_once, which leaves the resources bound to whichever device arrive first.
I fixed this with caching per cudaGetDevice() in an unordered map. Let me know if there is any reason why we should not do this. @jberchtold-nvidia
| } else { | ||
| NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, | ||
| ", got sum(group_sizes)=", sum_group_sizes); | ||
| NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, |
There was a problem hiding this comment.
Just pointing at this so reviewers also pay attention to this change (that is different from the initial version previous to the other comments). When I tested with FP8, which is when the padding to align_size took effect, I start to see these checks firing, to which I then relaxed the checks because I think it should allow for garbage data on dim m to exist when there is worst case padding.
There was a problem hiding this comment.
This is for the V1 grouped GEMM FFI, not the one we are using now that binds to cuBLASLt. The V2 does support k >= group sizes and I have tested the V2 with it thoroughly. In theory, I think relaxing this constraint for V1 shouldn't cause issues, but I have not tested it so I am not sure.
In your message above, is FP8 = tensor-scaled FP8? If so, the reason that triggers this assertion is we don't support tensor-scaled FP8 for the V2 grouped quant + GEMM so it falls through to the old V1 codepath.
There was a problem hiding this comment.
Yes I think I tested with current scaling that's why it hit this V1 implementation. I have changedto MXFP8_1D_SCALING in my test script now. Good catch
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation.
MoEBlockis a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Tritonsort_chunks_by_index),grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism viashard_mapThis first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (
wi_kernel_axes/wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across(ep, fsdp)simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.Fixes #2895
Type of change
Changes
transformer_engine/jax/flax/moe.py--MoEBlockLinen module:gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
transformer_engine/jax/permutation.pywith A2A param helpers (compute_ragged_all_to_all_params,compute_reverse_ragged_all_to_all_params,local_permute_after_a2a,local_unpermute_before_a2a) and the pure-JAXunfused_token_dispatch/unfused_token_combinepathswith custom VJPs.
tests/jax/test_moe_block.py-- single-device shape, backward,cross-backend equivalence, aux-loss, group-topk, JIT determinism.
tests/jax/test_distributed_moe_block.py-- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) anddata_parallelism_axes=("fsdp",)to exercise true FSDP (batch sharded across both axes).Checklist: