[None][perf] AutoDeploy: fuse FP8 quant into allreduce+residual+RMSNorm#14629
Draft
MrGeva wants to merge 1 commit into
Draft
[None][perf] AutoDeploy: fuse FP8 quant into allreduce+residual+RMSNorm#14629MrGeva wants to merge 1 commit into
MrGeva wants to merge 1 commit into
Conversation
Adds a new transform ``fuse_allreduce_residual_rmsnorm_quant_fp8`` that
runs after ``fuse_allreduce_residual_rmsnorm`` and folds the downstream
per-tensor FP8 input quantize into the AllReduce kernel by switching
from ``AllReduceFusionOp.RESIDUAL_RMS_NORM`` to
``RESIDUAL_RMS_NORM_QUANT_FP8``. This matches what the PyTorch backend
does for FP8 dense models (see ``modeling_llama.py``).
For each layer with TP > 1 we previously executed:
allreduce + residual + rmsnorm -> 1 fused kernel (bf16 out)
static_quantize_e4m3_per_tensor -> 1 extra kernel (fp8 out)
trtllm_fp8_prequant_linear -> matmul
With the transform the first two become a single AllReduce-fused
kernel emitting FP8 directly into the next linear's input.
Llama-3.1-8B-Instruct-FP8 on B200 TP=2, ISL=OSL=1000 (vs PT-backend):
- c=1: 2.42 -> 2.32 ms (-0.10) -- AD now ahead of PT (2.35)
- c=2: 2.42 -> 2.32 ms (-0.10) -- AD ahead of PT (2.37)
- c=16: 2.85 -> 2.74 ms (-0.11) -- half of the residual gap
- c=256: 24.10 -> 27.01 ms (+2.91) -- regression
Because the FP8-fused AllReduce kernel trades launch savings for
less-optimal large-batch work, the transform is left ``enabled: false``
in ``default.yaml`` and opt-in per registry yaml. Accuracy on MMLU/GSM8K
is unchanged from the non-fused FP8 path.
Signed-off-by: egeva <19514940+MrGeva@users.noreply.github.com>
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a new transform
fuse_allreduce_residual_rmsnorm_quant_fp8that runs afterfuse_allreduce_residual_rmsnormand folds the downstream per-tensor FP8 input quantize into the AllReduce kernel by switching fromAllReduceFusionOp.RESIDUAL_RMS_NORMtoRESIDUAL_RMS_NORM_QUANT_FP8. This matches what the PyTorch backend does for FP8 dense models (seemodeling_llama.py).For each layer with TP > 1 we previously executed:
allreduce + residual + rmsnorm -> 1 fused kernel (bf16 out)
static_quantize_e4m3_per_tensor -> 1 extra kernel (fp8 out)
trtllm_fp8_prequant_linear -> matmul
With the transform the first two become a single AllReduce-fused
kernel emitting FP8 directly into the next linear's input.
Llama-3.1-8B-Instruct-FP8 on B200 TP=2, ISL=OSL=1000 (vs PT-backend):
Because the FP8-fused AllReduce kernel trades launch savings for less-optimal large-batch work, the transform is left
enabled: falseindefault.yamland opt-in per registry yaml. Accuracy on MMLU/GSM8K is unchanged from the non-fused FP8 path.@coderabbitai summary
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.