Skip to content

[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191

Merged
realAsma merged 3 commits into
mainfrom
asma/new-qat-2
Jun 5, 2026
Merged

[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer#1191
realAsma merged 3 commits into
mainfrom
asma/new-qat-2

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Apr 7, 2026

Summary

This PR simplifies the HuggingFace knowledge distillation trainer and enhances the base ModelOptHFTrainer with Liger fused loss, per-parameter learning rates, and training utilities.

Model-agnostic Liger kernel fused loss

Adds custom Liger kernel integration in ModelOptHFTrainer that extends HuggingFace's built-in support in three ways:

  1. Model-agnostic: Works with any causal LM that has an lm_head, unlike HF's Liger which only supports a fixed set of model architectures.
  2. DeepSpeed ZeRO-3 support: HF's Liger integration only works with FSDP. ModelOpt adds distributed param gathering for DeepSpeed ZeRO-3 and DDP as well.
  3. KD loss support: KDTrainer extends fused loss to knowledge distillation via LigerFusedLinearJSD for fused lm_head + Jensen-Shannon divergence.

Liger kernel memory sweep (Qwen3-1.7B, 2×H100 FSDP2, NVFP4+FP8_KV)

Max per-GPU batch size before OOM at each sequence length:

QAT (no teacher)

Seq Length 512 1024 2048 4096 8192 16384
Liger 16 16 16 16 8 4
No Liger 16 16 8 4 2 OOM

QAD (with teacher)

Seq Length 512 1024 2048 4096 8192 16384
Liger 16 16 8 4 2 1
No Liger 8 4 2 1 OOM OOM

Liger fused loss enables 2-4× larger batch sizes at long context lengths by avoiding the materialization of the full logit tensor.

ModelOptHFTrainer enhancements

  • ModelOptTrainerArguments with --trainable_params, --frozen_params, --lr_config, --save_dtype, and --manual_gc flags
  • Per-parameter learning rate support via YAML config (lr_config)
  • _prepare_model and _update_config_json_dtype promoted to base class

KDTrainer simplification + fix

Removes mtd.convert() and the DistillationModel in-place class-swap for the HF path. The teacher model now lives directly on the trainer and is forwarded explicitly inside compute_kd_loss_func. This eliminates:

  • mtd.convert() in-place class swap and DynamicModule wrapping
  • Forward hooks for capturing intermediate outputs
  • hide_teacher_model / hide_loss_modules context managers for checkpointing
  • Deferred initialization branching (FSDP2 vs DDP/DeepSpeed)
  • save_model and QADTrainer._quantize_model overrides

Bug fix: The previous DistillationModel/mtd.convert() approach did not support CPU RAM-efficient loading for QAD. The teacher model had to be fully loaded on GPU before wrapping, which doubled peak memory during initialization. The new approach loads the teacher lazily on the trainer, enabling standard HF device-map and low-cpu-mem-usage loading.

Only logit-level distillation is supported for the HF path. The core DistillationModel/mtd.convert() API remains for Megatron and advanced intermediate-layer distillation use cases.

Test plan

  • pytest tests/unit/torch/distill/ (29 passed)
  • pytest tests/unit/torch/opt/plugins/test_hf_patching.py (2 passed)
  • pytest tests/unit/torch/opt/plugins/test_lr_config.py
  • Pre-commit hooks pass
  • GPU example tests: pytest tests/examples/llm_qat/ (QAT, QAD, LoRA QAT, QLoRA)
  • GPU distill example: pytest tests/examples/llm_distill/

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Liger fused loss support in ModelOptHFTrainer for distributed causal language models with JSD distillation loss support.
    • Introduced ModelOptTrainerArguments with new training CLI flags: per-parameter learning rates via YAML, parameter freezing, and manual garbage collection.
    • Simplified knowledge distillation trainer with logit-level distillation support.
  • Documentation

    • Updated example configurations and documentation with new training options and defaults.
    • Added learning rate configuration example guide.
  • Tests

    • Added test coverage for distillation training and per-parameter optimizer configuration.

@realAsma realAsma requested review from a team as code owners April 7, 2026 21:40
@realAsma realAsma requested review from Edwardf0t1 and removed request for a team April 7, 2026 21:40
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR integrates logit-level knowledge distillation with enhanced trainer infrastructure: it introduces ModelOptTrainerArguments for per-parameter learning rate and parameter-freezing control, expands ModelOptHFTrainer with lr_config loading and Liger fused-loss routing, and redesigns KDTrainer to use explicit teacher models with standard and fused distillation paths. Updated examples, configs, docs, and comprehensive tests follow.

Changes

Trainer Infrastructure & KDTrainer Redesign

Layer / File(s) Summary
Trainer argument contract and ModelOptHFTrainer expansion
modelopt/torch/opt/plugins/transformers.py
ModelOptTrainerArguments dataclass introduces CLI-configurable parameter freezing/trainability globs, YAML lr_config path loading, save_dtype config rewriting, manual GC triggering, and Liger CE label smoothing. ModelOptHFTrainer now accepts trainer_args and lr_config, applies parameter requires_grad settings via glob matching, groups optimizer parameters by fnmatch patterns with per-pattern kwargs, routes fused loss computation through _forward_redirect for FSDP2/DeepSpeed gathering, temporarily patches lm_head.forward to identity during fused computation, and rewrites saved config.json with save_dtype. Manual GC is triggered in training/prediction/checkpoint-load steps when enabled.
KDTrainer logit-distillation redesign
modelopt/torch/distill/plugins/huggingface.py
Refactors KDTrainer to require explicit teacher_model in distill_args instead of model wrapping. DistillArguments adds temperature and liger_jsd_beta fields. New DistillArgsWithTeacherModel type enforces pre-loaded teacher. KDTrainer validates criterion=logits_loss, enforces FSDP2, freezes teacher gradients, and prepares teacher via accelerator. Standard KD loss computes causal logit-shifted per-token distillation with IGNORE_INDEX masking. Fused liger-kernel KD patches lm_head to identity, uses LigerFusedLinearJSD, and handles FSDP/DeepSpeed via explicit parameter gathering. Evaluation accumulates KD loss as separate *_kd_loss metric.
QAT/QAD trainer alignment
modelopt/torch/quantization/plugins/transformers_trainer.py
Removes QATTrainer._update_config_json_dtype method and call site; dtype rewriting now handled by ModelOptHFTrainer.save_model. Changes QATTrainer.evaluate model prep from accelerator.prepare dummy-optimizer hack to _prepare_model(self.model). Removes QADTrainer._quantize_model override; quantization now uses inherited base implementation.

Example Wiring & Argument Updates

Layer / File(s) Summary
Example argument schemas and script wiring
examples/llm_qat/arguments.py, examples/llm_qat/train.py, examples/llm_qat/quantize.py, examples/llm_qat/llama_factory/llama_factory.py, examples/llm_distill/main.py
ModelArguments adds attn_implementation field and increases model_max_length to 8192. TrainingArguments inherits ModelOptTrainerArguments and adds use_liger_kernel. Train/quantize scripts create model_kwargs dict and conditionally pass attn_implementation to from_pretrained. Train.py removes FSDP2 distillation warning, sets lora_config earlier, builds distill_config dict with teacher/temperature/criterion/liger_jsd_beta, and selects QADTrainer vs QATTrainer based on distillation flag. Llama_factory caches model init kwargs and reuses them for teacher loading; distill_args now contains only teacher_model. Llm_distill removes LMLogitsLoss import and kd_config dict, passing only distill_args={"teacher_model": teacher_model}.
Training configs, lr_config example, and dependencies
examples/llm_qat/configs/train/*.yaml, examples/llm_qat/requirements.txt
All training configs increase model_max_length to 8192, add use_liger_kernel: true and manual_gc: true, remove gradient_checkpointing: true, and set attn_implementation: flash_attention_2. New lr_config_example.yaml documents fnmatch pattern-to-optimizer-kwargs mapping with sample patterns for lm_head, self_attn, mlp, and embed_tokens. Adds liger-kernel dependency.
Documentation
examples/llm_qat/ARGUMENTS.md, examples/llm_qat/README.md, CHANGELOG.rst
ARGUMENTS.md documents new CLI flags including temperature, liger_jsd_beta, attn_implementation, trainable_params, frozen_params, lr_config, save_dtype, manual_gc, and liger_ce_label_smoothing. README.md clarifies YAML/CLI argument sourcing, replaces QAD example with DistillArguments-based distill_args, adds quantization format table, and sets advanced-config section id. CHANGELOG.rst records three new v0.44 features: Liger fused loss + JSD, ModelOptTrainerArguments with training flags, and KDTrainer simplification.

Test Coverage

Layer / File(s) Summary
Example integration test refactoring and new QAD coverage
tests/examples/llm_qat/test_llm_qat.py
Introduces FAST_TRAIN_ARGS constant unifying fast-training CLI overrides. Refactors _run_quantize and _run_train helpers to accept explicit config path argument and pass it to quantize.py/train.py via --config. Updates all test calls to supply config paths (qat_nvfp4.yaml, qlora_nvfp4.yaml, etc.). Adds new end-to-end test_qwen3_qad_nvfp4 parameterized over FSDP2 and DeepSpeed, quantizing a student model and running QAD training with distillation.
Unit tests for KDTrainer logit-distillation
tests/unit/torch/distill/plugins/test_huggingface.py
Adds comprehensive KD unit tests with minimal _TinyCausalLM and _ToyDataset helpers. Tests verify that KDTrainer.compute_loss produces standard KD loss matching manual logits-distillation computation, that evaluate() returns CE-based eval_loss with separate eval_kd_loss metric, and that missing-label fallback uses mean reduction.
Unit tests for lr_config pattern-matching
tests/unit/torch/opt/plugins/test_lr_config.py
Adds unit test module validating ModelOptHFTrainer.lr_config behavior with minimal TinyModel including named submodules for parameter-group matching. Covers YAML loading with valid/invalid shapes, fnmatch pattern matching against parameter names, per-pattern optimizer kwargs propagation (lr, weight_decay, betas, eps), and fallback to global learning rate for unmatched parameters. Tests accept both dict and YAML-path lr_config sources.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • NVIDIA/Model-Optimizer#1172: Both PRs modify modelopt/torch/opt/plugins/transformers.py's argument/dataclass infrastructure so that the YAML/CLI parsing groundwork added by #1172 is extended in this PR with new training-specific ModelOptTrainerArguments and ModelOptHFTrainer behavior.

Suggested reviewers

  • kevalmorabia97
  • h-guo18
  • Edwardf0t1
  • shengliangxu
  • jenchen13
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.05% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[2/N] Simplify KDTrainer and enhance ModelOptHFTrainer' directly reflects the main changes: KDTrainer simplification and ModelOptHFTrainer enhancements across multiple files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security violations: no unsafe torch/numpy.load, trust_remote_code from user parameter, no eval/exec on input, no # nosec comments, liger-kernel has BSD 2-CLAUSE license.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/new-qat-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma requested review from ChenhanYu and shengliangxu and removed request for a team April 7, 2026 21:40
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 7, 2026

Codecov Report

❌ Patch coverage is 88.12155% with 43 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.88%. Comparing base (115cae2) to head (0dde6c8).

Files with missing lines Patch % Lines
modelopt/torch/opt/plugins/transformers.py 83.25% 37 Missing ⚠️
modelopt/torch/distill/plugins/huggingface.py 96.42% 5 Missing ⚠️
...torch/quantization/plugins/transformers_trainer.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1191      +/-   ##
==========================================
+ Coverage   76.43%   76.88%   +0.45%     
==========================================
  Files         488      488              
  Lines       54115    54386     +271     
==========================================
+ Hits        41362    41817     +455     
+ Misses      12753    12569     -184     
Flag Coverage Δ
examples 42.75% <75.41%> (+0.68%) ⬆️
gpu 58.41% <19.88%> (-1.41%) ⬇️
regression 14.89% <19.88%> (-0.25%) ⬇️
unit 54.00% <58.83%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@realAsma realAsma force-pushed the asma/new-qat-2 branch 2 times, most recently from e349a90 to b762870 Compare April 7, 2026 22:36
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 97759a4 to bfc343c Compare April 8, 2026 16:03
@realAsma realAsma force-pushed the asma/new-qat-1 branch 3 times, most recently from cc45203 to 9dd1732 Compare April 9, 2026 18:48
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: This PR simplifies the HF knowledge distillation trainer by removing mtd.convert() class-swap in favor of explicit teacher forwarding, enhances ModelOptHFTrainer with Liger fused loss, per-parameter LRs, parameter freezing, and refactors the llm_qat example to use YAML configs with a new ModelOptArgParser.

Issues Found:

  1. [Correctness] CRITICAL: recipe.ptq_cfg does not exist — should be recipe.quantize
    ModelOptPTQRecipe (in modelopt/recipe/config.py:69) exposes its quant config via the quantize attribute, not ptq_cfg. This will raise AttributeError at runtime in two places:

    • examples/llm_qat/simple_qat_train.py:126model = mtq.quantize(model, recipe.ptq_cfg, calibrate)
    • modelopt/torch/quantization/plugins/transformers_trainer.py:217return recipe.ptq_cfg

    The correct usage is already in examples/llm_qat/quantize.py:77 (ptq_cfg = recipe.quantize), confirming this is a copy-paste error. Existing usage in examples/llm_ptq/hf_ptq.py also uses recipe.quantize.

  2. [Correctness] LMLogitsLoss.forward double-sums — returns scalar instead of per-token losses
    LogitsDistillationLoss.forward with reduction="none" already sums over the vocab dimension (line 64 of losses.py), returning shape (B*S,). The new LMLogitsLoss.forward then does another .sum(dim=-1) on this 1D tensor, collapsing it to a scalar. This makes the ignore-index masking in _standard_kd_loss a no-op — padding tokens contribute equally to the loss.

  3. [Correctness] Inconsistent causal shift between standard and Liger KD paths
    _liger_kd_loss applies the standard causal LM shift (hidden_states[..., :-1, :], labels[..., 1:]) before computing JSD. But _standard_kd_loss applies no shift — it computes KL-div on all B*S positions and masks with unshifted labels. While comparing student/teacher at the same position is valid, the different masking alignment means the two paths produce semantically different losses for the same input. This will be surprising when toggling --use_liger_kernel.

  4. [Correctness] _forward_redirect doesn't restore module.forward on failure
    If the module(dummy) call raises before entering wrapped_forward (e.g., FSDP pre-forward hook fails), module.forward remains patched to wrapped_forward. A try/finally would be safer.

  5. [Tests] Low coverage on core library files
    Codecov reports ~20% patch coverage for transformers.py (165 missing lines) and huggingface.py (82 missing lines). The new Liger fused loss path, _forward_redirect, _sharded_liger_compute, parameter freezing, and save_dtype rewriting have no unit test coverage. These are critical code paths for distributed training correctness.

  6. [Correctness] save_dtype defaults to "bfloat16" instead of preserving original model dtype
    The old QATTrainer saved the model's original dtype (self._original_dtype). The new ModelOptHFTrainer.save_model hardcodes save_dtype="bfloat16" by default. For models originally in float16, this silently changes the config.json dtype, which may affect downstream inference engines.

  7. [Readability] Misleading shape comments in LMLogitsLoss
    The comment # (B*S, V) on the super().forward() call is wrong — the parent returns (B*S,) when reduction="none". This directly led to bug #2.

Suggestions:

  • The liger-kernel>=0.5.0 addition to pyproject.toml [hf] extras makes it a hard install dependency for all HF users. Since usage is guarded by --use_liger_kernel, consider making it an optional extra ([liger] or [hf-liger]) to avoid install issues in constrained environments.
  • The pre-commit hook for generate-arguments-md uses language: system and runs python examples/llm_qat/train.py --generate_docs. This requires all modelopt dependencies to be installed in the pre-commit environment, which will fail for most contributors. Consider language: python with explicit dependencies, or making it a manual step.
  • _dataset_cache (module-level mutable dict in dataset_utils.py) acts as an in-memory cache but is never evicted. In long-running processes or notebooks, this could hold large datasets in memory indefinitely.

Overall Assessment: The architectural direction is sound — removing mtd.convert() class-swap in favor of explicit teacher forwarding is a significant simplification. The ModelOptArgParser and YAML-based config system is a good developer experience improvement. However, there are two critical correctness bugs (recipe.ptq_cfg and LMLogitsLoss double-sum) that must be fixed before merge, plus the KD loss path inconsistency warrants discussion.

Comment thread examples/llm_qat/simple_qat_train.py Outdated
Comment thread modelopt/torch/quantization/plugins/transformers_trainer.py Outdated
Comment thread modelopt/torch/distill/plugins/huggingface.py Outdated
Comment thread modelopt/torch/distill/plugins/huggingface.py
Comment thread modelopt/torch/distill/plugins/huggingface.py
Comment thread modelopt/torch/opt/plugins/transformers.py
def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None:
"""Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str."""
cfg_path = os.path.join(output_dir, "config.json")
if not os.path.isfile(cfg_path):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_dtype defaults to "bfloat16" which silently changes the dtype for float16 models. Consider defaulting to None and falling back to the model's original dtype when not explicitly set.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial fix landed in 1f1c2507 (dataclass default set to None). Follow-up fix staged locally: _update_config_json_dtype now early-returns when dtype_str is None, so the original model dtype written by super().save_model() is preserved. The getattr fallback in save_model is also aligned to None for consistency. Will be included in the next push.

Comment thread .pre-commit-config.yaml
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from 4472470 to 1088304 Compare April 14, 2026 14:12
@realAsma realAsma force-pushed the asma/new-qat-1 branch 2 times, most recently from b93b8d3 to 2a23765 Compare June 2, 2026 22:22
Base automatically changed from asma/new-qat-1 to main June 3, 2026 01:25
Comment thread examples/llm_qat/ARGUMENTS.md
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 3, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-06-05 15:47 UTC

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 4

🧹 Nitpick comments (3)
modelopt/torch/distill/plugins/huggingface.py (3)

16-34: 💤 Low value

Missing __all__ for public API definition.

Per coding guidelines, define the public API with __all__ at the top of each Python module. This module exports DistillArguments, DistillArgsWithTeacherModel, KDTrainer, and IGNORE_INDEX.

Suggested addition after imports
 IGNORE_INDEX = nn.CrossEntropyLoss().ignore_index
+
+__all__ = [
+    "IGNORE_INDEX",
+    "DistillArguments",
+    "DistillArgsWithTeacherModel",
+    "KDTrainer",
+]

As per coding guidelines: "Define the public API with __all__ at the top of each Python module."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/distill/plugins/huggingface.py` around lines 16 - 34, Add a
module-level __all__ to explicitly declare the public API: include
"DistillArguments", "DistillArgsWithTeacherModel", "KDTrainer", and
"IGNORE_INDEX". Place the __all__ definition near the top of the file after the
imports (above IGNORE_INDEX) so static analysis and consumers know which names
are exported, and ensure the string names exactly match the classes/variables
defined in this module.

110-110: ⚡ Quick win

Replace assert with ValueError for runtime validation.

assert statements can be stripped in optimized Python (-O flag), making this check ineffective. Use explicit raise for required arguments.

Suggested fix
-        assert distill_args is not None, "`distill_args` is required for distillation."
+        if distill_args is None:
+            raise ValueError("`distill_args` is required for distillation.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/distill/plugins/huggingface.py` at line 110, The assertion
"assert distill_args is not None, '`distill_args` is required for
distillation.'" should be replaced with an explicit runtime check that raises a
ValueError so it can't be bypassed with Python -O; locate the check around the
distillation setup (the `distill_args` validation) and change it to: if
distill_args is None: raise ValueError("`distill_args` is required for
distillation.") ensuring the message matches the original assertion text.

124-124: ⚡ Quick win

Replace assert with ValueError for teacher model validation.

Same concern as above—assert can be optimized away.

Suggested fix
-        assert teacher is not None, "`distill_args.teacher_model` is required."
+        if teacher is None:
+            raise ValueError("`distill_args.teacher_model` is required.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/distill/plugins/huggingface.py` at line 124, Replace the
runtime-unsafe assert in the huggingface plugin with an explicit exception:
where the code currently does `assert teacher is not None,
"...distill_args.teacher_model..."`, change it to raise a ValueError when the
teacher model is missing (e.g., check `if teacher is None: raise
ValueError("`distill_args.teacher_model` is required.")`). Update the check in
the same scope that references teacher/distill_args.teacher_model to ensure
deterministic validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/llm_qat/arguments.py`:
- Around line 97-100: The option use_liger_kernel currently defaults to True
which can cause an ImportError if the liger-kernel package is not installed;
change the field in examples/llm_qat/arguments.py to default=False, or add a
guarded fallback in the code that uses it (in the Transformers plugin where
_liger_loss_func is defined) by wrapping the import of liger_kernel in a
try/except and disabling/ignoring use_liger_kernel when the import fails so
training continues safely without that dependency.

In `@examples/llm_qat/llama_factory/llama_factory.py`:
- Around line 216-219: The teacher model is loaded without honoring
trust_remote_code; update the call to
transformers.AutoModelForCausalLM.from_pretrained that creates teacher_model so
it passes trust_remote_code=model_args.trust_remote_code (or the equivalent
flag), and ensure model_args is accessible in that scope (e.g., capture
model_args in the closure or pass it into CustomTrainer.__init__) before
assigning modelopt_trainer_args["distill_args"] = {"teacher_model":
teacher_model}.

In `@examples/llm_qat/quantize.py`:
- Around line 74-79: The call to
transformers.AutoModelForCausalLM.from_pretrained is passing
dtype=torch.bfloat16 which is incorrect for HF; change the kwarg name to
torch_dtype (i.e., use torch_dtype=torch.bfloat16) in the model =
transformers.AutoModelForCausalLM.from_pretrained(...) invocation so the model
loads with the intended precision; update the same pattern wherever
from_pretrained is used (e.g., in quantize.py model creation and parallel
occurrences).

In `@examples/llm_qat/train.py`:
- Around line 74-78: The student model instantiation uses the wrong/
inconsistent kwarg name `dtype=torch.bfloat16`; update the call to
transformers.AutoModelForCausalLM.from_pretrained(...) to use
`torch_dtype=torch.bfloat16` (matching the teacher model usage) so the
HuggingFace loader receives the correct parameter; adjust the call that builds
`model` (using `model_args.model_name_or_path` and `model_kwargs`) to replace
`dtype` with `torch_dtype`.

---

Nitpick comments:
In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 16-34: Add a module-level __all__ to explicitly declare the public
API: include "DistillArguments", "DistillArgsWithTeacherModel", "KDTrainer", and
"IGNORE_INDEX". Place the __all__ definition near the top of the file after the
imports (above IGNORE_INDEX) so static analysis and consumers know which names
are exported, and ensure the string names exactly match the classes/variables
defined in this module.
- Line 110: The assertion "assert distill_args is not None, '`distill_args` is
required for distillation.'" should be replaced with an explicit runtime check
that raises a ValueError so it can't be bypassed with Python -O; locate the
check around the distillation setup (the `distill_args` validation) and change
it to: if distill_args is None: raise ValueError("`distill_args` is required for
distillation.") ensuring the message matches the original assertion text.
- Line 124: Replace the runtime-unsafe assert in the huggingface plugin with an
explicit exception: where the code currently does `assert teacher is not None,
"...distill_args.teacher_model..."`, change it to raise a ValueError when the
teacher model is missing (e.g., check `if teacher is None: raise
ValueError("`distill_args.teacher_model` is required.")`). Update the check in
the same scope that references teacher/distill_args.teacher_model to ensure
deterministic validation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: eb7e9688-e418-42db-a881-c765b8eb72cc

📥 Commits

Reviewing files that changed from the base of the PR and between 862ed5e and 8a4091a.

📒 Files selected for processing (20)
  • CHANGELOG.rst
  • examples/llm_distill/main.py
  • examples/llm_qat/ARGUMENTS.md
  • examples/llm_qat/README.md
  • examples/llm_qat/arguments.py
  • examples/llm_qat/configs/train/finetune.yaml
  • examples/llm_qat/configs/train/lr_config_example.yaml
  • examples/llm_qat/configs/train/qad_nvfp4.yaml
  • examples/llm_qat/configs/train/qat_nvfp4.yaml
  • examples/llm_qat/configs/train/qlora_nvfp4.yaml
  • examples/llm_qat/llama_factory/llama_factory.py
  • examples/llm_qat/quantize.py
  • examples/llm_qat/requirements.txt
  • examples/llm_qat/train.py
  • modelopt/torch/distill/plugins/huggingface.py
  • modelopt/torch/opt/plugins/transformers.py
  • modelopt/torch/quantization/plugins/transformers_trainer.py
  • pyproject.toml
  • tests/examples/llm_qat/test_llm_qat.py
  • tests/unit/torch/opt/plugins/test_lr_config.py

Comment thread examples/llm_qat/arguments.py
Comment thread examples/llm_qat/llama_factory/llama_factory.py
Comment thread examples/llm_qat/quantize.py
Comment thread examples/llm_qat/train.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
examples/llm_qat/train.py (1)

109-113: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Incorrect parameter name: dtype should be torch_dtype.

The HuggingFace AutoModelForCausalLM.from_pretrained() API expects torch_dtype, not dtype. While the student model at line 76 has the same issue (flagged in a prior review), this teacher model loading at line 111 is newly changed and uses the same incorrect parameter name.

Suggested fix
     teacher = transformers.AutoModelForCausalLM.from_pretrained(
         distill_args.teacher_model,
-        dtype=torch.bfloat16,
+        torch_dtype=torch.bfloat16,
         **model_kwargs,
     )
transformers AutoModelForCausalLM.from_pretrained torch_dtype parameter
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_qat/train.py` around lines 109 - 113, The call to
transformers.AutoModelForCausalLM.from_pretrained for the teacher model
(variable teacher) is using the wrong parameter name `dtype`; change it to
`torch_dtype` so the loader receives the correct argument (i.e., replace
`dtype=torch.bfloat16` with `torch_dtype=torch.bfloat16`); also search for other
from_pretrained calls (e.g., the student model) and make the same change to keep
parameter names consistent.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@examples/llm_qat/train.py`:
- Around line 109-113: The call to
transformers.AutoModelForCausalLM.from_pretrained for the teacher model
(variable teacher) is using the wrong parameter name `dtype`; change it to
`torch_dtype` so the loader receives the correct argument (i.e., replace
`dtype=torch.bfloat16` with `torch_dtype=torch.bfloat16`); also search for other
from_pretrained calls (e.g., the student model) and make the same change to keep
parameter names consistent.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4b065c1a-7299-4705-b4f9-ad850c90eed1

📥 Commits

Reviewing files that changed from the base of the PR and between 8a4091a and c27a7fa.

📒 Files selected for processing (3)
  • examples/llm_qat/train.py
  • modelopt/torch/distill/plugins/huggingface.py
  • tests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/examples/llm_qat/test_llm_qat.py

@realAsma realAsma requested a review from AAnoosheh June 3, 2026 23:25
Comment thread pyproject.toml Outdated
@realAsma realAsma requested a review from a team June 4, 2026 22:13
@realAsma realAsma enabled auto-merge (squash) June 4, 2026 22:23
Copy link
Copy Markdown
Collaborator

@shengliangxu shengliangxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: most prior critical comments are resolved in this revision — recipe.ptq_cfgrecipe.quantize migration via resolve_quant_cfg_from_args (fixed), LMLogitsLoss double-sum (class removed entirely, KD now calls LogitsDistillationLoss directly), causal-shift inconsistency between standard and Liger KD (_standard_kd_loss now shifts to [..., :-1, :] to match _liger_kd_loss), _forward_redirect failure cleanup (now wrapped in try/except), CE validation loss (compute_loss now runs CE in eval and records KD as eval_kd_loss secondary metric), and unit tests added for both KD and lr_config. Two prior LGTMs (mxinO, shengliangxu).

Flagging for human sign-off because of one outstanding behavior change and the breadth of the PR:

  • 💬 save_dtype default — author replied: "Partial fix landed in 1f1c250 (dataclass default set to None). Follow-up fix staged locally… Will be included in the next push." As of the latest diff, ModelOptTrainerArguments.save_dtype still defaults to "bfloat16" (modelopt/torch/opt/plugins/transformers.py L244), so float16 models that go through ModelOptHFTrainer.save_model will silently have their config.json dtype rewritten to bfloat16. The promised follow-up to default to None and preserve the original dtype hasn't landed yet — please confirm before merge.
  • Architectural surface: this PR introduces ModelOptTrainerArguments (new public API), per-parameter LR YAML config, _forward_redirect, model-agnostic Liger fused-CE/JSD loss with DDP/ZeRO-3 support, and removes mtd.convert() from the HF KD path. The PR body justifies the design choices (model-agnostic vs HF's fixed-architecture Liger, ZeRO-3 vs FSDP-only, lazy teacher loading bug fix), and the lr_config is a thin pattern→kwargs map rather than a new DSL — design looks reasonable, but the public-API breadth and the KDTrainer semantics shift (no more DistillationModel, no intermediate-layer distillation on the HF path) warrant human eyes. GPU example tests (tests/examples/llm_qat/, tests/examples/llm_distill/) are still unchecked in the PR's own test plan.
  • Minor: _record_eval_kd_loss weights each batch's KD loss by find_batch_size(inputs) rather than non-ignore token count, so the reported eval_kd_loss is a batch-weighted mean that can drift from a true token-weighted mean when sequence padding varies — likely fine for monitoring but worth a comment.

@realAsma
Copy link
Copy Markdown
Contributor Author

realAsma commented Jun 4, 2026

/claude review

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: all prior critical correctness comments are resolved in this revision — recipe.ptq_cfgrecipe.quantize (both simple_qat_train.py and transformers_trainer.py); LMLogitsLoss double-sum (class removed; KD now calls LogitsDistillationLoss directly); causal-shift inconsistency between standard/Liger KD (_standard_kd_loss now shifts to [..., :-1, :]); _forward_redirect failure cleanup (try/except restores module.forward and re-raises); CE validation loss (compute_loss runs CE in eval and records KD as eval_kd_loss secondary metric); unit tests added for KD and lr_config. Two prior LGTMs (mxinO, shengliangxu).

Flagging for human sign-off because of architectural breadth and unchecked example tests:

  • 💬 save_dtype default — author replied: "Partial fix landed in 1f1c250; follow-up will be in next push." Now fully resolved: the save_dtype dataclass field is gone entirely, and ModelOptHFTrainer.__init__ derives _original_dtype from model.config.dtype / torch_dtype; _update_config_json_dtype early-returns on None, so float16 models are no longer silently rewritten to bfloat16.
  • Architectural surface: this PR introduces ModelOptTrainerArguments (new public API), per-parameter LR YAML config, _forward_redirect, and model-agnostic Liger fused-CE/JSD loss with DDP/ZeRO-3 support, and removes mtd.convert() from the HF KD path (DistillationModel/intermediate-layer distillation no longer available on HF path). The PR body justifies the design choices (model-agnostic vs HF's fixed-architecture Liger, ZeRO-3 vs FSDP-only, lazy teacher loading bug fix), and lr_config is a thin pattern→kwargs map rather than a new DSL — design looks reasonable, but the public-API breadth and the KDTrainer semantic shift warrant human eyes.
  • GPU example tests (tests/examples/llm_qat/, tests/examples/llm_distill/) remain unchecked in the PR's own test plan checkboxes — please confirm before merge.
  • Minor: _record_eval_kd_loss weights each batch's KD loss by find_batch_size(inputs) rather than non-ignore token count, so reported eval_kd_loss is batch-weighted and can drift from a token-weighted mean when sequence padding varies — likely fine for monitoring but worth a comment.

realAsma added 3 commits June 5, 2026 14:36
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma merged commit 433b549 into main Jun 5, 2026
52 checks passed
@realAsma realAsma deleted the asma/new-qat-2 branch June 5, 2026 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants