Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 10 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 10 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Apr 21, 2026

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation. MoEBlock is a self-contained Flax-Linen module that wires together TE's fused router, pluggable token-dispatch backends (pure-JAX argsort or Triton sort_chunks_by_index), grouped_dense-based expert FFN, and ragged-all-to-all (A2Av) expert parallelism via shard_map

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM and 2 permutation backend: pure JAX or Triton kernels. The block also exposes pluggable knobs for: weight layout (wi_kernel_axes/ wo_kernel_axes), permutation backend, A2A vs no-EP (single GPU) path, data-parallelism axes for true FSDP (batch sharded across (ep, fsdp) simultaneously), top-K with optional grouped/sigmoid scoring (for DSv3 workload), and optional auxiliary load-balancing loss.

Fixes #2895

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • New transformer_engine/jax/flax/moe.py -- MoEBlock Linen module:
    gate -> fused topk -> global permute -> A2A EP shard_map (ragged_a2a fwd, local permute, 3x grouped GEMM SwiGLU FFN, local unpermute, ragged_a2a rev) -> global combine.
  • Extended transformer_engine/jax/permutation.py with A2A param helpers (compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a) and the pure-JAX unfused_token_dispatch / unfused_token_combine paths
    with custom VJPs.
  • tests/jax/test_moe_block.py -- single-device shape, backward,
    cross-backend equivalence, aux-loss, group-topk, JIT determinism.
  • tests/jax/test_distributed_moe_block.py -- EP=2 x FSDP=2 mesh test using the canonical Flax-Linen sharded-init pattern (eval_shape -> get_partition_spec -> logical_to_mesh_sharding -> jit(init, out_shardings=...)) and data_parallelism_axes=("fsdp",) to exercise true FSDP (batch sharded across both axes).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces MoEBlock, a self-contained Flax Linen module that composes TE's fused router, a pluggable token-dispatch backend (pure-JAX argsort or Triton), grouped_dense-based expert FFN, and an optional ragged-all-to-all expert-parallelism path via shard_map. It also extends permutation.py with pure-JAX dispatch/combine (with custom VJPs), ragged-A2A parameter helpers, and a fix for the previously-reported aux-loss correctness bug under grouped top-k routing.

  • moe.py adds the full MoEBlock with four composable stages plus separate _forward_no_ep and _forward_a2a_ep paths; the EP body correctly performs a single all_gather for aux-loss and computes recv_buffer_rows using the B/dp_size * S * topk worst-case formula.
  • permutation.py adds pure_jax_token_dispatch/pure_jax_token_combine with custom VJPs, routing_map_to_selected_experts, and four ragged-A2A helpers.
  • C++ changes relax sum(group_sizes) == m to <= m in grouped GEMM and quantize; multi_stream.cpp fixes a multi-device CUDA stream/event initialization race.

Confidence Score: 4/5

Safe to merge for the no-EP and default align_size paths; the align_size>0 code path remains blocked by xfail tests and an unresolved backward-pass issue from a prior review.

The core MoEBlock logic, ragged-A2A protocol, and pure-JAX dispatch/combine are well-reasoned and thoroughly tested. The C++ assertion relaxations are conservative and consistent. The main concern carried from a prior review is the hardcoded align_size=128 in the token_combine backward, which would produce wrong gradients once align_size>0 is enabled end-to-end — that path is currently gated behind xfail tests but is not addressed here.

transformer_engine/jax/permutation.py — backward rule for the padded token_combine uses a hardcoded alignment that must match the forward pass value

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New MoEBlock Flax module wiring router, permutation, grouped GEMM, and optional ragged-A2A EP; well-structured with clear stage separation but has a misleading error message in batch divisibility check
transformer_engine/jax/permutation.py Adds pure-JAX dispatch/combine with custom VJP, ragged-A2A helpers, and routing_map_to_selected_experts; backward pass has a hardcoded align_size=128 that must match the forward (currently xfail-gated)
transformer_engine/common/util/multi_stream.cpp Fixes multi-device CUDA stream/event initialization by keying the pool on active device ID instead of process-global call_once
transformer_engine/jax/csrc/extensions/gemm.cpp Relaxes grouped GEMM assertion to sum(group_sizes) <= m, enabling over-allocated ragged recv buffers
tests/jax/test_moe_block.py Single-device tests covering shape, backward, backend equivalence, aux loss, group-topk, and JIT determinism
tests/jax/test_distributed_moe_block.py EP=2 x FSDP=2 distributed test using canonical Flax sharded-init pattern
transformer_engine/jax/flax/init.py Adds MoEBlock to public exports
transformer_engine/jax/csrc/extensions/quantization.cpp Relaxes grouped quantize assertion analogously to gemm.cpp
transformer_engine/jax/cpp_extensions/gemm.py Removes @cache from _should_enforce_v2_grouped_gemm so monkeypatch.setenv works in tests

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["__call__ [B,S,H]"] --> B["_gate: inputs @ gate_kernel"]
    B --> C["_route_topk: fused_topk_with_score_function"]
    C --> D{expert_parallelism_axis?}
    D -- None --> E[_forward_no_ep]
    D -- set --> F[_forward_a2a_ep / shard_map]
    E --> E1[tokens_per_expert = sum routing_map]
    E1 --> E2[_compute_aux_loss optional]
    E1 --> E3[_global_permute]
    E3 --> E4[_expert_ffn: grouped_dense x3 + act]
    E4 --> E5[_global_combine]
    F --> F1[_a2a_body per EP shard]
    F1 --> F2[all_gather expert_bias]
    F2 --> F3[_route_topk local]
    F3 --> F4[aux: all_gather logits + re-run topk]
    F3 --> F5[_global_permute]
    F5 --> F6[all_gather group_sizes]
    F6 --> F7[compute_ragged_a2a_params]
    F7 --> F8[ragged_all_to_all fwd]
    F8 --> F9[local_permute_after_a2a]
    F9 --> F10[_expert_ffn E/num_ep groups]
    F10 --> F11[local_unpermute_before_a2a]
    F11 --> F12[compute_reverse_ragged_a2a_params]
    F12 --> F13[ragged_all_to_all rev]
    F13 --> F14[_global_combine]
    E5 --> G[return output aux_loss]
    F14 --> G
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/permutation.py
Comment thread transformer_engine/jax/flax/moe.py
tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +427 to +457
def _compute_aux_loss(
self,
logits_2d: jnp.ndarray,
) -> Optional[jnp.ndarray]:
"""Compute the MoE auxiliary load-balancing loss.

The score-for-aux kernel has no data dependency on the main
routing kernel, so XLA can overlap them on the GPU.

``logits_2d`` should be the *full* logits tensor over the global
token batch -- under EP the caller is responsible for
:func:`jax.lax.all_gather` ing the logits before calling this so
the aux_loss formula
``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])``
sees the global ``T`` and the global ``tokens_per_expert``.
"""
if self.aux_loss_coeff <= 0.0:
return None
aux_scores, aux_routing_map = fused_topk_with_score_function(
logits_2d.astype(jnp.float32),
topk=self.num_experts_per_tok,
score_function=self.score_function,
compute_aux_scores=True,
)
aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0)
return fused_moe_aux_loss(
aux_scores.astype(jnp.float32),
aux_tokens_per_expert,
topk=self.num_experts_per_tok,
coeff=self.aux_loss_coeff,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Aux loss tokens_per_expert is inconsistent with actual grouped-topk routing

When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.

Comment thread tests/jax/test_distributed_moe_block.py Outdated
sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)(
init_key, inputs
)
(sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SR: Do we want to jax.jit this function before calling it?

Comment thread tests/jax/test_moe_block.py Outdated
assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0"

@pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"])
def test_backward_grad(self, permutation_backend):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit SR: rename to test_backward_grad_is_finite_and_nonzero or something similar to indicate this test doesn't compare with a reference impl

Comment thread tests/jax/test_moe_block.py Outdated
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
g_pj = _unwrap_partitioned(grads_pj["params"][name])
g_tr = _unwrap_partitioned(grads_tr["params"][name])
assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-1 is a pretty high tolerance for most of our tests. What error values of atol and rtol do you typically get from these tests and is that error difference expected between jax/triton backends?

Comment thread tests/jax/test_moe_block.py Outdated

@pytest.mark.xfail(
reason=(
"TE grouped_dense FFI asserts sum(group_sizes) == M at "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion is only for the V1 grouped GEMM. For the V2 grouped GEMM sum(group_sizes) < M is supported. Can you try the following?

  1. Enforce the grouped GEMM with NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1, similar to what we do here:
    def use_jax_gemm(enabled=False):
  2. Remove the functools cache on this function so it can change:
    def _should_enforce_v2_grouped_gemm() -> bool:
  3. Run this test within a try/catch. If it runs, great. If the "catch" catches a runtime error that contains the string here, then pytest.skip("V2 grouped gemm is not supported")

Comment thread transformer_engine/jax/permutation.py Outdated

* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the
Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``.
* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` -
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like "unfused" implies slower, when in practice this approach is faster at least in MaxText. Would "triton" or "pure_jax" like we have in the tests fit better. What do you think?

Comment thread transformer_engine/jax/flax/moe.py Outdated
inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes)

_, _, hidden_size = inputs.shape
params = self._make_params(hidden_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this "params" dictionary? With nn.compact we can define each param right before it's usage inline and nn.compact will handle collectiong all params for us. Most of our other modules in module.py define params inline instead of in a dictionary upfront

Comment thread transformer_engine/jax/flax/moe.py Outdated
# Gate
# ------------------------------------------------------------------

def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this gating proj GEMM something we should quantize?

Comment thread transformer_engine/jax/flax/moe.py Outdated
inputs_2d: jnp.ndarray,
sparse_probs: jnp.ndarray,
routing_map: jnp.ndarray,
) -> dict:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace the dict with a dataclass with the same fields

sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias"))
aux_loss = self._compute_aux_loss(logits_2d)
perm = self._global_permute(inputs_2d, sparse_probs, routing_map)
expert_outputs = self._expert_ffn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the grouped quant and grouped GEMM do not have custom partitioning rules, I think this being outside of shard_map will either raise an error about missing partitioning rules or silently replicate

Comment thread transformer_engine/jax/flax/moe.py Outdated
captured[name] = params[name]
in_specs[name] = P(ep_axis, None)

def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly large function inside another big function. Can we move this to an outer scope or is this required to capture something?

nvjax and others added 2 commits May 7, 2026 15:18
…int in C++ files, make FP8 works. Tested with current scaling

Signed-off-by: JAX Toolbox <jax@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 7, 2026

Want your agent to iterate on Greptile's feedback? Try greploops.


namespace transformer_engine::detail {

namespace {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: I added this as I run into cudaErrorInvalidResourceHandle (at cast.cu:112 in nvte_multi_tensor_quantize) when trying to launch with 1 process in an independent script (that imports TE moe block) to test MoEBlock with data type FP8. This was because the global cudaStream or Event pool was created lazily via std::call_once, which leaves the resources bound to whichever device arrive first.

I fixed this with caching per cudaGetDevice() in an unordered map. Let me know if there is any reason why we should not do this. @jberchtold-nvidia

} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pointing at this so reviewers also pay attention to this change (that is different from the initial version previous to the other comments). When I tested with FP8, which is when the padding to align_size took effect, I start to see these checks firing, to which I then relaxed the checks because I think it should allow for garbage data on dim m to exist when there is worst case padding.

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the V1 grouped GEMM FFI, not the one we are using now that binds to cuBLASLt. The V2 does support k >= group sizes and I have tested the V2 with it thoroughly. In theory, I think relaxing this constraint for V1 shouldn't cause issues, but I have not tested it so I am not sure.

In your message above, is FP8 = tensor-scaled FP8? If so, the reason that triggers this assertion is we don't support tensor-scaled FP8 for the V2 grouped quant + GEMM so it falls through to the old V1 codepath.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think I tested with current scaling that's why it hit this V1 implementation. I have changedto MXFP8_1D_SCALING in my test script now. Good catch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JAX] Create initial MoE Block

3 participants