Skip to content

feat(dllm): add DFlash and LLaDA2 SFT recipes#2214

Open
kashif wants to merge 24 commits into
NVIDIA-NeMo:mainfrom
kashif:add-dllm-pipelines
Open

feat(dllm): add DFlash and LLaDA2 SFT recipes#2214
kashif wants to merge 24 commits into
NVIDIA-NeMo:mainfrom
kashif:add-dllm-pipelines

Conversation

@kashif
Copy link
Copy Markdown

@kashif kashif commented May 12, 2026

What does this PR do ?

add DFlash and LLaDA2 SFT recipes.

Changelog

  • Add specific line by line info of high level changes in this PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

kashif and others added 3 commits May 12, 2026 16:17
- Split dflash YAML section: add required dllm: section alongside dflash:
  (DiffusionLMSFTRecipe.setup() reads dllm:, DFlashSFTRecipe reads dflash:)
- Auto-resolve mask_token_id by adding <|MASK|> special token to tokenizer
  when neither YAML nor base tokenizer (Qwen3) defines one
- Smoke test: use allenai/tulu-3-sft-mixture[:64] (ChatDataset requires
  OpenAI-format messages; wikitext plain text didn't qualify)
- Add standalone GPU test script (no torchao dependency)

Smoke run: 2 steps, loss 20.4→28.3, grad_norm flowing, mem 12.65 GiB ✓

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- main(): remove type hints from signature, rename recipe → trainer
- _run_train_optim_step(): remove return type and batches type hints
- _run_validation_epoch(): remove return type hint
- log_train_metrics(): remove return type hint; switch % → .format();
  add tps_per_gpu and mode to log line to match parent format
- MetricsSample: add Train/mfu (None), Train/supervised_tokens,
  change hardcoded "dflash" → self.dllm_mode; allreduce num_predicted_tokens

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move all DFlash-specific training logic into DFlashStrategy so that
DiffusionLMSFTRecipe handles all dLLM model families without subclassing.

Strategy additions (strategy.py):
- setup_extra(recipe): hook for loading auxiliary models; DFlashStrategy
  loads+freezes target LM and resolves mask_token_id for tokenizers that
  have none (e.g. Qwen3 → adds <|MASK|> special token)
- pre_step(recipe, batches) → (noise_tokens, supervised_tokens):
  MDLM does corruption loop; DFlash does target forwards + offloads to CPU
- forward_backward(recipe, idx, batch, ...): MDLM delegates to existing
  _forward_backward_step; DFlash implements anchor-block draft forward
- loss_log_key property: "Loss/Train_DLLM" / "Loss/Train_DFlash"
- _build_target_layer_ids, _sample_anchor_block, _run_target_forward
  moved from DFlashSFTRecipe into DFlashStrategy

Recipe changes (train_ft.py):
- setup(): defer mask_token_id raise to after setup_extra() call
- _run_train_optim_step(): replace inline corruption loop with
  strategy.pre_step(); dispatch via strategy.forward_backward()
- _run_validation_epoch(): same pre_step + forward_backward dispatch
- log_train_metrics(): use strategy.loss_log_key instead of hardcoded key

train_dflash.py reduced to a 6-line entry-point shim pointing at
DiffusionLMSFTRecipe; DFlashSFTRecipe subclass removed entirely.

Smoke tested: DFlash 2-step run passes (loss 20.4→28.3, grads flowing).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is not required. We can reuse: https://github.com/NVIDIA-NeMo/Automodel/blob/main/examples/dllm_sft/finetune.py to launch dflash SFT. It will be good to maintain a single entry point

@pthombre
Copy link
Copy Markdown
Contributor

/claude review

Comment on lines +462 to +466
loss_result = self.dflash_loss_fn(
logits=logits,
target_ids=block_targets,
block_mask=block_mask,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: DFlashDecayLoss.forward is called without num_tokens, so the loss is a raw weighted sum (not an average). In contrast, the MDLM path passes num_diffusion_tokens to its loss function for global normalization (see train_ft.py:278), and the inlined version in scripts/test_dflash_sft_gpu.py normalizes by weights.sum().

Without normalization, the gradient magnitude scales with the number of valid block positions per microbatch, which will cause incorrect gradient scaling under gradient accumulation and varying batch sizes.

Should this pass num_tokens=num_diffusion_tokens?

Suggested change
loss_result = self.dflash_loss_fn(
logits=logits,
target_ids=block_targets,
block_mask=block_mask,
)
loss_result = self.dflash_loss_fn(
logits=logits,
target_ids=block_targets,
block_mask=block_mask,
num_tokens=num_diffusion_tokens,
)

Comment thread scripts/test_dflash_sft_gpu.py Outdated
Comment on lines +66 to +71
def main():
print(f"Loading tokenizer from {TARGET_ID} ...")
tok = AutoTokenizer.from_pretrained(TARGET_ID)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This inlined DFlashDecayLoss normalizes by weights.sum().clamp_min(1e-8), but the real DFlashDecayLoss in components/loss/dflash_loss.py does not normalize when num_tokens is None — it returns a raw weighted sum. This means the test is validating different behavior than what the real training code produces.

Consider importing or exactly matching the real loss logic so the test actually exercises the production code path.

Comment thread scripts/test_dflash_sft_gpu.py Outdated
"""Minimal GPU smoke test for DFlashSFTRecipe core logic.

Tests the three novel pieces without the full distributed recipe stack:
1. Target forward + hidden-state extraction
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: docstring says "Tests the three novel pieces" but lists four items (1–4).

Suggested change
1. Target forward + hidden-state extraction
Tests the four novel pieces without the full distributed recipe stack:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can be merged into nemo_automodel/components/loss/dllm_loss.py

kashif added 7 commits May 13, 2026 18:22
- Remove nemo_automodel/recipes/dllm/train_dflash.py; use the existing
  examples/dllm_sft/finetune.py as the single entry point for all dLLM
  SFT modes including DFlash (pthombre)
- Update dflash_sft.yaml: switch recipe key to DiffusionLMSFTRecipe,
  update usage comment to point at finetune.py
- Pass num_tokens=num_diffusion_tokens to DFlashDecayLoss in
  DFlashStrategy.forward_backward so the loss is properly normalised
  across DP replicas and gradient-accumulation steps (claude review)
- Fix test_dflash_sft_gpu.py: import and use the real DFlashDecayLoss
  instead of an inlined variant with different normalization; pass
  num_tokens so the test exercises the same code path as production;
  fix docstring ("three" -> "four" novel pieces) (claude review)
- Merge DFlashDecayLoss into dllm_loss.py; delete dflash_loss.py
- DFlashStrategy: add _sample_anchor_blocks (stars-and-bars sampling of N
  non-overlapping anchors) and _build_block_attention_mask (sparse 4D mask
  where block b attends only to its own causal prefix and own noise positions)
- DFlashDecayLoss: add block_size param so per-block decay resets at each
  block boundary for N>1 training
- pre_step dispatches to multi-block path when num_blocks_per_sample > 1;
  position_ids use actual sequence positions for correct RoPE
- dflash_sft.yaml: num_blocks_per_sample: 1 (explicit default)
- dflash_smoke.yaml: num_blocks_per_sample: 4; fix recipe name
@kashif kashif requested a review from jgerh as a code owner May 13, 2026 19:03
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-customer Waiting on the original author to respond label May 13, 2026
Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tech pubs review of docs/guides/dllm/finetune.md and provided a few copyedits.

Comment thread docs/guides/dllm/finetune.md Outdated
Comment thread docs/guides/dllm/finetune.md Outdated
Comment thread docs/guides/dllm/finetune.md Outdated
Comment thread docs/guides/dllm/finetune.md Outdated
Comment thread examples/dllm_sft/llada2_sft.yaml Outdated
# Set eps > 0 to ensure at least some corruption every step.
dllm:
mode: mdlm
mask_token_id: 126336
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like llada2 has a different mask token id

@pthombre
Copy link
Copy Markdown
Contributor

/ok to test 23e8c22

# Pre-step: anchor-block sampling + target forwards
# ------------------------------------------------------------------

def _sample_anchor_block(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need loss mask here?

kashif and others added 9 commits May 14, 2026 10:18
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
…model/mask token

- AND block_mask with loss_mask in _sample_anchor_block and _sample_anchor_blocks
  so prompt tokens are not supervised during DFlash SFT
- Pre_step now reads loss_mask from the batch and passes it through
- Fix llada2_sft.yaml: correct model to inclusionAI/LLaDA2.1-mini and
  update mask_token_id from 126336 to 156895 (<|mask|> in Qwen tokenizer)
- Update docs: note different mask_token_id values for LLaDA vs LLaDA2.1
- Add test_loss_mask_zeros_block_mask to cover the new loss_mask path
…draft init script

- Use `dtype=` instead of deprecated `torch_dtype=` in AutoModelForCausalLM.from_pretrained
- Add `trust_remote_code=True` to support Nemotron-H hybrid SSM target models
- Add scripts/create_nemotron_nano30b_dflash_draft.py to initialise a 7-layer
  DFlash draft (hidden=2688, 21 Q-heads, 3 KV-heads) for NVIDIA-Nemotron-3-Nano-30B-A3B
7-layer draft (hidden=2688, 21 Q-heads, 3 KV-heads, head_dim=128) conditioned
on frozen NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 target, trained on
Nemotron-Post-Training-Dataset-v2 chat+math+code mix for 3 epochs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants