[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR integrates logit-level knowledge distillation with enhanced trainer infrastructure: it introduces ChangesTrainer Infrastructure & KDTrainer Redesign
Example Wiring & Argument Updates
Test Coverage
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1191 +/- ##
==========================================
+ Coverage 76.43% 76.88% +0.45%
==========================================
Files 488 488
Lines 54115 54386 +271
==========================================
+ Hits 41362 41817 +455
+ Misses 12753 12569 -184
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
e349a90 to
b762870
Compare
97759a4 to
bfc343c
Compare
cc45203 to
9dd1732
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: This PR simplifies the HF knowledge distillation trainer by removing mtd.convert() class-swap in favor of explicit teacher forwarding, enhances ModelOptHFTrainer with Liger fused loss, per-parameter LRs, parameter freezing, and refactors the llm_qat example to use YAML configs with a new ModelOptArgParser.
Issues Found:
-
[Correctness] CRITICAL:
recipe.ptq_cfgdoes not exist — should berecipe.quantize
ModelOptPTQRecipe(inmodelopt/recipe/config.py:69) exposes its quant config via thequantizeattribute, notptq_cfg. This will raiseAttributeErrorat runtime in two places:examples/llm_qat/simple_qat_train.py:126—model = mtq.quantize(model, recipe.ptq_cfg, calibrate)modelopt/torch/quantization/plugins/transformers_trainer.py:217—return recipe.ptq_cfg
The correct usage is already in
examples/llm_qat/quantize.py:77(ptq_cfg = recipe.quantize), confirming this is a copy-paste error. Existing usage inexamples/llm_ptq/hf_ptq.pyalso usesrecipe.quantize. -
[Correctness]
LMLogitsLoss.forwarddouble-sums — returns scalar instead of per-token losses
LogitsDistillationLoss.forwardwithreduction="none"already sums over the vocab dimension (line 64 oflosses.py), returning shape(B*S,). The newLMLogitsLoss.forwardthen does another.sum(dim=-1)on this 1D tensor, collapsing it to a scalar. This makes the ignore-index masking in_standard_kd_lossa no-op — padding tokens contribute equally to the loss. -
[Correctness] Inconsistent causal shift between standard and Liger KD paths
_liger_kd_lossapplies the standard causal LM shift (hidden_states[..., :-1, :],labels[..., 1:]) before computing JSD. But_standard_kd_lossapplies no shift — it computes KL-div on allB*Spositions and masks with unshifted labels. While comparing student/teacher at the same position is valid, the different masking alignment means the two paths produce semantically different losses for the same input. This will be surprising when toggling--use_liger_kernel. -
[Correctness]
_forward_redirectdoesn't restoremodule.forwardon failure
If themodule(dummy)call raises before enteringwrapped_forward(e.g., FSDP pre-forward hook fails),module.forwardremains patched towrapped_forward. Atry/finallywould be safer. -
[Tests] Low coverage on core library files
Codecov reports ~20% patch coverage fortransformers.py(165 missing lines) andhuggingface.py(82 missing lines). The new Liger fused loss path,_forward_redirect,_sharded_liger_compute, parameter freezing, andsave_dtyperewriting have no unit test coverage. These are critical code paths for distributed training correctness. -
[Correctness]
save_dtypedefaults to"bfloat16"instead of preserving original model dtype
The oldQATTrainersaved the model's original dtype (self._original_dtype). The newModelOptHFTrainer.save_modelhardcodessave_dtype="bfloat16"by default. For models originally infloat16, this silently changes the config.json dtype, which may affect downstream inference engines. -
[Readability] Misleading shape comments in
LMLogitsLoss
The comment# (B*S, V)on thesuper().forward()call is wrong — the parent returns(B*S,)whenreduction="none". This directly led to bug #2.
Suggestions:
- The
liger-kernel>=0.5.0addition topyproject.toml[hf]extras makes it a hard install dependency for all HF users. Since usage is guarded by--use_liger_kernel, consider making it an optional extra ([liger]or[hf-liger]) to avoid install issues in constrained environments. - The pre-commit hook for
generate-arguments-mduseslanguage: systemand runspython examples/llm_qat/train.py --generate_docs. This requires all modelopt dependencies to be installed in the pre-commit environment, which will fail for most contributors. Considerlanguage: pythonwith explicit dependencies, or making it a manual step. _dataset_cache(module-level mutable dict indataset_utils.py) acts as an in-memory cache but is never evicted. In long-running processes or notebooks, this could hold large datasets in memory indefinitely.
Overall Assessment: The architectural direction is sound — removing mtd.convert() class-swap in favor of explicit teacher forwarding is a significant simplification. The ModelOptArgParser and YAML-based config system is a good developer experience improvement. However, there are two critical correctness bugs (recipe.ptq_cfg and LMLogitsLoss double-sum) that must be fixed before merge, plus the KD loss path inconsistency warrants discussion.
| def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None: | ||
| """Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str.""" | ||
| cfg_path = os.path.join(output_dir, "config.json") | ||
| if not os.path.isfile(cfg_path): |
There was a problem hiding this comment.
save_dtype defaults to "bfloat16" which silently changes the dtype for float16 models. Consider defaulting to None and falling back to the model's original dtype when not explicitly set.
There was a problem hiding this comment.
Partial fix landed in 1f1c2507 (dataclass default set to None). Follow-up fix staged locally: _update_config_json_dtype now early-returns when dtype_str is None, so the original model dtype written by super().save_model() is preserved. The getattr fallback in save_model is also aligned to None for consistency. Will be included in the next push.
99da38e to
7718d64
Compare
4472470 to
1088304
Compare
b93b8d3 to
2a23765
Compare
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
🧹 Nitpick comments (3)
modelopt/torch/distill/plugins/huggingface.py (3)
16-34: 💤 Low valueMissing
__all__for public API definition.Per coding guidelines, define the public API with
__all__at the top of each Python module. This module exportsDistillArguments,DistillArgsWithTeacherModel,KDTrainer, andIGNORE_INDEX.Suggested addition after imports
IGNORE_INDEX = nn.CrossEntropyLoss().ignore_index + +__all__ = [ + "IGNORE_INDEX", + "DistillArguments", + "DistillArgsWithTeacherModel", + "KDTrainer", +]As per coding guidelines: "Define the public API with
__all__at the top of each Python module."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/distill/plugins/huggingface.py` around lines 16 - 34, Add a module-level __all__ to explicitly declare the public API: include "DistillArguments", "DistillArgsWithTeacherModel", "KDTrainer", and "IGNORE_INDEX". Place the __all__ definition near the top of the file after the imports (above IGNORE_INDEX) so static analysis and consumers know which names are exported, and ensure the string names exactly match the classes/variables defined in this module.
110-110: ⚡ Quick winReplace
assertwithValueErrorfor runtime validation.
assertstatements can be stripped in optimized Python (-Oflag), making this check ineffective. Use explicitraisefor required arguments.Suggested fix
- assert distill_args is not None, "`distill_args` is required for distillation." + if distill_args is None: + raise ValueError("`distill_args` is required for distillation.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/distill/plugins/huggingface.py` at line 110, The assertion "assert distill_args is not None, '`distill_args` is required for distillation.'" should be replaced with an explicit runtime check that raises a ValueError so it can't be bypassed with Python -O; locate the check around the distillation setup (the `distill_args` validation) and change it to: if distill_args is None: raise ValueError("`distill_args` is required for distillation.") ensuring the message matches the original assertion text.
124-124: ⚡ Quick winReplace
assertwithValueErrorfor teacher model validation.Same concern as above—
assertcan be optimized away.Suggested fix
- assert teacher is not None, "`distill_args.teacher_model` is required." + if teacher is None: + raise ValueError("`distill_args.teacher_model` is required.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/distill/plugins/huggingface.py` at line 124, Replace the runtime-unsafe assert in the huggingface plugin with an explicit exception: where the code currently does `assert teacher is not None, "...distill_args.teacher_model..."`, change it to raise a ValueError when the teacher model is missing (e.g., check `if teacher is None: raise ValueError("`distill_args.teacher_model` is required.")`). Update the check in the same scope that references teacher/distill_args.teacher_model to ensure deterministic validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/llm_qat/arguments.py`:
- Around line 97-100: The option use_liger_kernel currently defaults to True
which can cause an ImportError if the liger-kernel package is not installed;
change the field in examples/llm_qat/arguments.py to default=False, or add a
guarded fallback in the code that uses it (in the Transformers plugin where
_liger_loss_func is defined) by wrapping the import of liger_kernel in a
try/except and disabling/ignoring use_liger_kernel when the import fails so
training continues safely without that dependency.
In `@examples/llm_qat/llama_factory/llama_factory.py`:
- Around line 216-219: The teacher model is loaded without honoring
trust_remote_code; update the call to
transformers.AutoModelForCausalLM.from_pretrained that creates teacher_model so
it passes trust_remote_code=model_args.trust_remote_code (or the equivalent
flag), and ensure model_args is accessible in that scope (e.g., capture
model_args in the closure or pass it into CustomTrainer.__init__) before
assigning modelopt_trainer_args["distill_args"] = {"teacher_model":
teacher_model}.
In `@examples/llm_qat/quantize.py`:
- Around line 74-79: The call to
transformers.AutoModelForCausalLM.from_pretrained is passing
dtype=torch.bfloat16 which is incorrect for HF; change the kwarg name to
torch_dtype (i.e., use torch_dtype=torch.bfloat16) in the model =
transformers.AutoModelForCausalLM.from_pretrained(...) invocation so the model
loads with the intended precision; update the same pattern wherever
from_pretrained is used (e.g., in quantize.py model creation and parallel
occurrences).
In `@examples/llm_qat/train.py`:
- Around line 74-78: The student model instantiation uses the wrong/
inconsistent kwarg name `dtype=torch.bfloat16`; update the call to
transformers.AutoModelForCausalLM.from_pretrained(...) to use
`torch_dtype=torch.bfloat16` (matching the teacher model usage) so the
HuggingFace loader receives the correct parameter; adjust the call that builds
`model` (using `model_args.model_name_or_path` and `model_kwargs`) to replace
`dtype` with `torch_dtype`.
---
Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 16-34: Add a module-level __all__ to explicitly declare the public
API: include "DistillArguments", "DistillArgsWithTeacherModel", "KDTrainer", and
"IGNORE_INDEX". Place the __all__ definition near the top of the file after the
imports (above IGNORE_INDEX) so static analysis and consumers know which names
are exported, and ensure the string names exactly match the classes/variables
defined in this module.
- Line 110: The assertion "assert distill_args is not None, '`distill_args` is
required for distillation.'" should be replaced with an explicit runtime check
that raises a ValueError so it can't be bypassed with Python -O; locate the
check around the distillation setup (the `distill_args` validation) and change
it to: if distill_args is None: raise ValueError("`distill_args` is required for
distillation.") ensuring the message matches the original assertion text.
- Line 124: Replace the runtime-unsafe assert in the huggingface plugin with an
explicit exception: where the code currently does `assert teacher is not None,
"...distill_args.teacher_model..."`, change it to raise a ValueError when the
teacher model is missing (e.g., check `if teacher is None: raise
ValueError("`distill_args.teacher_model` is required.")`). Update the check in
the same scope that references teacher/distill_args.teacher_model to ensure
deterministic validation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: eb7e9688-e418-42db-a881-c765b8eb72cc
📒 Files selected for processing (20)
CHANGELOG.rstexamples/llm_distill/main.pyexamples/llm_qat/ARGUMENTS.mdexamples/llm_qat/README.mdexamples/llm_qat/arguments.pyexamples/llm_qat/configs/train/finetune.yamlexamples/llm_qat/configs/train/lr_config_example.yamlexamples/llm_qat/configs/train/qad_nvfp4.yamlexamples/llm_qat/configs/train/qat_nvfp4.yamlexamples/llm_qat/configs/train/qlora_nvfp4.yamlexamples/llm_qat/llama_factory/llama_factory.pyexamples/llm_qat/quantize.pyexamples/llm_qat/requirements.txtexamples/llm_qat/train.pymodelopt/torch/distill/plugins/huggingface.pymodelopt/torch/opt/plugins/transformers.pymodelopt/torch/quantization/plugins/transformers_trainer.pypyproject.tomltests/examples/llm_qat/test_llm_qat.pytests/unit/torch/opt/plugins/test_lr_config.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
examples/llm_qat/train.py (1)
109-113:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winIncorrect parameter name:
dtypeshould betorch_dtype.The HuggingFace
AutoModelForCausalLM.from_pretrained()API expectstorch_dtype, notdtype. While the student model at line 76 has the same issue (flagged in a prior review), this teacher model loading at line 111 is newly changed and uses the same incorrect parameter name.Suggested fix
teacher = transformers.AutoModelForCausalLM.from_pretrained( distill_args.teacher_model, - dtype=torch.bfloat16, + torch_dtype=torch.bfloat16, **model_kwargs, )transformers AutoModelForCausalLM.from_pretrained torch_dtype parameter🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/llm_qat/train.py` around lines 109 - 113, The call to transformers.AutoModelForCausalLM.from_pretrained for the teacher model (variable teacher) is using the wrong parameter name `dtype`; change it to `torch_dtype` so the loader receives the correct argument (i.e., replace `dtype=torch.bfloat16` with `torch_dtype=torch.bfloat16`); also search for other from_pretrained calls (e.g., the student model) and make the same change to keep parameter names consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@examples/llm_qat/train.py`:
- Around line 109-113: The call to
transformers.AutoModelForCausalLM.from_pretrained for the teacher model
(variable teacher) is using the wrong parameter name `dtype`; change it to
`torch_dtype` so the loader receives the correct argument (i.e., replace
`dtype=torch.bfloat16` with `torch_dtype=torch.bfloat16`); also search for other
from_pretrained calls (e.g., the student model) and make the same change to keep
parameter names consistent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4b065c1a-7299-4705-b4f9-ad850c90eed1
📒 Files selected for processing (3)
examples/llm_qat/train.pymodelopt/torch/distill/plugins/huggingface.pytests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/examples/llm_qat/test_llm_qat.py
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: most prior critical comments are resolved in this revision — recipe.ptq_cfg → recipe.quantize migration via resolve_quant_cfg_from_args (fixed), LMLogitsLoss double-sum (class removed entirely, KD now calls LogitsDistillationLoss directly), causal-shift inconsistency between standard and Liger KD (_standard_kd_loss now shifts to [..., :-1, :] to match _liger_kd_loss), _forward_redirect failure cleanup (now wrapped in try/except), CE validation loss (compute_loss now runs CE in eval and records KD as eval_kd_loss secondary metric), and unit tests added for both KD and lr_config. Two prior LGTMs (mxinO, shengliangxu).
Flagging for human sign-off because of one outstanding behavior change and the breadth of the PR:
- 💬
save_dtypedefault — author replied: "Partial fix landed in 1f1c250 (dataclass default set to None). Follow-up fix staged locally… Will be included in the next push." As of the latest diff,ModelOptTrainerArguments.save_dtypestill defaults to"bfloat16"(modelopt/torch/opt/plugins/transformers.py L244), so float16 models that go throughModelOptHFTrainer.save_modelwill silently have theirconfig.jsondtype rewritten to bfloat16. The promised follow-up to default toNoneand preserve the original dtype hasn't landed yet — please confirm before merge. - Architectural surface: this PR introduces
ModelOptTrainerArguments(new public API), per-parameter LR YAML config,_forward_redirect, model-agnostic Liger fused-CE/JSD loss with DDP/ZeRO-3 support, and removesmtd.convert()from the HF KD path. The PR body justifies the design choices (model-agnostic vs HF's fixed-architecture Liger, ZeRO-3 vs FSDP-only, lazy teacher loading bug fix), and thelr_configis a thin pattern→kwargs map rather than a new DSL — design looks reasonable, but the public-API breadth and theKDTrainersemantics shift (no moreDistillationModel, no intermediate-layer distillation on the HF path) warrant human eyes. GPU example tests (tests/examples/llm_qat/,tests/examples/llm_distill/) are still unchecked in the PR's own test plan. - Minor:
_record_eval_kd_lossweights each batch's KD loss byfind_batch_size(inputs)rather than non-ignore token count, so the reportedeval_kd_lossis a batch-weighted mean that can drift from a true token-weighted mean when sequence padding varies — likely fine for monitoring but worth a comment.
|
/claude review |
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: all prior critical correctness comments are resolved in this revision — recipe.ptq_cfg → recipe.quantize (both simple_qat_train.py and transformers_trainer.py); LMLogitsLoss double-sum (class removed; KD now calls LogitsDistillationLoss directly); causal-shift inconsistency between standard/Liger KD (_standard_kd_loss now shifts to [..., :-1, :]); _forward_redirect failure cleanup (try/except restores module.forward and re-raises); CE validation loss (compute_loss runs CE in eval and records KD as eval_kd_loss secondary metric); unit tests added for KD and lr_config. Two prior LGTMs (mxinO, shengliangxu).
Flagging for human sign-off because of architectural breadth and unchecked example tests:
- 💬
save_dtypedefault — author replied: "Partial fix landed in 1f1c250; follow-up will be in next push." Now fully resolved: thesave_dtypedataclass field is gone entirely, andModelOptHFTrainer.__init__derives_original_dtypefrommodel.config.dtype/torch_dtype;_update_config_json_dtypeearly-returns onNone, so float16 models are no longer silently rewritten to bfloat16. - Architectural surface: this PR introduces
ModelOptTrainerArguments(new public API), per-parameter LR YAML config,_forward_redirect, and model-agnostic Liger fused-CE/JSD loss with DDP/ZeRO-3 support, and removesmtd.convert()from the HF KD path (DistillationModel/intermediate-layer distillation no longer available on HF path). The PR body justifies the design choices (model-agnostic vs HF's fixed-architecture Liger, ZeRO-3 vs FSDP-only, lazy teacher loading bug fix), andlr_configis a thin pattern→kwargs map rather than a new DSL — design looks reasonable, but the public-API breadth and theKDTrainersemantic shift warrant human eyes. - GPU example tests (
tests/examples/llm_qat/,tests/examples/llm_distill/) remain unchecked in the PR's own test plan checkboxes — please confirm before merge. - Minor:
_record_eval_kd_lossweights each batch's KD loss byfind_batch_size(inputs)rather than non-ignore token count, so reportedeval_kd_lossis batch-weighted and can drift from a token-weighted mean when sequence padding varies — likely fine for monitoring but worth a comment.
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Summary
This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base
ModelOptHFTrainerwith Liger fused loss, per-parameter learning rates, and training utilities.Model-agnostic Liger kernel fused loss
Adds custom Liger kernel integration in
ModelOptHFTrainerthat extends HuggingFace's built-in support in three ways:lm_head, unlike HF's Liger which only supports a fixed set of model architectures.KDTrainerextends fused loss to knowledge distillation viaLigerFusedLinearJSDfor fused lm_head + Jensen-Shannon divergence.Liger kernel memory sweep (Qwen3-1.7B, 2×H100 FSDP2, NVFP4+FP8_KV)
Max per-GPU batch size before OOM at each sequence length:
QAT (no teacher)
QAD (with teacher)
Liger fused loss enables 2-4× larger batch sizes at long context lengths by avoiding the materialization of the full logit tensor.
ModelOptHFTrainer enhancements
ModelOptTrainerArgumentswith--trainable_params,--frozen_params,--lr_config,--save_dtype, and--manual_gcflagslr_config)_prepare_modeland_update_config_json_dtypepromoted to base classKDTrainer simplification + fix
Removes
mtd.convert()and theDistillationModelin-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly insidecompute_kd_loss_func. This eliminates:mtd.convert()in-place class swap and DynamicModule wrappinghide_teacher_model/hide_loss_modulescontext managers for checkpointingsave_modelandQADTrainer._quantize_modeloverridesBug fix: The previous
DistillationModel/mtd.convert()approach did not support CPU RAM-efficient loading for QAD. The teacher model had to be fully loaded on GPU before wrapping, which doubled peak memory during initialization. The new approach loads the teacher lazily on the trainer, enabling standard HF device-map and low-cpu-mem-usage loading.Only logit-level distillation is supported for the HF path. The core
DistillationModel/mtd.convert()API remains for Megatron and advanced intermediate-layer distillation use cases.Test plan
pytest tests/unit/torch/distill/(29 passed)pytest tests/unit/torch/opt/plugins/test_hf_patching.py(2 passed)pytest tests/unit/torch/opt/plugins/test_lr_config.pypytest tests/examples/llm_qat/(QAT, QAD, LoRA QAT, QLoRA)pytest tests/examples/llm_distill/🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
ModelOptHFTrainerfor distributed causal language models with JSD distillation loss support.ModelOptTrainerArgumentswith new training CLI flags: per-parameter learning rates via YAML, parameter freezing, and manual garbage collection.Documentation
Tests