Skip to content

[None][perf] AutoDeploy: fuse FP8 quant into allreduce+residual+RMSNorm#14629

Draft
MrGeva wants to merge 1 commit into
NVIDIA:mainfrom
nv-auto-deploy:egeva/ad_fp8_ar_fusion-clean
Draft

[None][perf] AutoDeploy: fuse FP8 quant into allreduce+residual+RMSNorm#14629
MrGeva wants to merge 1 commit into
NVIDIA:mainfrom
nv-auto-deploy:egeva/ad_fp8_ar_fusion-clean

Conversation

@MrGeva
Copy link
Copy Markdown
Collaborator

@MrGeva MrGeva commented May 27, 2026

Adds a new transform fuse_allreduce_residual_rmsnorm_quant_fp8 that runs after fuse_allreduce_residual_rmsnorm and folds the downstream per-tensor FP8 input quantize into the AllReduce kernel by switching from AllReduceFusionOp.RESIDUAL_RMS_NORM to
RESIDUAL_RMS_NORM_QUANT_FP8. This matches what the PyTorch backend does for FP8 dense models (see modeling_llama.py).

For each layer with TP > 1 we previously executed:
allreduce + residual + rmsnorm -> 1 fused kernel (bf16 out)
static_quantize_e4m3_per_tensor -> 1 extra kernel (fp8 out)
trtllm_fp8_prequant_linear -> matmul
With the transform the first two become a single AllReduce-fused
kernel emitting FP8 directly into the next linear's input.

Llama-3.1-8B-Instruct-FP8 on B200 TP=2, ISL=OSL=1000 (vs PT-backend):

  • c=1: 2.42 -> 2.32 ms (-0.10) -- AD now ahead of PT (2.35)
  • c=2: 2.42 -> 2.32 ms (-0.10) -- AD ahead of PT (2.37)
  • c=16: 2.85 -> 2.74 ms (-0.11) -- half of the residual gap
  • c=256: 24.10 -> 27.01 ms (+2.91) -- regression

Because the FP8-fused AllReduce kernel trades launch savings for less-optimal large-batch work, the transform is left enabled: false in default.yaml and opt-in per registry yaml. Accuracy on MMLU/GSM8K is unchanged from the non-fused FP8 path.

@coderabbitai summary

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Adds a new transform ``fuse_allreduce_residual_rmsnorm_quant_fp8`` that
runs after ``fuse_allreduce_residual_rmsnorm`` and folds the downstream
per-tensor FP8 input quantize into the AllReduce kernel by switching
from ``AllReduceFusionOp.RESIDUAL_RMS_NORM`` to
``RESIDUAL_RMS_NORM_QUANT_FP8``. This matches what the PyTorch backend
does for FP8 dense models (see ``modeling_llama.py``).

For each layer with TP > 1 we previously executed:
    allreduce + residual + rmsnorm  -> 1 fused kernel (bf16 out)
    static_quantize_e4m3_per_tensor -> 1 extra kernel (fp8 out)
    trtllm_fp8_prequant_linear      -> matmul
With the transform the first two become a single AllReduce-fused
kernel emitting FP8 directly into the next linear's input.

Llama-3.1-8B-Instruct-FP8 on B200 TP=2, ISL=OSL=1000 (vs PT-backend):
- c=1:   2.42 -> 2.32 ms (-0.10) -- AD now ahead of PT (2.35)
- c=2:   2.42 -> 2.32 ms (-0.10) -- AD ahead of PT (2.37)
- c=16:  2.85 -> 2.74 ms (-0.11) -- half of the residual gap
- c=256: 24.10 -> 27.01 ms (+2.91) -- regression

Because the FP8-fused AllReduce kernel trades launch savings for
less-optimal large-batch work, the transform is left ``enabled: false``
in ``default.yaml`` and opt-in per registry yaml. Accuracy on MMLU/GSM8K
is unchanged from the non-fused FP8 path.

Signed-off-by: egeva <19514940+MrGeva@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant