Skip to content

[Draft][TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend #14608

Open
xxi-nv wants to merge 1 commit into
NVIDIA:mainfrom
xxi-nv:megamoe_cutedsl_nvfp4_v2
Open

[Draft][TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend #14608
xxi-nv wants to merge 1 commit into
NVIDIA:mainfrom
xxi-nv:megamoe_cutedsl_nvfp4_v2

Conversation

@xxi-nv
Copy link
Copy Markdown
Collaborator

@xxi-nv xxi-nv commented May 27, 2026

⚠️ Status: PREVIEW / DO NOT USE / DO NOT MERGE

This PR introduces the MegaMoECuteDsl NVFP4 MoE backend as a preview / WIP drop.
It is NOT a finished product and is not safe to use as a real serving backend:

  • The kernel ABI is not finalized. The ported Sm100MegaMoEKernel still hard-codes
    alpha = 1.0 and norm_const = 1.0 in the FC1/FC2 epilogue path. Real NVFP4
    checkpoints (where fc31_alpha / fc2_alpha / fc2_input_scale are not 1) cannot
    produce correct results until the kernel ABI is extended to thread these scales through.
    See tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md
    "NVFP4 scale and alpha ABI" for the full hard gate.
  • A v1 _check_v1_alpha_gate in NVFP4MegaMoECuteDslMethod.process_weights_after_loading
    rejects any checkpoint whose alpha tensors differ from 1.0; the factory falls back to
    CutlassFusedMoE / CuteDslFusedMoE in that case so production paths are never silently wrong.
  • Accuracy testing has NOT been done. The current test runs only validate that the
    capability gate, weight pipeline, and op registration wire up end-to-end on a GB200
    with cu13 cutlass-dsl available; numerical correctness against a reference is gated
    on the kernel ABI extension above.
  • Do not select model_config.moe_backend = MEGAMOE_CUTEDSL in any production model
    config.
    The factory will fall back via can_implement / alpha gate as designed,
    but the backend itself is staged for follow-up work, not for use.

This PR is intentionally landed as a preview so the MegaMoE port + backend / quant
method / scheduler integration / NVSHMEM-equivalent symmetric memory provider can
evolve in tree without blocking the kernel ABI work. Please do not merge this PR until
the kernel ABI extension and accuracy validation follow-ups land.

Summary

  • New backend MegaMoECuteDsl under
    tensorrt_llm/_torch/modules/fused_moe/mega_moe/, wrapping the ported
    Sm100MegaMoEKernel (fused dispatch + FC1 + activation + FC2 + combine on SM100/103).
  • Standard CuteDSL torch op + TunableRunner pattern in
    tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py (Sm100MegaMoENvfp4Runner,
    MegaMoeSymmMemProvider, torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell).
  • Ported MegaMoE kernel package at
    tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ (16 files, package-relative imports).
  • New NVFP4MegaMoECuteDslMethod quant method: 16-atom gate/up interleave +
    to_blocked swizzle + flat per-slot SF stack; CPU shared-staging + EPLB
    fix-up registration for the four mega_fc{1,2}_weight{,_sf} derived parameters.
  • Multi-rank symmetric-memory provider built on PyTorch torch.distributed._symmetric_memory
    (cuMem-based NVSHMEM equivalent already used in the codebase via SymmetricMemoryAllReduce).
  • FusedCommMoEScheduler refactor: zero-token chunks now go through
    backend.quantize_input(...) unconditionally (each fused-comm backend owns its own
    empty-tensor layout, scheduler stays layout-agnostic).
  • Factory + ConfigurableMoE wiring (create_moe.py, __init__.py allowlist).
  • MoE Developer Guide updates: backend file map, capability matrix, FUSED_COMM
    scheduler invariants.
  • Shared test helper migration: MoeBackendType.MEGAMOE alias removed; both variants
    spelled out as MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with paired
    should_skip_megamoe_{deepgemm,cutedsl} helpers.

Hard gates / open work tracked in MEGAMOE_CUTEDSL_DESIGN.md

  1. Kernel ABI extension for fc31_alpha / fc2_alpha / fc2_input_scale.
  2. Production memory budget for form A vs form B in-kernel top-k reduction.
  3. Accuracy validation against a CUTLASS / CuteDSL fp16 reference once the
    ABI extension is in place.
  4. Multi-rank smoke / functional coverage for hidden sizes that are 32-aligned
    but not 64-aligned (1568, 1632, 2080) — current test matrix uses only 64-aligned
    hidden sizes so the symmetric-memory activation_sf row stride path is not
    exercised.

Test plan / current state on GB200 (cu13 cutlass-dsl)

All MegaMoECuteDsl-related cases collected by -k megamoe_cutedsl were run on a
GB200 node (nvl72151-T01) inside the OCI cu13 cutlass-dsl container, using a
prebuilt artifact symlink (no local C++ build).

Total: 59 cases collected (10 backend + 42 module single-rank + 7 module multi-rank).
Outcome: 0 PASSED / 24 FAILED / 35 SKIPPED.

Failure types (all map to the same v1 limitation, not a regression in this PR):

  • Class A — 10/10 backend FAILED: synthetic NVFP4 weights have fc31_alpha ≈ 1e-7;
    the backend test sets TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE=1 to bypass the
    alpha gate, so the kernel runs with alpha = 1 and the output diverges by ~2×10⁷
    from the reference (accuracy mismatch 100%). This is the documented v1 ABI limitation.
  • Class B — 9/42 module single-rank FAILED: _check_v1_alpha_gate raises
    NotImplementedError because single-rank module test path does not currently set the
    bypass env var (only the multi-rank path does). Same root cause as Class A; just visible
    earlier in the lifecycle.
  • Class C — 5/7 module multi-rank FAILED: bypass env set, kernel runs, accuracy
    mismatch (same root cause as Class A).
  • Skipped: 33 single-rank cases (non-Renormalize / non-DeepSeekV3 routings filtered
    by should_skip_megamoe_cutedsl, plus the e256_k6_h4096_i2048 shape) and 2
    multi-rank cases (e256_k6_h4096_i2048 large shape).

None of the failures indicate a regression introduced by this PR — they all map to
the same kernel ABI / alpha gate gap.

Things that will land in follow-up PRs (NOT in this preview)

  • Kernel ABI extension (per-expert alpha + norm_const) and the accompanying
    numerical-correctness tests.
  • Single-rank module test fixture: set TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE
    symmetrically with the multi-rank fixture so Class B cases reach the kernel.
  • Hidden-size coverage (1568 / 1632 / 2080) for the multi-rank symmetric-memory
    activation_sf row-stride path.
  • Form B in-kernel top-k reduction once memory budget for form A is measured.

Reviewer guidance

This is a preview. Reviewers can read code-level structure (backend boundary,
quantization-method boundary, scheduler refactor, MoE developer guide), but
please do NOT advance to merge until the kernel ABI extension and accuracy
validation land
.

Summary by CodeRabbit

  • New Features

    • Introduced MegaMoE CuteDSL backend for NVFP4 fused kernel execution.
    • Added NVFP4 quantization method for CuteDSL-based MegaMoE inference.
    • Extended FusedCommMoEScheduler to handle empty token batches.
  • Documentation

    • Added MegaMoE CuteDSL design specification and architecture details.
    • Updated MoE Developer Guide with new backend constraints and capabilities.
  • Tests

    • Extended unit test coverage for MegaMoE DeepGEMM and CuteDSL backends.

Review Change Stack

@xxi-nv xxi-nv requested review from a team as code owners May 27, 2026 03:08
@xxi-nv xxi-nv requested review from hyukn, mlefeb01, tburt-nv and yuxianq May 27, 2026 03:08
@xxi-nv xxi-nv force-pushed the megamoe_cutedsl_nvfp4_v2 branch from aa3a318 to ff947a8 Compare May 27, 2026 03:17
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 27, 2026

📝 Walkthrough

Walkthrough

This PR introduces a complete MegaMoE CuteDSL NVFP4 fused-communication backend that enables multi-rank expert routing with symmetric memory, token communication overlapping, and persistent scheduling. The implementation spans ~18K lines across kernel DSL code, custom op registration, backend class, quantization methods, and test infrastructure, featuring a flattened kernel package architecture with lazy imports to defer expensive CuteDSL symbol resolution.

Changes

MegaMoE CuteDSL NVFP4 Backend

Layer / File(s) Summary
Tooling setup
.pre-commit-config.yaml, legacy-files.txt, pyproject.toml, ruff-legacy.toml, tensorrt_llm/_torch/autotuner.py
Configuration files updated to include mega_moe_nvfp4 kernel package files in linting/formatting hooks. FP8 tensor generation added to autotuner with uint8-view pattern.
Kernel package init & core infrastructure
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py, config.py, megamoe_constants.py, blocked_scale.py, contract.py
Package docstring and lazy-loading interface; DSV4Config configuration structures; blocked-scale tensor swizzling for NVFP4 atoms; finite coordinate Space and Contract mapping framework for RMEM tensor handoff.
Device-side persistent scheduler & work-tile
moe_persistent_scheduler.py
MoEWorkTileInfo payload with MLIR serialization; MoESchedExtension interface for work-tile enrichment; static and dynamic (CLC-based) persistent schedulers with short-side-first raster and expert boundary caching.
Device-side utilities & TMA descriptors
moe_utils.py, grid_sync.py, iket_compat.py, ptx_helpers.py, sf_swizzle.py
PTX-level grid-sync barrier; TMA 1D load/store helpers; DSMEM cluster communication; pointer/address conversions; scale-factor atom layout swizzling; IKET dialect compatibility wrapper; online TMA descriptor workspace and grouping constructors.
Fused FC1+FC2 scheduler
fc1_fc2_fuse_sched.py
Persistent tile scheduler for fused fc1+fc2 with group→phase→expert state machine, greedy group formation via group_hint, and atomic-counter load-balancing option.
Dispatch kernel & scheduler extension
dispatch_kernel.py, custom_ext.py
3-stage dispatch flow (prep/barrier/pull) with TMA token pulls and cross-rank NVLink signaling. SwapAB-aware scheduler extension with phase-decoded peek-ready bits for fc1/fc2 readiness.
MegaMoE kernel launcher
megamoe_kernel.py
Workspace region layout, token communication hooks, TokenCommArgs serialization, and integration of dispatch/fc1/fc2 phases with NVLink barrier ordering.
Torch custom op & runner
tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py, __init__.py
Registers cute_dsl_megamoe_nvfp4_blackwell op with tactic validation, symmetric-memory provider, local workspace caching, and Sm100MegaMoENvfp4Runner for kernel compilation/execution via autotuner.
MegaMoECuteDsl backend
tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py
Backend class with can_implement() gating, EP process-group resolution, weight creation/loading, quantize_input() for NVFP4 packing and FP8 SF padding, and run_moe() with kernel invocation and form-A reduction.
NVFP4 weight lifecycle
tensorrt_llm/_torch/modules/fused_moe/quantization.py
NVFP4MegaMoECuteDslMethod with MegaMoE-format weight registration, expert weight staging, blocked-scale transformation via to_blocked, EPLB shared staging, and v1 alpha gating.
Module & factory integration
mega_moe/__init__.py, create_moe.py, configurable_moe.py, moe_scheduler.py
Module exports updated for CuteDSL classes/methods. Factory adds pretrained capability helper and backend selection. ConfigurableMoE defers validation. Scheduler delegates quantize_input to backend for zero-token chunks. DeepGemm adds empty-input handling.
Test infrastructure
tests/unittest/.../moe_test_utils.py, test_moe_backend.py, test_moe_module.py, tests/microbenchmarks/.../backend.py, utils.py
Test utilities split MEGAMOE into MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with backend-specific skip predicates. Backend/module tests updated with CuteDSL-specific multi-GPU and EPLB generators. Benchmark tests use explicit backend enum.
Documentation
MOE_DEVELOPER_GUIDE.md, mega_moe/MEGAMOE_CUTEDSL_DESIGN.md
Comprehensive design spec for CuteDSL backend integration, hard gates, call chain, runner/autotuner design, quantization spec, and test plan. Guide updated with backend details, capability matrix, and FUSED_COMM anti-patterns.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

The PR is large (18K+ lines) and heterogeneous, spanning device-side CUDA kernel DSL code, host-side PyTorch backend infrastructure, quantization pipelines, and test coverage. Key areas of complexity: persistent scheduler state machine with multiple advancement modes; dispatch kernel 3-stage barrier with cross-rank NVLink signaling; atomic-counter load-balancing option; MLIR serialization/deserialization for DSL interoperability; symmetric-memory provider caching; tactic validation and compile caching; weight transformation pipeline with blocked-scale swizzling; v1 alpha gating. Many changes are implementations of analogous patterns (e.g., scheduler variants, TMA descriptor constructors, skip predicates), which reduces per-file review friction but requires understanding the underlying design principles. The design doc and developer guide are essential context for review.

Suggested reviewers

  • nv-guomingz
  • syuoni
  • leslie-fang25
  • yuxianq
  • mingyangHao
  • tomeras91
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (11)
tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md (1)

152-152: ⚡ Quick win

Clarify symmetric-memory provider terminology.

The description states "requires CUDA 13 Cutlass DSL runtime (PR #14354) and NVSHMEM provider". Based on the design doc (phase 0, line 938-939), the implementation uses torch.distributed._symmetric_memory (cuMem-based), not an external NVSHMEM package. Consider revising to "symmetric memory provider" or "torch.distributed._symmetric_memory" to avoid implying a dependency on nvshmem4py-cu13, which was discussed in the design doc (lines 533-534) but not chosen for v1.

📝 Proposed clarification
-| `mega_moe/mega_moe_cute_dsl.py` | `MegaMoECuteDsl` | SM100/SM103 | NVFP4 via ported CuteDSL `Sm100MegaMoEKernel` fused dispatch+FC1+act+FC2+combine kernel; requires CUDA 13 Cutlass DSL runtime (PR `#14354`) and NVSHMEM provider (hard gate — see `mega_moe/MEGAMOE_CUTEDSL_DESIGN.md`); v1 alpha=1 product gate enforced in `post_load_weights` | `FUSED_COMM` |
+| `mega_moe/mega_moe_cute_dsl.py` | `MegaMoECuteDsl` | SM100/SM103 | NVFP4 via ported CuteDSL `Sm100MegaMoEKernel` fused dispatch+FC1+act+FC2+combine kernel; requires CUDA 13 Cutlass DSL runtime (PR `#14354`) and symmetric memory provider (`torch.distributed._symmetric_memory`); v1 alpha=1 product gate enforced in `post_load_weights` | `FUSED_COMM` |

Based on learnings from MEGAMOE_CUTEDSL_DESIGN.md phase 0: the implementation uses PyTorch's torch.distributed._symmetric_memory (cuMem-based NVSHMEM-equivalent), not an external NVSHMEM package.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md` at line 152,
Update the README entry describing mega_moe/mega_moe_cute_dsl.py /
MegaMoECuteDsl to avoid implying an external NVSHMEM package: change "NVSHMEM
provider" to "symmetric memory provider (torch.distributed._symmetric_memory)"
or similar wording that explicitly names PyTorch's cuMem-based provider; ensure
the note about requiring CUDA 13 Cutlass DSL runtime and the product gate
references remain unchanged so readers know runtime and gating requirements.
tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py (1)

706-706: 💤 Low value

Consider ClassVar annotation for shared mutable class attribute.

The kernel_cache is intentionally shared across all instances for compile caching. Adding ClassVar makes this intent explicit and silences the RUF012 warning:

+from typing import ClassVar
+
 class Sm100MegaMoENvfp4Runner(TunableRunner):
     ...
     # Module-scope compile cache shared by every runner instance.
-    kernel_cache: dict = {}
+    kernel_cache: ClassVar[dict] = {}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py` at line 706,
The class-level mutable cache kernel_cache is intended to be shared across
instances but lacks an explicit ClassVar annotation; update the declaration of
kernel_cache in the class to be typed as typing.ClassVar[dict] (and add from
typing import ClassVar if not already imported) so the intent is explicit and
the RUF012 warning is silenced while preserving the shared compile cache
behavior.
tests/unittest/_torch/modules/moe/test_moe_backend.py (2)

568-577: ⚡ Quick win

Add a companion negative test for the v1 alpha gate.

Setting TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE=1 for this path is useful for end-to-end execution, but it also removes direct coverage that the default gate rejects non-1 alpha checkpoints. Please add a small paired test with the env var unset that asserts rejection.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/moe/test_moe_backend.py` around lines 568 -
577, Add a companion negative test that mirrors the MEGAMOE_CUTEDSL branch but
ensures the v1 alpha gate is enforced: when backend_type ==
MoeBackendType.MEGAMOE_CUTEDSL, explicitly ensure TR
TLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE is unset (use monkeypatch.delenv or
monkeypatch.setenv with None), then exercise the same load -> post-load ->
run_moe path used in the positive case and assert that the load/post-load step
fails (raises/rejects) due to non-1 alpha values; place this next to the
existing positive branch in
tests/unittest/_torch/modules/moe/test_moe_backend.py so it uses the same setup
and failure assertion to prove the gate blocks non-1 alpha checkpoints.

307-315: QA integration list updates are unnecessary for this change-set.

This file only expands unit-test backend coverage; no tests/integration/defs/ entries were added or materially changed, so tests/integration/test_lists/qa/* updates are not required in this PR.

As per coding guidelines: "If the PR only touches unittest/ or narrow unit scope, say explicitly whether QA list updates are unnecessary or optional."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unittest/_torch/modules/moe/test_moe_backend.py` around lines 307 -
315, Add a brief explicit note that QA integration list updates are unnecessary
because this change only expands unit-test backend coverage and does not modify
tests/integration/defs; place the comment immediately above
BACKEND_TYPES_TO_TEST (referencing BACKEND_TYPES_TO_TEST, MoeBackendType.*
entries) so reviewers and release/QA scripts know no
tests/integration/test_lists/qa/* edits are required.
tests/microbenchmarks/bench_moe/utils.py (1)

66-79: 💤 Low value

_ensure_dist_for_megamoe only handles MEGAMOE_DEEPGEMM, not MEGAMOE_CUTEDSL.

This function checks only for MEGAMOE_DEEPGEMM but test_moe_module.py:_ensure_dist_for_megamoe (lines 139-143) checks for both MegaMoE variants. If bench_moe is ever used with MEGAMOE_CUTEDSL, the distributed process group won't be initialized.

Since MEGAMOE_CUTEDSL isn't in the benchmark's MoeBackendType enum yet (preview phase), this is acceptable for now but should be updated when CuteDSL is added to the benchmark registry.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/microbenchmarks/bench_moe/utils.py` around lines 66 - 79, The
early-return in _ensure_dist_for_megamoe only recognizes
MoeBackendType.MEGAMOE_DEEPGEMM, so the NCCL ProcessGroup will not be
initialized for the other MegaMoE variant; update _ensure_dist_for_megamoe to
also accept the MEGAMOE_CUTEDSL backend (or generally any future MegaMoE enum
value) by checking for both MoeBackendType.MEGAMOE_DEEPGEMM.value and
MoeBackendType.MEGAMOE_CUTEDSL.value (or by matching a common prefix/enum
category), keeping the rest of the initialization (CUDA check, env vars,
dist.init_process_group) unchanged; reference function name
_ensure_dist_for_megamoe and the test file
test_moe_module.py:_ensure_dist_for_megamoe to ensure parity with tests and add
a TODO comment to revisit when CuteDSL is formally added to the benchmark
registry.
tests/microbenchmarks/bench_moe/backend.py (1)

121-124: 💤 Low value

Consider adding MEGAMOE_CUTEDSL to the benchmark registry.

The MoeBackendType enum and get_backend_class dispatch only include MEGAMOE_DEEPGEMM. The CuteDSL variant is missing, which means bench_moe cannot benchmark the new backend.

This is acceptable for the preview phase, but should be addressed before the feature graduates to production.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/microbenchmarks/bench_moe/backend.py` around lines 121 - 124, Add the
missing CuteDSL backend to the registry by extending the MoeBackendType enum
with a MEGAMOE_CUTEDSL member and updating the get_backend_class dispatch to
handle it: import the CuteDSL backend implementation (e.g., the MegaMoE CuteDSL
class) alongside MegaMoEDeepGemm and return that class when backend_type ==
MoeBackendType.MEGAMOE_CUTEDSL; ensure you reference the new enum value
(MoeBackendType.MEGAMOE_CUTEDSL) and the CuteDSL class name (e.g.,
MegaMoECuteDSL) in the dispatch so bench_moe can instantiate the CuteDSL
backend.
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py (2)

1926-1927: ⚡ Quick win

Prefix unused variable with underscore.

The n_idx variable from the CLC response unpacking is never used. Prefix it with underscore to indicate it's intentionally ignored.

Suggested fix
-        m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response(
+        m_idx, _n_idx, l_idx, is_valid = cute.arch.clc_response(
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py`
around lines 1926 - 1927, The unpacking of the CLC response in the call to
cute.arch.clc_response currently assigns an unused variable n_idx; change that
identifier to a prefixed name (e.g., _n_idx) so it signals intentional ignore.
Locate the line where m_idx, n_idx, l_idx, is_valid =
cute.arch.clc_response(...) (inside moe_persistent_scheduler.py) and replace
n_idx with _n_idx in the tuple assignment; leave m_idx, l_idx and is_valid
unchanged.

7-7: 💤 Low value

Consider using built-in types instead of typing module.

Per coding guidelines, prefer built-in list, tuple over typing.List, typing.Tuple for Python 3.10+. Also prefer X | None over Optional[X].

This is a low-priority cleanup given the file size.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py`
at line 7, Replace typing.List, typing.Tuple, and typing.Optional usages with
built-in generic types and PEP 604 unions: change annotations like List[X] ->
list[X], Tuple[A, B] -> tuple[A, B], and Optional[T] -> T | None; remove List,
Optional, Tuple from the import line in moe_persistent_scheduler.py and
keep/import only Literal if still needed; update all function/method signatures
and variable annotations that reference List, Tuple, or Optional accordingly
(search for usages of List, Tuple, Optional in the file and replace them with
list, tuple, and | None forms).
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py (1)

33-33: 💤 Low value

Prefer built-in list over typing.List.

Per coding guidelines, prefer built-in types like list over legacy typing.List for Python 3.10+.

Suggested fix
-from typing import List
+from collections.abc import Sequence

Then update function signatures:

-def cat_byte_reinterpretable_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
+def cat_byte_reinterpretable_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py` at line
33, Replace the legacy typing.List import and usages with the built-in list
type: remove "from typing import List" and update all annotations in this module
(e.g., function signatures, return types, and variable annotations that
reference List) to use the native list[...] form (for example change "List[T]"
to "list[T]"); ensure functions/methods in this file that reference List (search
for symbols like any function or class definitions using List) are updated
accordingly and the typing import is deleted.
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py (1)

5-13: 💤 Low value

Consider UPPER_SNAKE_CASE for module-level constants.

Per coding guidelines, Python constants should use UPPER_SNAKE_CASE (e.g., NVFP4_BLOCK_SIZE, SF_PADDING_BLOCK). The current PascalCase naming is internally consistent but deviates from the project convention.

Given this is a preview PR, this can be deferred if the naming aligns with an upstream source being ported.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py`
around lines 5 - 13, Module-level constants use PascalCase; rename them to
UPPER_SNAKE_CASE to match project conventions: change Nvfp4BlockSize ->
NVFP4_BLOCK_SIZE, SfPaddingBlock -> SF_PADDING_BLOCK, TmaLeadingDimByteAlign ->
TMA_LEADING_DIM_BYTE_ALIGN, Nvfp4E2M1Max -> NVFP4_E2M1_MAX, Fp8E4M3FNMax ->
FP8_E4M3_FN_MAX, SupportedMmaTileM -> SUPPORTED_MMA_TILE_M, SupportedMmaTileN ->
SUPPORTED_MMA_TILE_N, and update all internal references/usages of these symbols
accordingly; if external consumers may rely on the old names, add short-lived
aliases mapping the old names to the new constants in the same module to
preserve backward compatibility (e.g., Nvfp4BlockSize = NVFP4_BLOCK_SIZE) and
add a TODO to remove aliases later.
tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py (1)

100-100: 💤 Low value

Modernize type hints to use built-in types.

The coding guidelines prefer built-in types (list, dict, tuple) over legacy typing equivalents, and X | None over Optional[X]. This file imports and uses List, Dict, Tuple, Optional throughout.

♻️ Suggested import and usage update
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any

Then update usages throughout, e.g.:

  • List[ir.Value]list[ir.Value]
  • Dict[str, int]dict[str, int]
  • Tuple[int, ...]tuple[int, ...]
  • Optional[int]int | None
  • Type[cutlass.Numeric]type[cutlass.Numeric]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py` at
line 100, Replace legacy typing generics with built-in generics and PEP 604
unions: remove List, Dict, Tuple, Optional, Type from the typing import and use
list, dict, tuple, X | None, and type[...] in their places; e.g., change imports
on the line with "from typing import Any, Dict, List, Optional, Tuple, Type" to
keep only Any (if needed) and update all occurrences like
List[ir.Value]→list[ir.Value], Dict[str,int]→dict[str,int],
Tuple[int,...]→tuple[int,...], Optional[int]→int | None, and
Type[cutlass.Numeric]→type[cutlass.Numeric] throughout megamoe_kernel.py (search
for those symbols: List, Dict, Tuple, Optional, Type and replace accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py`:
- Around line 346-352: The code builds sf_layout = tile_atom_to_shape_SF(...)
but then throws away the swizzled layout by calling
cute.make_layout(sf_layout.shape, stride=stride) when constructing the tensor;
update the cute.make_tensor call for the "sfa" branch (and the analogous
"sfb"/"sfc" branches) to pass the full sf_layout (the layout object returned by
tile_atom_to_shape_SF) into cute.make_tensor instead of reconstructing a plain
strided layout, so the atom-swizzled layout is preserved for the descriptor
path.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py`:
- Around line 890-909: The mbarrier expect_tx must be set before issuing the
async cluster store to avoid a race where the store completes before the peer is
armed; move the mbarrier_arrive_expect_tx_on_peer(...) call to precede
store_i32_to_peer_cluster_smem_async(...) in the path handling lane_idx <
Int32(cluster_size) so the peer's expect_tx is armed (matching the Int32(4) tx
size) before the st.async.shared::cluster write; reference
store_i32_to_peer_cluster_smem_async, mbarrier_arrive_expect_tx_on_peer, and
consumer_wait when making the reorder.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py`:
- Around line 954-955: The constructor/signature should not allow expert_cnt or
dependency to be None because construct_and_write() immediately uses dependency
and self.expert_cnt; update the signatures (remove Optional[...] = None and
default None) so expert_cnt and dependency are required parameters, and/or add
an early validation in construct_and_write() that raises a ValueError if
dependency is None or self.expert_cnt is None (referencing construct_and_write,
self.expert_cnt, and dependency), and apply the same change to the other
occurrence around lines 1014-1021 so both call sites enforce non-None values.

In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py`:
- Around line 436-454: Replace the assertion-based argument checks in
MegaMoECuteDsl's initializer with explicit validation that raises ValueError:
change the two assert checks that reference self.tp_size and self.cluster_size
to if ...: raise ValueError(...) and similarly change the assert that checks
self.ep_size == self.parallel_size when self.use_dp and self.parallel_size > 1
to an if that raises ValueError; keep the existing error messages (adjusted to
the ValueError) and keep the surrounding logic for num_slots/ep_size as-is.
- Around line 708-711: The input validation in MegaMoECuteDsl.load_weights uses
an assert which is skipped with Python -O; change this to an explicit check and
raise a ValueError when the length is not 1 (e.g., if len(weights) != 1: raise
ValueError("MegaMoECuteDsl.load_weights expects a single-element list, got
{len(weights)} entries.")) so callers always get a clear exception; update the
same error message currently used by the assert.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 3535-3549: The code crashes when per_slot is empty because
stack_byte_reinterpretable_tensors is called with an empty list; in
_build_mega_sf (and similarly in the other block), detect the zero-slot case
(e.g., if num_slots == 0 or per_slot == []) before stacking and return an
appropriately shaped empty/padded tensor (torch.zeros with shape (0, flat_size)
or (num_slots, flat_size), matching dtype=stacked.dtype/torch.uint8 and
device=device) so downstream callers expecting a (num_slots, flat_size) tensor
do not fail; implement this early-return right after constructing/validating
per_slot and before calling stack_byte_reinterpretable_tensors.
- Around line 3205-3212: The code assumes a 2x gated expansion
(expand_intermediate == 2 * intermediate) but never enforces it, causing
mis-sized MegaMoE buffers; add an explicit validation early in the weight
construction path (e.g., in create_weights) that checks
expand_intermediate_size_per_partition == 2 * intermediate and raise a clear
exception if not, and mirror the same guard in fc1_sf_flat_size (or call the
validated value) so fc1_sf_flat_size, the [w3|w1] split/interleave logic and
_build_mega_format_buffers all receive a guaranteed 2x expansion; reference
expand_intermediate_size_per_partition, create_weights, fc1_sf_flat_size, and
_build_mega_format_buffers when adding the assertion and error message.

---

Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py`:
- Line 706: The class-level mutable cache kernel_cache is intended to be shared
across instances but lacks an explicit ClassVar annotation; update the
declaration of kernel_cache in the class to be typed as typing.ClassVar[dict]
(and add from typing import ClassVar if not already imported) so the intent is
explicit and the RUF012 warning is silenced while preserving the shared compile
cache behavior.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py`:
- Line 33: Replace the legacy typing.List import and usages with the built-in
list type: remove "from typing import List" and update all annotations in this
module (e.g., function signatures, return types, and variable annotations that
reference List) to use the native list[...] form (for example change "List[T]"
to "list[T]"); ensure functions/methods in this file that reference List (search
for symbols like any function or class definitions using List) are updated
accordingly and the typing import is deleted.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py`:
- Around line 5-13: Module-level constants use PascalCase; rename them to
UPPER_SNAKE_CASE to match project conventions: change Nvfp4BlockSize ->
NVFP4_BLOCK_SIZE, SfPaddingBlock -> SF_PADDING_BLOCK, TmaLeadingDimByteAlign ->
TMA_LEADING_DIM_BYTE_ALIGN, Nvfp4E2M1Max -> NVFP4_E2M1_MAX, Fp8E4M3FNMax ->
FP8_E4M3_FN_MAX, SupportedMmaTileM -> SUPPORTED_MMA_TILE_M, SupportedMmaTileN ->
SUPPORTED_MMA_TILE_N, and update all internal references/usages of these symbols
accordingly; if external consumers may rely on the old names, add short-lived
aliases mapping the old names to the new constants in the same module to
preserve backward compatibility (e.g., Nvfp4BlockSize = NVFP4_BLOCK_SIZE) and
add a TODO to remove aliases later.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py`:
- Line 100: Replace legacy typing generics with built-in generics and PEP 604
unions: remove List, Dict, Tuple, Optional, Type from the typing import and use
list, dict, tuple, X | None, and type[...] in their places; e.g., change imports
on the line with "from typing import Any, Dict, List, Optional, Tuple, Type" to
keep only Any (if needed) and update all occurrences like
List[ir.Value]→list[ir.Value], Dict[str,int]→dict[str,int],
Tuple[int,...]→tuple[int,...], Optional[int]→int | None, and
Type[cutlass.Numeric]→type[cutlass.Numeric] throughout megamoe_kernel.py (search
for those symbols: List, Dict, Tuple, Optional, Type and replace accordingly).

In
`@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py`:
- Around line 1926-1927: The unpacking of the CLC response in the call to
cute.arch.clc_response currently assigns an unused variable n_idx; change that
identifier to a prefixed name (e.g., _n_idx) so it signals intentional ignore.
Locate the line where m_idx, n_idx, l_idx, is_valid =
cute.arch.clc_response(...) (inside moe_persistent_scheduler.py) and replace
n_idx with _n_idx in the tuple assignment; leave m_idx, l_idx and is_valid
unchanged.
- Line 7: Replace typing.List, typing.Tuple, and typing.Optional usages with
built-in generic types and PEP 604 unions: change annotations like List[X] ->
list[X], Tuple[A, B] -> tuple[A, B], and Optional[T] -> T | None; remove List,
Optional, Tuple from the import line in moe_persistent_scheduler.py and
keep/import only Literal if still needed; update all function/method signatures
and variable annotations that reference List, Tuple, or Optional accordingly
(search for usages of List, Tuple, Optional in the file and replace them with
list, tuple, and | None forms).

In `@tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md`:
- Line 152: Update the README entry describing mega_moe/mega_moe_cute_dsl.py /
MegaMoECuteDsl to avoid implying an external NVSHMEM package: change "NVSHMEM
provider" to "symmetric memory provider (torch.distributed._symmetric_memory)"
or similar wording that explicitly names PyTorch's cuMem-based provider; ensure
the note about requiring CUDA 13 Cutlass DSL runtime and the product gate
references remain unchanged so readers know runtime and gating requirements.

In `@tests/microbenchmarks/bench_moe/backend.py`:
- Around line 121-124: Add the missing CuteDSL backend to the registry by
extending the MoeBackendType enum with a MEGAMOE_CUTEDSL member and updating the
get_backend_class dispatch to handle it: import the CuteDSL backend
implementation (e.g., the MegaMoE CuteDSL class) alongside MegaMoEDeepGemm and
return that class when backend_type == MoeBackendType.MEGAMOE_CUTEDSL; ensure
you reference the new enum value (MoeBackendType.MEGAMOE_CUTEDSL) and the
CuteDSL class name (e.g., MegaMoECuteDSL) in the dispatch so bench_moe can
instantiate the CuteDSL backend.

In `@tests/microbenchmarks/bench_moe/utils.py`:
- Around line 66-79: The early-return in _ensure_dist_for_megamoe only
recognizes MoeBackendType.MEGAMOE_DEEPGEMM, so the NCCL ProcessGroup will not be
initialized for the other MegaMoE variant; update _ensure_dist_for_megamoe to
also accept the MEGAMOE_CUTEDSL backend (or generally any future MegaMoE enum
value) by checking for both MoeBackendType.MEGAMOE_DEEPGEMM.value and
MoeBackendType.MEGAMOE_CUTEDSL.value (or by matching a common prefix/enum
category), keeping the rest of the initialization (CUDA check, env vars,
dist.init_process_group) unchanged; reference function name
_ensure_dist_for_megamoe and the test file
test_moe_module.py:_ensure_dist_for_megamoe to ensure parity with tests and add
a TODO comment to revisit when CuteDSL is formally added to the benchmark
registry.

In `@tests/unittest/_torch/modules/moe/test_moe_backend.py`:
- Around line 568-577: Add a companion negative test that mirrors the
MEGAMOE_CUTEDSL branch but ensures the v1 alpha gate is enforced: when
backend_type == MoeBackendType.MEGAMOE_CUTEDSL, explicitly ensure TR
TLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE is unset (use monkeypatch.delenv or
monkeypatch.setenv with None), then exercise the same load -> post-load ->
run_moe path used in the positive case and assert that the load/post-load step
fails (raises/rejects) due to non-1 alpha values; place this next to the
existing positive branch in
tests/unittest/_torch/modules/moe/test_moe_backend.py so it uses the same setup
and failure assertion to prove the gate blocks non-1 alpha checkpoints.
- Around line 307-315: Add a brief explicit note that QA integration list
updates are unnecessary because this change only expands unit-test backend
coverage and does not modify tests/integration/defs; place the comment
immediately above BACKEND_TYPES_TO_TEST (referencing BACKEND_TYPES_TO_TEST,
MoeBackendType.* entries) so reviewers and release/QA scripts know no
tests/integration/test_lists/qa/* edits are required.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: def9f1a9-762b-42f1-b770-485392227501

📥 Commits

Reviewing files that changed from the base of the PR and between 5dd96d6 and aa3a318.

📒 Files selected for processing (40)
  • .pre-commit-config.yaml
  • legacy-files.txt
  • pyproject.toml
  • ruff-legacy.toml
  • tensorrt_llm/_torch/autotuner.py
  • tensorrt_llm/_torch/custom_ops/__init__.py
  • tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.py
  • tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.py
  • tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md
  • tensorrt_llm/_torch/modules/fused_moe/__init__.py
  • tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/create_moe.py
  • tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md
  • tensorrt_llm/_torch/modules/fused_moe/mega_moe/__init__.py
  • tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py
  • tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.py
  • tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
  • tests/microbenchmarks/bench_moe/backend.py
  • tests/microbenchmarks/bench_moe/utils.py
  • tests/unittest/_torch/modules/moe/moe_test_utils.py
  • tests/unittest/_torch/modules/moe/test_moe_backend.py
  • tests/unittest/_torch/modules/moe/test_moe_module.py

Comment on lines +346 to +352
elif cutlass.const_expr(tensor_name == "sfa"):
real = cute.domain_offset((0, 0, expert_idx),
gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Preserve the swizzled sf_layout from tile_atom_to_shape_SF when building sfa/sfb/sfc tensors

custom_ext.py computes sf_layout = tile_atom_to_shape_SF(...) for sfa/sfb/sfc, but then discards the layout object and rebuilds the tensor using only cute.make_layout(sf_layout.shape, stride=stride), which can drop the atom-swizzled layout details expected by the descriptor path.

Suggested fix: pass the full sf_layout into cute.make_tensor(...) instead of rebuilding a plain strided view.

Suggested diff
         elif cutlass.const_expr(tensor_name == "sfa"):
             real = cute.domain_offset((0, 0, expert_idx),
                                       gmem_tensor_in_moe_view)
             per_expert_shape = (shape[0], shape[1], c1)  # type: ignore[index]
             sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
-            real = cute.make_tensor(
-                real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+            real = cute.make_tensor(real.iterator, sf_layout)
             return (real, None)

         elif cutlass.const_expr(tensor_name == "sfb"):
             real = cute.domain_offset((sf_token_offset, 0, 0),
                                       gmem_tensor_in_moe_view)
             per_expert_shape = (shape[0], shape[1], c1)  # type: ignore[index]
             sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
-            real = cute.make_tensor(
-                real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+            real = cute.make_tensor(real.iterator, sf_layout)
             return (real, None)
...
         elif cutlass.const_expr(tensor_name == "sfc"):
             real = cute.domain_offset((sf_token_offset, 0, 0),
                                       gmem_tensor_in_moe_view)
             per_expert_shape = (shape[0], shape[1], c1)  # type: ignore[index]
             sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
-            real = cute.make_tensor(
-                real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+            real = cute.make_tensor(real.iterator, sf_layout)
             return (real, None)

Also applies to: 355-362, 371-379

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py` around
lines 346 - 352, The code builds sf_layout = tile_atom_to_shape_SF(...) but then
throws away the swizzled layout by calling cute.make_layout(sf_layout.shape,
stride=stride) when constructing the tensor; update the cute.make_tensor call
for the "sfa" branch (and the analogous "sfb"/"sfc" branches) to pass the full
sf_layout (the layout object returned by tile_atom_to_shape_SF) into
cute.make_tensor instead of reconstructing a plain strided layout, so the
atom-swizzled layout is preserved for the descriptor path.

Comment on lines +890 to +909
# DSMEM fan-out: lanes [0, cluster_size) each write to one peer
# CTA. Each lane targets a distinct peer (lane_idx == peer rank).
if lane_idx < Int32(cluster_size):
store_i32_to_peer_cluster_smem_async(
ds.broadcast_ptr,
atomic_idx,
full_barrier_ptr,
lane_idx,
loc=loc,
ip=ip,
)
# Set expect_tx on the peer mbarrier to match the 4-byte
# store above; pairs with the consumer_wait below.
mbarrier_arrive_expect_tx_on_peer(
full_barrier_ptr,
Int32(4),
lane_idx,
loc=loc,
ip=ip,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Arm the peer mbarrier before issuing the async cluster store.

store_i32_to_peer_cluster_smem_async() is documented to require the peer expect_tx to be set first, but this path does the st.async.shared::cluster...complete_tx before mbarrier_arrive_expect_tx_on_peer(). If the store wins that race, the atomic-counter scheduler can miss the completion and deadlock in consumer_wait().

Suggested fix
             if lane_idx < Int32(cluster_size):
-                store_i32_to_peer_cluster_smem_async(
-                    ds.broadcast_ptr,
-                    atomic_idx,
-                    full_barrier_ptr,
-                    lane_idx,
-                    loc=loc,
-                    ip=ip,
-                )
-                # Set expect_tx on the peer mbarrier to match the 4-byte
-                # store above; pairs with the consumer_wait below.
                 mbarrier_arrive_expect_tx_on_peer(
                     full_barrier_ptr,
                     Int32(4),
                     lane_idx,
                     loc=loc,
                     ip=ip,
                 )
+                store_i32_to_peer_cluster_smem_async(
+                    ds.broadcast_ptr,
+                    atomic_idx,
+                    full_barrier_ptr,
+                    lane_idx,
+                    loc=loc,
+                    ip=ip,
+                )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py`
around lines 890 - 909, The mbarrier expect_tx must be set before issuing the
async cluster store to avoid a race where the store completes before the peer is
armed; move the mbarrier_arrive_expect_tx_on_peer(...) call to precede
store_i32_to_peer_cluster_smem_async(...) in the path handling lane_idx <
Int32(cluster_size) so the peer's expect_tx is armed (matching the Int32(4) tx
size) before the st.async.shared::cluster write; reference
store_i32_to_peer_cluster_smem_async, mbarrier_arrive_expect_tx_on_peer, and
consumer_wait when making the reorder.

Comment on lines +954 to +955
expert_cnt: Optional[Union[Int32, int]] = None,
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make expert_cnt and dependency required here.

construct_and_write() immediately unpacks dependency and does arithmetic with self.expert_cnt, so the current Optional[...] = None defaults blow up as soon as a caller follows the advertised signature. Either require both inputs at construction/call time or fail fast with a ValueError before entering the descriptor loop.

Also applies to: 1014-1021

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py` around
lines 954 - 955, The constructor/signature should not allow expert_cnt or
dependency to be None because construct_and_write() immediately uses dependency
and self.expert_cnt; update the signatures (remove Optional[...] = None and
default None) so expert_cnt and dependency are required parameters, and/or add
an early validation in construct_and_write() that raises a ValueError if
dependency is None or self.expert_cnt is None (referencing construct_and_write,
self.expert_cnt, and dependency), and apply the same change to the other
occurrence around lines 1014-1021 so both call sites enforce non-None values.

Comment on lines +436 to +454
assert self.tp_size == 1, (
f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}."
)
assert self.cluster_size == 1, (
f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}."
)
if self.num_slots % max(self.ep_size, 1) != 0:
raise ValueError(
f"MegaMoECuteDsl requires num_slots ({self.num_slots}) "
f"divisible by ep_size ({self.ep_size})."
)

if self.use_dp and self.parallel_size > 1:
assert self.ep_size == self.parallel_size, (
f"MegaMoECuteDsl with enable_attention_dp=True requires "
f"ep_size == parallel_size (got ep_size={self.ep_size}, "
f"parallel_size={self.parallel_size}). ADP > EP would "
f"require an outer allgather + reducescatter wrapper."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Replace assert with if ... raise ValueError for API validation.

Lines 436-441 and 448-454 use assert statements to validate constructor arguments. Since Python's -O flag disables assertions, these checks could silently pass in optimized production builds. The adjacent validation on lines 442-446 already correctly uses if ... raise ValueError.

Proposed fix
-        assert self.tp_size == 1, (
-            f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}."
-        )
-        assert self.cluster_size == 1, (
-            f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}."
-        )
+        if self.tp_size != 1:
+            raise ValueError(
+                f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}."
+            )
+        if self.cluster_size != 1:
+            raise ValueError(
+                f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}."
+            )
         if self.num_slots % max(self.ep_size, 1) != 0:
             raise ValueError(
                 f"MegaMoECuteDsl requires num_slots ({self.num_slots}) "
                 f"divisible by ep_size ({self.ep_size})."
             )

-        if self.use_dp and self.parallel_size > 1:
-            assert self.ep_size == self.parallel_size, (
+        if self.use_dp and self.parallel_size > 1 and self.ep_size != self.parallel_size:
+            raise ValueError(
                 f"MegaMoECuteDsl with enable_attention_dp=True requires "
                 f"ep_size == parallel_size (got ep_size={self.ep_size}, "
                 f"parallel_size={self.parallel_size}). ADP > EP would "
                 f"require an outer allgather + reducescatter wrapper."
             )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py` around
lines 436 - 454, Replace the assertion-based argument checks in MegaMoECuteDsl's
initializer with explicit validation that raises ValueError: change the two
assert checks that reference self.tp_size and self.cluster_size to if ...: raise
ValueError(...) and similarly change the assert that checks self.ep_size ==
self.parallel_size when self.use_dp and self.parallel_size > 1 to an if that
raises ValueError; keep the existing error messages (adjusted to the ValueError)
and keep the surrounding logic for num_slots/ep_size as-is.

Comment on lines +708 to +711
assert len(weights) == 1, (
"MegaMoECuteDsl.load_weights expects a single-element list, "
f"got {len(weights)} entries."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Replace assert with if ... raise ValueError for input validation.

Same pattern as the constructor: this validates caller-provided input, which should not be silently skipped when Python runs with -O.

Proposed fix
-        assert len(weights) == 1, (
+        if len(weights) != 1:
+            raise ValueError(
             "MegaMoECuteDsl.load_weights expects a single-element list, "
             f"got {len(weights)} entries."
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py` around
lines 708 - 711, The input validation in MegaMoECuteDsl.load_weights uses an
assert which is skipped with Python -O; change this to an explicit check and
raise a ValueError when the length is not 1 (e.g., if len(weights) != 1: raise
ValueError("MegaMoECuteDsl.load_weights expects a single-element list, got
{len(weights)} entries.")) so callers always get a clear exception; update the
same error message currently used by the assert.

Comment on lines +3205 to +3212
def fc1_sf_flat_size(cls, intermediate: int, hidden: int) -> int:
"""``round_up(expand_intermediate, SfPaddingBlock=128) *
round_up(ceil(hidden / 16), 4)`` -- matches kernel_fc12.py:880-890.
``expand_intermediate = 2 * intermediate``.
"""
expand_intermediate = intermediate * 2
return (cls._round_up_int(expand_intermediate, 128) *
cls._round_up_int(cls._ceil_div_int(hidden, 16), 4))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Assert the 2x gated expansion invariant up front.

This implementation hard-codes FC1 as 2 * intermediate in both the derived-SF sizing and the [w3 | w1] split/interleave path, but create_weights() never validates that expand_intermediate_size_per_partition actually satisfies that contract. On any non-2x MoE shape, the MegaMoE buffers are mis-sized and _build_mega_format_buffers() will mis-pack or fail at load time instead of rejecting the config immediately.

💡 Suggested guard
     def create_weights(self, module: torch.nn.Module):
+        if (module.expand_intermediate_size_per_partition !=
+                2 * module.intermediate_size_per_partition):
+            raise NotImplementedError(
+                "NVFP4MegaMoECuteDslMethod currently requires gated 2x intermediate expansion."
+            )
+
         weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
         self.block_scales_vec_size = torch.iinfo(
             self.block_scales_dtype).bits // 8

Also applies to: 3241-3282, 3582-3609

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 3205 -
3212, The code assumes a 2x gated expansion (expand_intermediate == 2 *
intermediate) but never enforces it, causing mis-sized MegaMoE buffers; add an
explicit validation early in the weight construction path (e.g., in
create_weights) that checks expand_intermediate_size_per_partition == 2 *
intermediate and raise a clear exception if not, and mirror the same guard in
fc1_sf_flat_size (or call the validated value) so fc1_sf_flat_size, the [w3|w1]
split/interleave logic and _build_mega_format_buffers all receive a guaranteed
2x expansion; reference expand_intermediate_size_per_partition, create_weights,
fc1_sf_flat_size, and _build_mega_format_buffers when adding the assertion and
error message.

Comment on lines +3535 to +3549
per_slot: List[torch.Tensor] = []
for slot_idx in range(num_slots):
sf_fp8 = raw_sf[slot_idx].view(torch.float8_e4m3fn)
per_slot.append(to_blocked(sf_fp8).view(torch.uint8))
stacked = stack_byte_reinterpretable_tensors(per_slot,
dim=0).contiguous()
if stacked.shape[-1] == flat_size:
return stacked
# Pad zero on the tail so the output shape matches the
# registered Parameter shape.
out = torch.zeros((num_slots, flat_size),
dtype=stacked.dtype,
device=device)
out[:, :stacked.shape[-1]] = stacked
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle zero-slot staging before stacking per-slot SF tensors.

When online EPLB is enabled, local_shared_load_expert_ids can be empty, but _build_mega_shared_staging() still calls _build_mega_format_buffers(). In that case _build_mega_sf() builds per_slot = [] and unconditionally passes it to stack_byte_reinterpretable_tensors(...), which turns the "no shared experts" case into a load-time crash.

💡 Suggested early return
     def _build_mega_sf(raw_sf: torch.Tensor, *, num_slots: int,
                        gate_up_interleave_intermediate: Optional[int],
                        n_pairs: Optional[int], expand_intermediate: int,
                        flat_size: int) -> torch.Tensor:
         """Build a flattened, blocked-swizzled NVFP4 SF tensor per slot.
@@
         from ...cute_dsl_kernels.mega_moe_nvfp4 import (
             stack_byte_reinterpretable_tensors, to_blocked)

         device = raw_sf.device
+        if num_slots == 0:
+            return torch.empty((0, flat_size),
+                               dtype=torch.uint8,
+                               device=device)
+
         sf_cols = raw_sf.shape[-1]  # int32 units

Also applies to: 3665-3706

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 3535 -
3549, The code crashes when per_slot is empty because
stack_byte_reinterpretable_tensors is called with an empty list; in
_build_mega_sf (and similarly in the other block), detect the zero-slot case
(e.g., if num_slots == 0 or per_slot == []) before stacking and return an
appropriately shaped empty/padded tensor (torch.zeros with shape (0, flat_size)
or (num_slots, flat_size), matching dtype=stacked.dtype/torch.uint8 and
device=device) so downstream callers expecting a (num_slots, flat_size) tensor
do not fail; implement this early-return right after constructing/validating
per_slot and before calling stack_byte_reinterpretable_tensors.

@xxi-nv xxi-nv changed the title [None][feat] Add MegaMoECuteDsl NVFP4 MoE backend (preview, do not use) [Draft][None][feat] Add MegaMoECuteDsl NVFP4 MoE backend May 27, 2026
@xxi-nv xxi-nv force-pushed the megamoe_cutedsl_nvfp4_v2 branch from ff947a8 to 41b3aa8 Compare May 27, 2026 03:22
@xxi-nv xxi-nv changed the title [Draft][None][feat] Add MegaMoECuteDsl NVFP4 MoE backend [Draft][TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend May 27, 2026
Introduces MegaMoECuteDsl, a fused-communication MoE backend that runs the ported MegaMoE NVFP4 CuteDSL kernel (Sm100MegaMoEKernel) on SM100/SM103. The kernel fuses dispatch + FC1 + SwiGLU + FC2 + combine in a single launch via the in-kernel NVLink dispatch barrier. Single-rank degenerate path and multi-rank EP path are both wired through a unified always-pad-to-max_tokens_per_rank staging contract.

Components:
- New backend tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py with build-time symmetric-memory provider rendezvous, local staging cache, FUSED_COMM scheduler binding, and quantize_input that pads SF columns to round_up(ceil(hidden/16), 4) to match the kernel TMA contract.
- New torch custom op torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell registered in tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py. Hosts Sm100MegaMoENvfp4Runner (TunableRunner) with PARALLEL distributed tuning so every EP rank converges on the same compiled tactic for every chunk (required for the in-kernel dispatch barrier), MegaMoeSymmMemProvider with zero-init buffer, candidate tactic enumeration sweeping {static, atomic_counter} load balance modes, and a stricter IS_MEGAMOE_OP_AVAILABLE probe so half-installed cutlass-dsl wheels do not break the rest of custom_ops. The runner allocates local_workspace via torch.zeros and calls local_workspace.zero_() on every forward in addition to the existing shared_workspace.zero_(), so the cached buffer cannot feed garbage into the kernel's Int32 atomic counters (l1_arrival_count, fc1_done_counter, grid_sync_counter); a negative Int32 in any counter slot would otherwise make the in-kernel spin_wait (v >= positive_threshold) impossible to satisfy and hang the kernel at 100% SM / 0% memory bandwidth. The op returns None and the caller uses the in-place mutated combine_output directly, because torch.library forbids the return value from aliasing any mutated input.
- Ported CuteDSL kernel package at tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ (16 files) with package-relative imports, plus blocked_scale.py extracted from upstream runner_fc12.py.
- New NVFP4MegaMoECuteDslMethod quant method in tensorrt_llm/_torch/modules/fused_moe/quantization.py: 16-atom gate/up interleave + to_blocked swizzle + flattened-stack staging for mega_fc1_weight / mega_fc1_weight_sf / mega_fc2_weight / mega_fc2_weight_sf. Builds CPU shared-staging buffers for all four derived parameters and registers them with the load balancer through register_all_parameter_slot_and_to_fix_weight_fns so both static AND dynamic EPLB migrate the mega-format bytes atomically with the underlying NVFP4 raw weights + scales.
- FusedCommMoEScheduler now calls backend.quantize_input on zero-token chunks too (DG backend updated in parallel) so each fused-comm backend owns its own empty-tensor layout.
- create_moe.py factory and ConfigurableMoE allowlist updated; MoEDeveloperGuide adds Backend Capability Matrix entries, FUSED_COMM anti-patterns, and the autotuner tactic representation reference.

Tests:
- Drops the asymmetric MoeBackendType.MEGAMOE alias; both variants are spelled out as MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with should_skip_megamoe_deepgemm / should_skip_megamoe_cutedsl helpers.
- New focused tests in test_moe_backend.py: kernel-package import, to_blocked roundtrip, can_implement positive/negative, quantize_input zero-token, multi-rank symm-provider gate, alpha-gate, sf-byte-width helper, atomic_counter sweep, dynamic-EPLB-now-supported, and a byte-equivalence test for the FC1 gate/up 16-atom interleave that catches any gate-vs-up swap.
- test_moe_module.py adds factory + scheduler wiring coverage and updates the shared dist helper to recognise MEGAMOE_CUTEDSL.

Hard gates documented in MEGAMOE_CUTEDSL_DESIGN.md:
- Kernel ABI hard-codes per-expert alpha / fc2_input_scale to 1.0; v1 alpha gate in NVFP4MegaMoECuteDslMethod rejects non-1.0 checkpoint values until the upstream kernel ABI is extended.
- Launch contract requires T == max_tokens_per_rank on every call; backend stages real T rows then pads topk_idx tail to -1 (dispatch_prep skips negative expert ids).
- IS_MEGAMOE_OP_AVAILABLE strict probe protects the rest of custom_ops on partial cutlass-dsl installs.

Known follow-ups:
- per-slot to_blocked perf can be batched (high-volume models).
- Real GPU E2E run is blocked on OCI worktree rebuild for ABI parity with the new C++ bindings.

Designed and documented under tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md.

Signed-off-by: xxi <xxi@nvidia.com>
@xxi-nv xxi-nv force-pushed the megamoe_cutedsl_nvfp4_v2 branch from 41b3aa8 to 9134a9b Compare May 27, 2026 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant