Skip to content

fix(eagle3): flush trailing partial grad-accum window each epoch#2257

Merged
HuiyingLi merged 2 commits into
NVIDIA-NeMo:mainfrom
khazic:fix/eagle3-flush-trailing-grad-accum
May 19, 2026
Merged

fix(eagle3): flush trailing partial grad-accum window each epoch#2257
HuiyingLi merged 2 commits into
NVIDIA-NeMo:mainfrom
khazic:fix/eagle3-flush-trailing-grad-accum

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 18, 2026

Summary

When num_batches_per_epoch is not a multiple of grad_accumulation_steps, the EAGLE-3 recipe's training loop silently dropped the trailing micro-batches: their gradients reached backward() but never an optimizer.step() — the next epoch's zero_grad wiped them. Up to grad_accumulation_steps - 1 micro-batches per epoch were wasted, and the LR scheduler (sized via floor division) was off-by-N over a full run.

Two coordinated fixes in recipes/llm/train_eagle3.py:

  1. Ceil-divide optimizer steps via new _optim_steps_per_epoch helper so the LR scheduler covers the trailing flush and progress does not saturate at min_lr_ratio prematurely.
  2. Flush the partial window at the end of each epoch: rescale its gradients by grad_accumulation_steps / pending_micro_batches (every micro-batch had divided its loss by the full accumulation count, expecting a full window) and run one final clip_grad_norm_ / optimizer.step / lr_scheduler.step / zero_grad so the trailing step's update magnitude matches every other step.

Test plan

  • Added tests/unit_tests/recipes/llm/test_train_eagle3_grad_accum.py exercising _optim_steps_per_epoch on divisible / non-divisible / degenerate inputs.
  • Local: pytest tests/unit_tests/recipes/llm/test_train_eagle3_grad_accum.py tests/unit_tests/speculative/ — 32 passed, 2 skipped (FA2 CUDA path).

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 62f21a7

When ``num_batches_per_epoch`` is not a multiple of
``grad_accumulation_steps``, the EAGLE-3 recipe's training loop
silently dropped the trailing micro-batches: their gradients reached
``backward()`` but never an ``optimizer.step()`` -- the next epoch's
``zero_grad`` wiped them. Up to ``grad_accumulation_steps - 1``
micro-batches per epoch were wasted, and the LR scheduler (sized via
floor division) was off-by-N at the end of training.

Two coordinated fixes in ``recipes/llm/train_eagle3.py``:

1. Compute optimizer steps with ceil division (new
   ``_optim_steps_per_epoch`` helper) so the LR scheduler covers the
   trailing flush and ``progress`` does not saturate at
   ``min_lr_ratio`` prematurely.

2. After the inner per-batch loop exits, if there is a partially-filled
   accumulation window, rescale its gradients by
   ``grad_accumulation_steps / pending_micro_batches`` (each
   micro-batch had divided its loss by the full accumulation count
   anticipating a full window) and run one final
   ``clip_grad_norm_`` / ``optimizer.step`` / ``lr_scheduler.step`` /
   ``zero_grad`` so the magnitude is comparable to a normal step.

Tests:
- ``tests/unit_tests/recipes/llm/test_train_eagle3_grad_accum.py``
  exercises ``_optim_steps_per_epoch`` on divisible / non-divisible /
  degenerate inputs.

Signed-off-by: khazic <khazzz1c@gmail.com>
@khazic khazic force-pushed the fix/eagle3-flush-trailing-grad-accum branch from 62f21a7 to 70d2932 Compare May 19, 2026 05:36
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 70d2932

…nternals

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 9b5e95d

@HuiyingLi HuiyingLi merged commit 9fba3ae into NVIDIA-NeMo:main May 19, 2026
65 checks passed
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants