[Draft][TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend #14608
[Draft][TRTLLM-12950][feat] Add MegaMoECuteDsl NVFP4 MoE backend #14608xxi-nv wants to merge 1 commit into
Conversation
aa3a318 to
ff947a8
Compare
📝 WalkthroughWalkthroughThis PR introduces a complete MegaMoE CuteDSL NVFP4 fused-communication backend that enables multi-rank expert routing with symmetric memory, token communication overlapping, and persistent scheduling. The implementation spans ~18K lines across kernel DSL code, custom op registration, backend class, quantization methods, and test infrastructure, featuring a flattened kernel package architecture with lazy imports to defer expensive CuteDSL symbol resolution. ChangesMegaMoE CuteDSL NVFP4 Backend
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes The PR is large (18K+ lines) and heterogeneous, spanning device-side CUDA kernel DSL code, host-side PyTorch backend infrastructure, quantization pipelines, and test coverage. Key areas of complexity: persistent scheduler state machine with multiple advancement modes; dispatch kernel 3-stage barrier with cross-rank NVLink signaling; atomic-counter load-balancing option; MLIR serialization/deserialization for DSL interoperability; symmetric-memory provider caching; tactic validation and compile caching; weight transformation pipeline with blocked-scale swizzling; v1 alpha gating. Many changes are implementations of analogous patterns (e.g., scheduler variants, TMA descriptor constructors, skip predicates), which reduces per-file review friction but requires understanding the underlying design principles. The design doc and developer guide are essential context for review. Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (11)
tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md (1)
152-152: ⚡ Quick winClarify symmetric-memory provider terminology.
The description states "requires CUDA 13 Cutlass DSL runtime (PR
#14354) and NVSHMEM provider". Based on the design doc (phase 0, line 938-939), the implementation usestorch.distributed._symmetric_memory(cuMem-based), not an external NVSHMEM package. Consider revising to "symmetric memory provider" or "torch.distributed._symmetric_memory" to avoid implying a dependency onnvshmem4py-cu13, which was discussed in the design doc (lines 533-534) but not chosen for v1.📝 Proposed clarification
-| `mega_moe/mega_moe_cute_dsl.py` | `MegaMoECuteDsl` | SM100/SM103 | NVFP4 via ported CuteDSL `Sm100MegaMoEKernel` fused dispatch+FC1+act+FC2+combine kernel; requires CUDA 13 Cutlass DSL runtime (PR `#14354`) and NVSHMEM provider (hard gate — see `mega_moe/MEGAMOE_CUTEDSL_DESIGN.md`); v1 alpha=1 product gate enforced in `post_load_weights` | `FUSED_COMM` | +| `mega_moe/mega_moe_cute_dsl.py` | `MegaMoECuteDsl` | SM100/SM103 | NVFP4 via ported CuteDSL `Sm100MegaMoEKernel` fused dispatch+FC1+act+FC2+combine kernel; requires CUDA 13 Cutlass DSL runtime (PR `#14354`) and symmetric memory provider (`torch.distributed._symmetric_memory`); v1 alpha=1 product gate enforced in `post_load_weights` | `FUSED_COMM` |Based on learnings from MEGAMOE_CUTEDSL_DESIGN.md phase 0: the implementation uses PyTorch's
torch.distributed._symmetric_memory(cuMem-based NVSHMEM-equivalent), not an external NVSHMEM package.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md` at line 152, Update the README entry describing mega_moe/mega_moe_cute_dsl.py / MegaMoECuteDsl to avoid implying an external NVSHMEM package: change "NVSHMEM provider" to "symmetric memory provider (torch.distributed._symmetric_memory)" or similar wording that explicitly names PyTorch's cuMem-based provider; ensure the note about requiring CUDA 13 Cutlass DSL runtime and the product gate references remain unchanged so readers know runtime and gating requirements.tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py (1)
706-706: 💤 Low valueConsider
ClassVarannotation for shared mutable class attribute.The
kernel_cacheis intentionally shared across all instances for compile caching. AddingClassVarmakes this intent explicit and silences the RUF012 warning:+from typing import ClassVar + class Sm100MegaMoENvfp4Runner(TunableRunner): ... # Module-scope compile cache shared by every runner instance. - kernel_cache: dict = {} + kernel_cache: ClassVar[dict] = {}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py` at line 706, The class-level mutable cache kernel_cache is intended to be shared across instances but lacks an explicit ClassVar annotation; update the declaration of kernel_cache in the class to be typed as typing.ClassVar[dict] (and add from typing import ClassVar if not already imported) so the intent is explicit and the RUF012 warning is silenced while preserving the shared compile cache behavior.tests/unittest/_torch/modules/moe/test_moe_backend.py (2)
568-577: ⚡ Quick winAdd a companion negative test for the v1 alpha gate.
Setting
TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE=1for this path is useful for end-to-end execution, but it also removes direct coverage that the default gate rejects non-1 alpha checkpoints. Please add a small paired test with the env var unset that asserts rejection.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/moe/test_moe_backend.py` around lines 568 - 577, Add a companion negative test that mirrors the MEGAMOE_CUTEDSL branch but ensures the v1 alpha gate is enforced: when backend_type == MoeBackendType.MEGAMOE_CUTEDSL, explicitly ensure TR TLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE is unset (use monkeypatch.delenv or monkeypatch.setenv with None), then exercise the same load -> post-load -> run_moe path used in the positive case and assert that the load/post-load step fails (raises/rejects) due to non-1 alpha values; place this next to the existing positive branch in tests/unittest/_torch/modules/moe/test_moe_backend.py so it uses the same setup and failure assertion to prove the gate blocks non-1 alpha checkpoints.
307-315: QA integration list updates are unnecessary for this change-set.This file only expands unit-test backend coverage; no
tests/integration/defs/entries were added or materially changed, sotests/integration/test_lists/qa/*updates are not required in this PR.As per coding guidelines: "If the PR only touches unittest/ or narrow unit scope, say explicitly whether QA list updates are unnecessary or optional."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/modules/moe/test_moe_backend.py` around lines 307 - 315, Add a brief explicit note that QA integration list updates are unnecessary because this change only expands unit-test backend coverage and does not modify tests/integration/defs; place the comment immediately above BACKEND_TYPES_TO_TEST (referencing BACKEND_TYPES_TO_TEST, MoeBackendType.* entries) so reviewers and release/QA scripts know no tests/integration/test_lists/qa/* edits are required.tests/microbenchmarks/bench_moe/utils.py (1)
66-79: 💤 Low value
_ensure_dist_for_megamoeonly handlesMEGAMOE_DEEPGEMM, notMEGAMOE_CUTEDSL.This function checks only for
MEGAMOE_DEEPGEMMbuttest_moe_module.py:_ensure_dist_for_megamoe(lines 139-143) checks for both MegaMoE variants. Ifbench_moeis ever used withMEGAMOE_CUTEDSL, the distributed process group won't be initialized.Since
MEGAMOE_CUTEDSLisn't in the benchmark'sMoeBackendTypeenum yet (preview phase), this is acceptable for now but should be updated when CuteDSL is added to the benchmark registry.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/microbenchmarks/bench_moe/utils.py` around lines 66 - 79, The early-return in _ensure_dist_for_megamoe only recognizes MoeBackendType.MEGAMOE_DEEPGEMM, so the NCCL ProcessGroup will not be initialized for the other MegaMoE variant; update _ensure_dist_for_megamoe to also accept the MEGAMOE_CUTEDSL backend (or generally any future MegaMoE enum value) by checking for both MoeBackendType.MEGAMOE_DEEPGEMM.value and MoeBackendType.MEGAMOE_CUTEDSL.value (or by matching a common prefix/enum category), keeping the rest of the initialization (CUDA check, env vars, dist.init_process_group) unchanged; reference function name _ensure_dist_for_megamoe and the test file test_moe_module.py:_ensure_dist_for_megamoe to ensure parity with tests and add a TODO comment to revisit when CuteDSL is formally added to the benchmark registry.tests/microbenchmarks/bench_moe/backend.py (1)
121-124: 💤 Low valueConsider adding
MEGAMOE_CUTEDSLto the benchmark registry.The
MoeBackendTypeenum andget_backend_classdispatch only includeMEGAMOE_DEEPGEMM. The CuteDSL variant is missing, which meansbench_moecannot benchmark the new backend.This is acceptable for the preview phase, but should be addressed before the feature graduates to production.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/microbenchmarks/bench_moe/backend.py` around lines 121 - 124, Add the missing CuteDSL backend to the registry by extending the MoeBackendType enum with a MEGAMOE_CUTEDSL member and updating the get_backend_class dispatch to handle it: import the CuteDSL backend implementation (e.g., the MegaMoE CuteDSL class) alongside MegaMoEDeepGemm and return that class when backend_type == MoeBackendType.MEGAMOE_CUTEDSL; ensure you reference the new enum value (MoeBackendType.MEGAMOE_CUTEDSL) and the CuteDSL class name (e.g., MegaMoECuteDSL) in the dispatch so bench_moe can instantiate the CuteDSL backend.tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py (2)
1926-1927: ⚡ Quick winPrefix unused variable with underscore.
The
n_idxvariable from the CLC response unpacking is never used. Prefix it with underscore to indicate it's intentionally ignored.Suggested fix
- m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response( + m_idx, _n_idx, l_idx, is_valid = cute.arch.clc_response(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py` around lines 1926 - 1927, The unpacking of the CLC response in the call to cute.arch.clc_response currently assigns an unused variable n_idx; change that identifier to a prefixed name (e.g., _n_idx) so it signals intentional ignore. Locate the line where m_idx, n_idx, l_idx, is_valid = cute.arch.clc_response(...) (inside moe_persistent_scheduler.py) and replace n_idx with _n_idx in the tuple assignment; leave m_idx, l_idx and is_valid unchanged.
7-7: 💤 Low valueConsider using built-in types instead of
typingmodule.Per coding guidelines, prefer built-in
list,tupleovertyping.List,typing.Tuplefor Python 3.10+. Also preferX | NoneoverOptional[X].This is a low-priority cleanup given the file size.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py` at line 7, Replace typing.List, typing.Tuple, and typing.Optional usages with built-in generic types and PEP 604 unions: change annotations like List[X] -> list[X], Tuple[A, B] -> tuple[A, B], and Optional[T] -> T | None; remove List, Optional, Tuple from the import line in moe_persistent_scheduler.py and keep/import only Literal if still needed; update all function/method signatures and variable annotations that reference List, Tuple, or Optional accordingly (search for usages of List, Tuple, Optional in the file and replace them with list, tuple, and | None forms).tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py (1)
33-33: 💤 Low valuePrefer built-in
listovertyping.List.Per coding guidelines, prefer built-in types like
listover legacytyping.Listfor Python 3.10+.Suggested fix
-from typing import List +from collections.abc import SequenceThen update function signatures:
-def cat_byte_reinterpretable_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor: +def cat_byte_reinterpretable_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py` at line 33, Replace the legacy typing.List import and usages with the built-in list type: remove "from typing import List" and update all annotations in this module (e.g., function signatures, return types, and variable annotations that reference List) to use the native list[...] form (for example change "List[T]" to "list[T]"); ensure functions/methods in this file that reference List (search for symbols like any function or class definitions using List) are updated accordingly and the typing import is deleted.tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py (1)
5-13: 💤 Low valueConsider UPPER_SNAKE_CASE for module-level constants.
Per coding guidelines, Python constants should use UPPER_SNAKE_CASE (e.g.,
NVFP4_BLOCK_SIZE,SF_PADDING_BLOCK). The current PascalCase naming is internally consistent but deviates from the project convention.Given this is a preview PR, this can be deferred if the naming aligns with an upstream source being ported.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py` around lines 5 - 13, Module-level constants use PascalCase; rename them to UPPER_SNAKE_CASE to match project conventions: change Nvfp4BlockSize -> NVFP4_BLOCK_SIZE, SfPaddingBlock -> SF_PADDING_BLOCK, TmaLeadingDimByteAlign -> TMA_LEADING_DIM_BYTE_ALIGN, Nvfp4E2M1Max -> NVFP4_E2M1_MAX, Fp8E4M3FNMax -> FP8_E4M3_FN_MAX, SupportedMmaTileM -> SUPPORTED_MMA_TILE_M, SupportedMmaTileN -> SUPPORTED_MMA_TILE_N, and update all internal references/usages of these symbols accordingly; if external consumers may rely on the old names, add short-lived aliases mapping the old names to the new constants in the same module to preserve backward compatibility (e.g., Nvfp4BlockSize = NVFP4_BLOCK_SIZE) and add a TODO to remove aliases later.tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py (1)
100-100: 💤 Low valueModernize type hints to use built-in types.
The coding guidelines prefer built-in types (
list,dict,tuple) over legacytypingequivalents, andX | NoneoverOptional[X]. This file imports and usesList,Dict,Tuple,Optionalthroughout.♻️ Suggested import and usage update
-from typing import Any, Dict, List, Optional, Tuple, Type +from typing import AnyThen update usages throughout, e.g.:
List[ir.Value]→list[ir.Value]Dict[str, int]→dict[str, int]Tuple[int, ...]→tuple[int, ...]Optional[int]→int | NoneType[cutlass.Numeric]→type[cutlass.Numeric]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py` at line 100, Replace legacy typing generics with built-in generics and PEP 604 unions: remove List, Dict, Tuple, Optional, Type from the typing import and use list, dict, tuple, X | None, and type[...] in their places; e.g., change imports on the line with "from typing import Any, Dict, List, Optional, Tuple, Type" to keep only Any (if needed) and update all occurrences like List[ir.Value]→list[ir.Value], Dict[str,int]→dict[str,int], Tuple[int,...]→tuple[int,...], Optional[int]→int | None, and Type[cutlass.Numeric]→type[cutlass.Numeric] throughout megamoe_kernel.py (search for those symbols: List, Dict, Tuple, Optional, Type and replace accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py`:
- Around line 346-352: The code builds sf_layout = tile_atom_to_shape_SF(...)
but then throws away the swizzled layout by calling
cute.make_layout(sf_layout.shape, stride=stride) when constructing the tensor;
update the cute.make_tensor call for the "sfa" branch (and the analogous
"sfb"/"sfc" branches) to pass the full sf_layout (the layout object returned by
tile_atom_to_shape_SF) into cute.make_tensor instead of reconstructing a plain
strided layout, so the atom-swizzled layout is preserved for the descriptor
path.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py`:
- Around line 890-909: The mbarrier expect_tx must be set before issuing the
async cluster store to avoid a race where the store completes before the peer is
armed; move the mbarrier_arrive_expect_tx_on_peer(...) call to precede
store_i32_to_peer_cluster_smem_async(...) in the path handling lane_idx <
Int32(cluster_size) so the peer's expect_tx is armed (matching the Int32(4) tx
size) before the st.async.shared::cluster write; reference
store_i32_to_peer_cluster_smem_async, mbarrier_arrive_expect_tx_on_peer, and
consumer_wait when making the reorder.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py`:
- Around line 954-955: The constructor/signature should not allow expert_cnt or
dependency to be None because construct_and_write() immediately uses dependency
and self.expert_cnt; update the signatures (remove Optional[...] = None and
default None) so expert_cnt and dependency are required parameters, and/or add
an early validation in construct_and_write() that raises a ValueError if
dependency is None or self.expert_cnt is None (referencing construct_and_write,
self.expert_cnt, and dependency), and apply the same change to the other
occurrence around lines 1014-1021 so both call sites enforce non-None values.
In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py`:
- Around line 436-454: Replace the assertion-based argument checks in
MegaMoECuteDsl's initializer with explicit validation that raises ValueError:
change the two assert checks that reference self.tp_size and self.cluster_size
to if ...: raise ValueError(...) and similarly change the assert that checks
self.ep_size == self.parallel_size when self.use_dp and self.parallel_size > 1
to an if that raises ValueError; keep the existing error messages (adjusted to
the ValueError) and keep the surrounding logic for num_slots/ep_size as-is.
- Around line 708-711: The input validation in MegaMoECuteDsl.load_weights uses
an assert which is skipped with Python -O; change this to an explicit check and
raise a ValueError when the length is not 1 (e.g., if len(weights) != 1: raise
ValueError("MegaMoECuteDsl.load_weights expects a single-element list, got
{len(weights)} entries.")) so callers always get a clear exception; update the
same error message currently used by the assert.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py`:
- Around line 3535-3549: The code crashes when per_slot is empty because
stack_byte_reinterpretable_tensors is called with an empty list; in
_build_mega_sf (and similarly in the other block), detect the zero-slot case
(e.g., if num_slots == 0 or per_slot == []) before stacking and return an
appropriately shaped empty/padded tensor (torch.zeros with shape (0, flat_size)
or (num_slots, flat_size), matching dtype=stacked.dtype/torch.uint8 and
device=device) so downstream callers expecting a (num_slots, flat_size) tensor
do not fail; implement this early-return right after constructing/validating
per_slot and before calling stack_byte_reinterpretable_tensors.
- Around line 3205-3212: The code assumes a 2x gated expansion
(expand_intermediate == 2 * intermediate) but never enforces it, causing
mis-sized MegaMoE buffers; add an explicit validation early in the weight
construction path (e.g., in create_weights) that checks
expand_intermediate_size_per_partition == 2 * intermediate and raise a clear
exception if not, and mirror the same guard in fc1_sf_flat_size (or call the
validated value) so fc1_sf_flat_size, the [w3|w1] split/interleave logic and
_build_mega_format_buffers all receive a guaranteed 2x expansion; reference
expand_intermediate_size_per_partition, create_weights, fc1_sf_flat_size, and
_build_mega_format_buffers when adding the assertion and error message.
---
Nitpick comments:
In `@tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py`:
- Line 706: The class-level mutable cache kernel_cache is intended to be shared
across instances but lacks an explicit ClassVar annotation; update the
declaration of kernel_cache in the class to be typed as typing.ClassVar[dict]
(and add from typing import ClassVar if not already imported) so the intent is
explicit and the RUF012 warning is silenced while preserving the shared compile
cache behavior.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.py`:
- Line 33: Replace the legacy typing.List import and usages with the built-in
list type: remove "from typing import List" and update all annotations in this
module (e.g., function signatures, return types, and variable annotations that
reference List) to use the native list[...] form (for example change "List[T]"
to "list[T]"); ensure functions/methods in this file that reference List (search
for symbols like any function or class definitions using List) are updated
accordingly and the typing import is deleted.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.py`:
- Around line 5-13: Module-level constants use PascalCase; rename them to
UPPER_SNAKE_CASE to match project conventions: change Nvfp4BlockSize ->
NVFP4_BLOCK_SIZE, SfPaddingBlock -> SF_PADDING_BLOCK, TmaLeadingDimByteAlign ->
TMA_LEADING_DIM_BYTE_ALIGN, Nvfp4E2M1Max -> NVFP4_E2M1_MAX, Fp8E4M3FNMax ->
FP8_E4M3_FN_MAX, SupportedMmaTileM -> SUPPORTED_MMA_TILE_M, SupportedMmaTileN ->
SUPPORTED_MMA_TILE_N, and update all internal references/usages of these symbols
accordingly; if external consumers may rely on the old names, add short-lived
aliases mapping the old names to the new constants in the same module to
preserve backward compatibility (e.g., Nvfp4BlockSize = NVFP4_BLOCK_SIZE) and
add a TODO to remove aliases later.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.py`:
- Line 100: Replace legacy typing generics with built-in generics and PEP 604
unions: remove List, Dict, Tuple, Optional, Type from the typing import and use
list, dict, tuple, X | None, and type[...] in their places; e.g., change imports
on the line with "from typing import Any, Dict, List, Optional, Tuple, Type" to
keep only Any (if needed) and update all occurrences like
List[ir.Value]→list[ir.Value], Dict[str,int]→dict[str,int],
Tuple[int,...]→tuple[int,...], Optional[int]→int | None, and
Type[cutlass.Numeric]→type[cutlass.Numeric] throughout megamoe_kernel.py (search
for those symbols: List, Dict, Tuple, Optional, Type and replace accordingly).
In
`@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.py`:
- Around line 1926-1927: The unpacking of the CLC response in the call to
cute.arch.clc_response currently assigns an unused variable n_idx; change that
identifier to a prefixed name (e.g., _n_idx) so it signals intentional ignore.
Locate the line where m_idx, n_idx, l_idx, is_valid =
cute.arch.clc_response(...) (inside moe_persistent_scheduler.py) and replace
n_idx with _n_idx in the tuple assignment; leave m_idx, l_idx and is_valid
unchanged.
- Line 7: Replace typing.List, typing.Tuple, and typing.Optional usages with
built-in generic types and PEP 604 unions: change annotations like List[X] ->
list[X], Tuple[A, B] -> tuple[A, B], and Optional[T] -> T | None; remove List,
Optional, Tuple from the import line in moe_persistent_scheduler.py and
keep/import only Literal if still needed; update all function/method signatures
and variable annotations that reference List, Tuple, or Optional accordingly
(search for usages of List, Tuple, Optional in the file and replace them with
list, tuple, and | None forms).
In `@tensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.md`:
- Line 152: Update the README entry describing mega_moe/mega_moe_cute_dsl.py /
MegaMoECuteDsl to avoid implying an external NVSHMEM package: change "NVSHMEM
provider" to "symmetric memory provider (torch.distributed._symmetric_memory)"
or similar wording that explicitly names PyTorch's cuMem-based provider; ensure
the note about requiring CUDA 13 Cutlass DSL runtime and the product gate
references remain unchanged so readers know runtime and gating requirements.
In `@tests/microbenchmarks/bench_moe/backend.py`:
- Around line 121-124: Add the missing CuteDSL backend to the registry by
extending the MoeBackendType enum with a MEGAMOE_CUTEDSL member and updating the
get_backend_class dispatch to handle it: import the CuteDSL backend
implementation (e.g., the MegaMoE CuteDSL class) alongside MegaMoEDeepGemm and
return that class when backend_type == MoeBackendType.MEGAMOE_CUTEDSL; ensure
you reference the new enum value (MoeBackendType.MEGAMOE_CUTEDSL) and the
CuteDSL class name (e.g., MegaMoECuteDSL) in the dispatch so bench_moe can
instantiate the CuteDSL backend.
In `@tests/microbenchmarks/bench_moe/utils.py`:
- Around line 66-79: The early-return in _ensure_dist_for_megamoe only
recognizes MoeBackendType.MEGAMOE_DEEPGEMM, so the NCCL ProcessGroup will not be
initialized for the other MegaMoE variant; update _ensure_dist_for_megamoe to
also accept the MEGAMOE_CUTEDSL backend (or generally any future MegaMoE enum
value) by checking for both MoeBackendType.MEGAMOE_DEEPGEMM.value and
MoeBackendType.MEGAMOE_CUTEDSL.value (or by matching a common prefix/enum
category), keeping the rest of the initialization (CUDA check, env vars,
dist.init_process_group) unchanged; reference function name
_ensure_dist_for_megamoe and the test file
test_moe_module.py:_ensure_dist_for_megamoe to ensure parity with tests and add
a TODO comment to revisit when CuteDSL is formally added to the benchmark
registry.
In `@tests/unittest/_torch/modules/moe/test_moe_backend.py`:
- Around line 568-577: Add a companion negative test that mirrors the
MEGAMOE_CUTEDSL branch but ensures the v1 alpha gate is enforced: when
backend_type == MoeBackendType.MEGAMOE_CUTEDSL, explicitly ensure TR
TLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE is unset (use monkeypatch.delenv or
monkeypatch.setenv with None), then exercise the same load -> post-load ->
run_moe path used in the positive case and assert that the load/post-load step
fails (raises/rejects) due to non-1 alpha values; place this next to the
existing positive branch in
tests/unittest/_torch/modules/moe/test_moe_backend.py so it uses the same setup
and failure assertion to prove the gate blocks non-1 alpha checkpoints.
- Around line 307-315: Add a brief explicit note that QA integration list
updates are unnecessary because this change only expands unit-test backend
coverage and does not modify tests/integration/defs; place the comment
immediately above BACKEND_TYPES_TO_TEST (referencing BACKEND_TYPES_TO_TEST,
MoeBackendType.* entries) so reviewers and release/QA scripts know no
tests/integration/test_lists/qa/* edits are required.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: def9f1a9-762b-42f1-b770-485392227501
📒 Files selected for processing (40)
.pre-commit-config.yamllegacy-files.txtpyproject.tomlruff-legacy.tomltensorrt_llm/_torch/autotuner.pytensorrt_llm/_torch/custom_ops/__init__.pytensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/__init__.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/blocked_scale.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/config.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/contract.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/dispatch_kernel.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/epilogue.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/grid_sync.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/iket_compat.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/kernel_fc12.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_constants.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/megamoe_kernel.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_persistent_scheduler.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ptx_helpers.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sf_swizzle.pytensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/sym_buffer.pytensorrt_llm/_torch/modules/fused_moe/MOE_DEVELOPER_GUIDE.mdtensorrt_llm/_torch/modules/fused_moe/__init__.pytensorrt_llm/_torch/modules/fused_moe/configurable_moe.pytensorrt_llm/_torch/modules/fused_moe/create_moe.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.mdtensorrt_llm/_torch/modules/fused_moe/mega_moe/__init__.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.pytensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.pytensorrt_llm/_torch/modules/fused_moe/moe_scheduler.pytensorrt_llm/_torch/modules/fused_moe/quantization.pytests/microbenchmarks/bench_moe/backend.pytests/microbenchmarks/bench_moe/utils.pytests/unittest/_torch/modules/moe/moe_test_utils.pytests/unittest/_torch/modules/moe/test_moe_backend.pytests/unittest/_torch/modules/moe/test_moe_module.py
| elif cutlass.const_expr(tensor_name == "sfa"): | ||
| real = cute.domain_offset((0, 0, expert_idx), | ||
| gmem_tensor_in_moe_view) | ||
| per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index] | ||
| sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) | ||
| real = cute.make_tensor( | ||
| real.iterator, cute.make_layout(sf_layout.shape, stride=stride)) |
There was a problem hiding this comment.
Preserve the swizzled sf_layout from tile_atom_to_shape_SF when building sfa/sfb/sfc tensors
custom_ext.py computes sf_layout = tile_atom_to_shape_SF(...) for sfa/sfb/sfc, but then discards the layout object and rebuilds the tensor using only cute.make_layout(sf_layout.shape, stride=stride), which can drop the atom-swizzled layout details expected by the descriptor path.
Suggested fix: pass the full sf_layout into cute.make_tensor(...) instead of rebuilding a plain strided view.
Suggested diff
elif cutlass.const_expr(tensor_name == "sfa"):
real = cute.domain_offset((0, 0, expert_idx),
gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
- real = cute.make_tensor(
- real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+ real = cute.make_tensor(real.iterator, sf_layout)
return (real, None)
elif cutlass.const_expr(tensor_name == "sfb"):
real = cute.domain_offset((sf_token_offset, 0, 0),
gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
- real = cute.make_tensor(
- real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+ real = cute.make_tensor(real.iterator, sf_layout)
return (real, None)
...
elif cutlass.const_expr(tensor_name == "sfc"):
real = cute.domain_offset((sf_token_offset, 0, 0),
gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
- real = cute.make_tensor(
- real.iterator, cute.make_layout(sf_layout.shape, stride=stride))
+ real = cute.make_tensor(real.iterator, sf_layout)
return (real, None)Also applies to: 355-362, 371-379
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/custom_ext.py` around
lines 346 - 352, The code builds sf_layout = tile_atom_to_shape_SF(...) but then
throws away the swizzled layout by calling cute.make_layout(sf_layout.shape,
stride=stride) when constructing the tensor; update the cute.make_tensor call
for the "sfa" branch (and the analogous "sfb"/"sfc" branches) to pass the full
sf_layout (the layout object returned by tile_atom_to_shape_SF) into
cute.make_tensor instead of reconstructing a plain strided layout, so the
atom-swizzled layout is preserved for the descriptor path.
| # DSMEM fan-out: lanes [0, cluster_size) each write to one peer | ||
| # CTA. Each lane targets a distinct peer (lane_idx == peer rank). | ||
| if lane_idx < Int32(cluster_size): | ||
| store_i32_to_peer_cluster_smem_async( | ||
| ds.broadcast_ptr, | ||
| atomic_idx, | ||
| full_barrier_ptr, | ||
| lane_idx, | ||
| loc=loc, | ||
| ip=ip, | ||
| ) | ||
| # Set expect_tx on the peer mbarrier to match the 4-byte | ||
| # store above; pairs with the consumer_wait below. | ||
| mbarrier_arrive_expect_tx_on_peer( | ||
| full_barrier_ptr, | ||
| Int32(4), | ||
| lane_idx, | ||
| loc=loc, | ||
| ip=ip, | ||
| ) |
There was a problem hiding this comment.
Arm the peer mbarrier before issuing the async cluster store.
store_i32_to_peer_cluster_smem_async() is documented to require the peer expect_tx to be set first, but this path does the st.async.shared::cluster...complete_tx before mbarrier_arrive_expect_tx_on_peer(). If the store wins that race, the atomic-counter scheduler can miss the completion and deadlock in consumer_wait().
Suggested fix
if lane_idx < Int32(cluster_size):
- store_i32_to_peer_cluster_smem_async(
- ds.broadcast_ptr,
- atomic_idx,
- full_barrier_ptr,
- lane_idx,
- loc=loc,
- ip=ip,
- )
- # Set expect_tx on the peer mbarrier to match the 4-byte
- # store above; pairs with the consumer_wait below.
mbarrier_arrive_expect_tx_on_peer(
full_barrier_ptr,
Int32(4),
lane_idx,
loc=loc,
ip=ip,
)
+ store_i32_to_peer_cluster_smem_async(
+ ds.broadcast_ptr,
+ atomic_idx,
+ full_barrier_ptr,
+ lane_idx,
+ loc=loc,
+ ip=ip,
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/fc1_fc2_fuse_sched.py`
around lines 890 - 909, The mbarrier expect_tx must be set before issuing the
async cluster store to avoid a race where the store completes before the peer is
armed; move the mbarrier_arrive_expect_tx_on_peer(...) call to precede
store_i32_to_peer_cluster_smem_async(...) in the path handling lane_idx <
Int32(cluster_size) so the peer's expect_tx is armed (matching the Int32(4) tx
size) before the st.async.shared::cluster write; reference
store_i32_to_peer_cluster_smem_async, mbarrier_arrive_expect_tx_on_peer, and
consumer_wait when making the reorder.
| expert_cnt: Optional[Union[Int32, int]] = None, | ||
| ) -> None: |
There was a problem hiding this comment.
Make expert_cnt and dependency required here.
construct_and_write() immediately unpacks dependency and does arithmetic with self.expert_cnt, so the current Optional[...] = None defaults blow up as soon as a caller follows the advertised signature. Either require both inputs at construction/call time or fail fast with a ValueError before entering the descriptor loop.
Also applies to: 1014-1021
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/moe_utils.py` around
lines 954 - 955, The constructor/signature should not allow expert_cnt or
dependency to be None because construct_and_write() immediately uses dependency
and self.expert_cnt; update the signatures (remove Optional[...] = None and
default None) so expert_cnt and dependency are required parameters, and/or add
an early validation in construct_and_write() that raises a ValueError if
dependency is None or self.expert_cnt is None (referencing construct_and_write,
self.expert_cnt, and dependency), and apply the same change to the other
occurrence around lines 1014-1021 so both call sites enforce non-None values.
| assert self.tp_size == 1, ( | ||
| f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}." | ||
| ) | ||
| assert self.cluster_size == 1, ( | ||
| f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}." | ||
| ) | ||
| if self.num_slots % max(self.ep_size, 1) != 0: | ||
| raise ValueError( | ||
| f"MegaMoECuteDsl requires num_slots ({self.num_slots}) " | ||
| f"divisible by ep_size ({self.ep_size})." | ||
| ) | ||
|
|
||
| if self.use_dp and self.parallel_size > 1: | ||
| assert self.ep_size == self.parallel_size, ( | ||
| f"MegaMoECuteDsl with enable_attention_dp=True requires " | ||
| f"ep_size == parallel_size (got ep_size={self.ep_size}, " | ||
| f"parallel_size={self.parallel_size}). ADP > EP would " | ||
| f"require an outer allgather + reducescatter wrapper." | ||
| ) |
There was a problem hiding this comment.
Replace assert with if ... raise ValueError for API validation.
Lines 436-441 and 448-454 use assert statements to validate constructor arguments. Since Python's -O flag disables assertions, these checks could silently pass in optimized production builds. The adjacent validation on lines 442-446 already correctly uses if ... raise ValueError.
Proposed fix
- assert self.tp_size == 1, (
- f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}."
- )
- assert self.cluster_size == 1, (
- f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}."
- )
+ if self.tp_size != 1:
+ raise ValueError(
+ f"MegaMoECuteDsl is EP-only in v1 (moe_tp_size=1); got tp_size={self.tp_size}."
+ )
+ if self.cluster_size != 1:
+ raise ValueError(
+ f"MegaMoECuteDsl assumes cluster_size=1; got cluster_size={self.cluster_size}."
+ )
if self.num_slots % max(self.ep_size, 1) != 0:
raise ValueError(
f"MegaMoECuteDsl requires num_slots ({self.num_slots}) "
f"divisible by ep_size ({self.ep_size})."
)
- if self.use_dp and self.parallel_size > 1:
- assert self.ep_size == self.parallel_size, (
+ if self.use_dp and self.parallel_size > 1 and self.ep_size != self.parallel_size:
+ raise ValueError(
f"MegaMoECuteDsl with enable_attention_dp=True requires "
f"ep_size == parallel_size (got ep_size={self.ep_size}, "
f"parallel_size={self.parallel_size}). ADP > EP would "
f"require an outer allgather + reducescatter wrapper."
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py` around
lines 436 - 454, Replace the assertion-based argument checks in MegaMoECuteDsl's
initializer with explicit validation that raises ValueError: change the two
assert checks that reference self.tp_size and self.cluster_size to if ...: raise
ValueError(...) and similarly change the assert that checks self.ep_size ==
self.parallel_size when self.use_dp and self.parallel_size > 1 to an if that
raises ValueError; keep the existing error messages (adjusted to the ValueError)
and keep the surrounding logic for num_slots/ep_size as-is.
| assert len(weights) == 1, ( | ||
| "MegaMoECuteDsl.load_weights expects a single-element list, " | ||
| f"got {len(weights)} entries." | ||
| ) |
There was a problem hiding this comment.
Replace assert with if ... raise ValueError for input validation.
Same pattern as the constructor: this validates caller-provided input, which should not be silently skipped when Python runs with -O.
Proposed fix
- assert len(weights) == 1, (
+ if len(weights) != 1:
+ raise ValueError(
"MegaMoECuteDsl.load_weights expects a single-element list, "
f"got {len(weights)} entries."
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py` around
lines 708 - 711, The input validation in MegaMoECuteDsl.load_weights uses an
assert which is skipped with Python -O; change this to an explicit check and
raise a ValueError when the length is not 1 (e.g., if len(weights) != 1: raise
ValueError("MegaMoECuteDsl.load_weights expects a single-element list, got
{len(weights)} entries.")) so callers always get a clear exception; update the
same error message currently used by the assert.
| def fc1_sf_flat_size(cls, intermediate: int, hidden: int) -> int: | ||
| """``round_up(expand_intermediate, SfPaddingBlock=128) * | ||
| round_up(ceil(hidden / 16), 4)`` -- matches kernel_fc12.py:880-890. | ||
| ``expand_intermediate = 2 * intermediate``. | ||
| """ | ||
| expand_intermediate = intermediate * 2 | ||
| return (cls._round_up_int(expand_intermediate, 128) * | ||
| cls._round_up_int(cls._ceil_div_int(hidden, 16), 4)) |
There was a problem hiding this comment.
Assert the 2x gated expansion invariant up front.
This implementation hard-codes FC1 as 2 * intermediate in both the derived-SF sizing and the [w3 | w1] split/interleave path, but create_weights() never validates that expand_intermediate_size_per_partition actually satisfies that contract. On any non-2x MoE shape, the MegaMoE buffers are mis-sized and _build_mega_format_buffers() will mis-pack or fail at load time instead of rejecting the config immediately.
💡 Suggested guard
def create_weights(self, module: torch.nn.Module):
+ if (module.expand_intermediate_size_per_partition !=
+ 2 * module.intermediate_size_per_partition):
+ raise NotImplementedError(
+ "NVFP4MegaMoECuteDslMethod currently requires gated 2x intermediate expansion."
+ )
+
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
self.block_scales_vec_size = torch.iinfo(
self.block_scales_dtype).bits // 8Also applies to: 3241-3282, 3582-3609
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 3205 -
3212, The code assumes a 2x gated expansion (expand_intermediate == 2 *
intermediate) but never enforces it, causing mis-sized MegaMoE buffers; add an
explicit validation early in the weight construction path (e.g., in
create_weights) that checks expand_intermediate_size_per_partition == 2 *
intermediate and raise a clear exception if not, and mirror the same guard in
fc1_sf_flat_size (or call the validated value) so fc1_sf_flat_size, the [w3|w1]
split/interleave logic and _build_mega_format_buffers all receive a guaranteed
2x expansion; reference expand_intermediate_size_per_partition, create_weights,
fc1_sf_flat_size, and _build_mega_format_buffers when adding the assertion and
error message.
| per_slot: List[torch.Tensor] = [] | ||
| for slot_idx in range(num_slots): | ||
| sf_fp8 = raw_sf[slot_idx].view(torch.float8_e4m3fn) | ||
| per_slot.append(to_blocked(sf_fp8).view(torch.uint8)) | ||
| stacked = stack_byte_reinterpretable_tensors(per_slot, | ||
| dim=0).contiguous() | ||
| if stacked.shape[-1] == flat_size: | ||
| return stacked | ||
| # Pad zero on the tail so the output shape matches the | ||
| # registered Parameter shape. | ||
| out = torch.zeros((num_slots, flat_size), | ||
| dtype=stacked.dtype, | ||
| device=device) | ||
| out[:, :stacked.shape[-1]] = stacked | ||
| return out |
There was a problem hiding this comment.
Handle zero-slot staging before stacking per-slot SF tensors.
When online EPLB is enabled, local_shared_load_expert_ids can be empty, but _build_mega_shared_staging() still calls _build_mega_format_buffers(). In that case _build_mega_sf() builds per_slot = [] and unconditionally passes it to stack_byte_reinterpretable_tensors(...), which turns the "no shared experts" case into a load-time crash.
💡 Suggested early return
def _build_mega_sf(raw_sf: torch.Tensor, *, num_slots: int,
gate_up_interleave_intermediate: Optional[int],
n_pairs: Optional[int], expand_intermediate: int,
flat_size: int) -> torch.Tensor:
"""Build a flattened, blocked-swizzled NVFP4 SF tensor per slot.
@@
from ...cute_dsl_kernels.mega_moe_nvfp4 import (
stack_byte_reinterpretable_tensors, to_blocked)
device = raw_sf.device
+ if num_slots == 0:
+ return torch.empty((0, flat_size),
+ dtype=torch.uint8,
+ device=device)
+
sf_cols = raw_sf.shape[-1] # int32 unitsAlso applies to: 3665-3706
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tensorrt_llm/_torch/modules/fused_moe/quantization.py` around lines 3535 -
3549, The code crashes when per_slot is empty because
stack_byte_reinterpretable_tensors is called with an empty list; in
_build_mega_sf (and similarly in the other block), detect the zero-slot case
(e.g., if num_slots == 0 or per_slot == []) before stacking and return an
appropriately shaped empty/padded tensor (torch.zeros with shape (0, flat_size)
or (num_slots, flat_size), matching dtype=stacked.dtype/torch.uint8 and
device=device) so downstream callers expecting a (num_slots, flat_size) tensor
do not fail; implement this early-return right after constructing/validating
per_slot and before calling stack_byte_reinterpretable_tensors.
ff947a8 to
41b3aa8
Compare
Introduces MegaMoECuteDsl, a fused-communication MoE backend that runs the ported MegaMoE NVFP4 CuteDSL kernel (Sm100MegaMoEKernel) on SM100/SM103. The kernel fuses dispatch + FC1 + SwiGLU + FC2 + combine in a single launch via the in-kernel NVLink dispatch barrier. Single-rank degenerate path and multi-rank EP path are both wired through a unified always-pad-to-max_tokens_per_rank staging contract.
Components:
- New backend tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_cute_dsl.py with build-time symmetric-memory provider rendezvous, local staging cache, FUSED_COMM scheduler binding, and quantize_input that pads SF columns to round_up(ceil(hidden/16), 4) to match the kernel TMA contract.
- New torch custom op torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell registered in tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py. Hosts Sm100MegaMoENvfp4Runner (TunableRunner) with PARALLEL distributed tuning so every EP rank converges on the same compiled tactic for every chunk (required for the in-kernel dispatch barrier), MegaMoeSymmMemProvider with zero-init buffer, candidate tactic enumeration sweeping {static, atomic_counter} load balance modes, and a stricter IS_MEGAMOE_OP_AVAILABLE probe so half-installed cutlass-dsl wheels do not break the rest of custom_ops. The runner allocates local_workspace via torch.zeros and calls local_workspace.zero_() on every forward in addition to the existing shared_workspace.zero_(), so the cached buffer cannot feed garbage into the kernel's Int32 atomic counters (l1_arrival_count, fc1_done_counter, grid_sync_counter); a negative Int32 in any counter slot would otherwise make the in-kernel spin_wait (v >= positive_threshold) impossible to satisfy and hang the kernel at 100% SM / 0% memory bandwidth. The op returns None and the caller uses the in-place mutated combine_output directly, because torch.library forbids the return value from aliasing any mutated input.
- Ported CuteDSL kernel package at tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/ (16 files) with package-relative imports, plus blocked_scale.py extracted from upstream runner_fc12.py.
- New NVFP4MegaMoECuteDslMethod quant method in tensorrt_llm/_torch/modules/fused_moe/quantization.py: 16-atom gate/up interleave + to_blocked swizzle + flattened-stack staging for mega_fc1_weight / mega_fc1_weight_sf / mega_fc2_weight / mega_fc2_weight_sf. Builds CPU shared-staging buffers for all four derived parameters and registers them with the load balancer through register_all_parameter_slot_and_to_fix_weight_fns so both static AND dynamic EPLB migrate the mega-format bytes atomically with the underlying NVFP4 raw weights + scales.
- FusedCommMoEScheduler now calls backend.quantize_input on zero-token chunks too (DG backend updated in parallel) so each fused-comm backend owns its own empty-tensor layout.
- create_moe.py factory and ConfigurableMoE allowlist updated; MoEDeveloperGuide adds Backend Capability Matrix entries, FUSED_COMM anti-patterns, and the autotuner tactic representation reference.
Tests:
- Drops the asymmetric MoeBackendType.MEGAMOE alias; both variants are spelled out as MEGAMOE_DEEPGEMM and MEGAMOE_CUTEDSL with should_skip_megamoe_deepgemm / should_skip_megamoe_cutedsl helpers.
- New focused tests in test_moe_backend.py: kernel-package import, to_blocked roundtrip, can_implement positive/negative, quantize_input zero-token, multi-rank symm-provider gate, alpha-gate, sf-byte-width helper, atomic_counter sweep, dynamic-EPLB-now-supported, and a byte-equivalence test for the FC1 gate/up 16-atom interleave that catches any gate-vs-up swap.
- test_moe_module.py adds factory + scheduler wiring coverage and updates the shared dist helper to recognise MEGAMOE_CUTEDSL.
Hard gates documented in MEGAMOE_CUTEDSL_DESIGN.md:
- Kernel ABI hard-codes per-expert alpha / fc2_input_scale to 1.0; v1 alpha gate in NVFP4MegaMoECuteDslMethod rejects non-1.0 checkpoint values until the upstream kernel ABI is extended.
- Launch contract requires T == max_tokens_per_rank on every call; backend stages real T rows then pads topk_idx tail to -1 (dispatch_prep skips negative expert ids).
- IS_MEGAMOE_OP_AVAILABLE strict probe protects the rest of custom_ops on partial cutlass-dsl installs.
Known follow-ups:
- per-slot to_blocked perf can be batched (high-volume models).
- Real GPU E2E run is blocked on OCI worktree rebuild for ABI parity with the new C++ bindings.
Designed and documented under tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md.
Signed-off-by: xxi <xxi@nvidia.com>
41b3aa8 to
9134a9b
Compare
This PR introduces the MegaMoECuteDsl NVFP4 MoE backend as a preview / WIP drop.
It is NOT a finished product and is not safe to use as a real serving backend:
Sm100MegaMoEKernelstill hard-codesalpha = 1.0andnorm_const = 1.0in the FC1/FC2 epilogue path. Real NVFP4checkpoints (where
fc31_alpha/fc2_alpha/fc2_input_scaleare not 1) cannotproduce correct results until the kernel ABI is extended to thread these scales through.
See
tensorrt_llm/_torch/modules/fused_moe/mega_moe/MEGAMOE_CUTEDSL_DESIGN.md→"NVFP4 scale and alpha ABI" for the full hard gate.
_check_v1_alpha_gateinNVFP4MegaMoECuteDslMethod.process_weights_after_loadingrejects any checkpoint whose alpha tensors differ from 1.0; the factory falls back to
CutlassFusedMoE / CuteDslFusedMoE in that case so production paths are never silently wrong.
capability gate, weight pipeline, and op registration wire up end-to-end on a GB200
with cu13 cutlass-dsl available; numerical correctness against a reference is gated
on the kernel ABI extension above.
model_config.moe_backend = MEGAMOE_CUTEDSLin any production modelconfig. The factory will fall back via
can_implement/ alpha gate as designed,but the backend itself is staged for follow-up work, not for use.
This PR is intentionally landed as a preview so the MegaMoE port + backend / quant
method / scheduler integration / NVSHMEM-equivalent symmetric memory provider can
evolve in tree without blocking the kernel ABI work. Please do not merge this PR until
the kernel ABI extension and accuracy validation follow-ups land.
Summary
MegaMoECuteDslundertensorrt_llm/_torch/modules/fused_moe/mega_moe/, wrapping the portedSm100MegaMoEKernel(fused dispatch + FC1 + activation + FC2 + combine on SM100/103).tensorrt_llm/_torch/custom_ops/cute_dsl_megamoe_custom_op.py(Sm100MegaMoENvfp4Runner,MegaMoeSymmMemProvider,torch.ops.trtllm.cute_dsl_megamoe_nvfp4_blackwell).tensorrt_llm/_torch/cute_dsl_kernels/mega_moe_nvfp4/(16 files, package-relative imports).NVFP4MegaMoECuteDslMethodquant method: 16-atom gate/up interleave +to_blockedswizzle + flat per-slot SF stack; CPU shared-staging + EPLBfix-up registration for the four
mega_fc{1,2}_weight{,_sf}derived parameters.torch.distributed._symmetric_memory(cuMem-based NVSHMEM equivalent already used in the codebase via
SymmetricMemoryAllReduce).FusedCommMoESchedulerrefactor: zero-token chunks now go throughbackend.quantize_input(...)unconditionally (each fused-comm backend owns its ownempty-tensor layout, scheduler stays layout-agnostic).
ConfigurableMoEwiring (create_moe.py,__init__.pyallowlist).scheduler invariants.
MoeBackendType.MEGAMOEalias removed; both variantsspelled out as
MEGAMOE_DEEPGEMMandMEGAMOE_CUTEDSLwith pairedshould_skip_megamoe_{deepgemm,cutedsl}helpers.Hard gates / open work tracked in MEGAMOE_CUTEDSL_DESIGN.md
fc31_alpha/fc2_alpha/fc2_input_scale.ABI extension is in place.
but not 64-aligned (1568, 1632, 2080) — current test matrix uses only 64-aligned
hidden sizes so the symmetric-memory
activation_sfrow stride path is notexercised.
Test plan / current state on GB200 (cu13 cutlass-dsl)
All MegaMoECuteDsl-related cases collected by
-k megamoe_cutedslwere run on aGB200 node (
nvl72151-T01) inside the OCI cu13 cutlass-dsl container, using aprebuilt artifact symlink (no local C++ build).
Total: 59 cases collected (10 backend + 42 module single-rank + 7 module multi-rank).
Outcome: 0 PASSED / 24 FAILED / 35 SKIPPED.
Failure types (all map to the same v1 limitation, not a regression in this PR):
fc31_alpha ≈ 1e-7;the backend test sets
TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATE=1to bypass thealpha gate, so the kernel runs with
alpha = 1and the output diverges by ~2×10⁷from the reference (accuracy mismatch 100%). This is the documented v1 ABI limitation.
_check_v1_alpha_gateraisesNotImplementedErrorbecause single-rank module test path does not currently set thebypass env var (only the multi-rank path does). Same root cause as Class A; just visible
earlier in the lifecycle.
mismatch (same root cause as Class A).
by
should_skip_megamoe_cutedsl, plus thee256_k6_h4096_i2048shape) and 2multi-rank cases (
e256_k6_h4096_i2048large shape).None of the failures indicate a regression introduced by this PR — they all map to
the same kernel ABI / alpha gate gap.
Things that will land in follow-up PRs (NOT in this preview)
numerical-correctness tests.
TRTLLM_MEGAMOE_CUTEDSL_BYPASS_V1_ALPHA_GATEsymmetrically with the multi-rank fixture so Class B cases reach the kernel.
activation_sfrow-stride path.Reviewer guidance
This is a preview. Reviewers can read code-level structure (backend boundary,
quantization-method boundary, scheduler refactor, MoE developer guide), but
please do NOT advance to merge until the kernel ABI extension and accuracy
validation land.
Summary by CodeRabbit
New Features
Documentation
Tests