[Discrete Diffusion] Add LLaDA2 pipeline#13226
Conversation
…usion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included
Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__.
12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output).
- Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
- Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat
LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output.
…ment.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names
- Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs
- threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models).
LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None.
- top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230)
…put_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True.
…ings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype=
yiyixuxu
left a comment
There was a problem hiding this comment.
Thanks for working on this!
My main concern is we might be over-generalizing it a bit: currently there is a 3-layer hierarchy (DiscreteDiffusionPipelineMixin -> BlockRefinementPipeline -L LLaDA2Pipeline). I think it is at least one layer too much for diffusion style. i.e., BlockRefinementPipeline should be combined with LLaDA2Pipeline. we do use a Mixin class to abstract away common pipeline methods, but since this is the first model we support, I also would prefer to wait on that
Does it maks sense to flatten this into a single LLaDA2Pipeline(DiffusionPipeline) for now? we can look into refactoring later we more discrete diffusion model arrives
We can also support it in modular pipelines instead, this way you can avoid duplicated code and it works better with remote code. We have a skill coming soon can help with that. But for this PR, I think we can just merge one first?
|
Thanks @yiyixuxu for the suggestion! Fixed. |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks!
i left some feedbacks!
| self.fusing_vae = False | ||
|
|
||
|
|
||
| class DiscreteDiffusionPipelineMixin: |
There was a problem hiding this comment.
Can we remove this for now and move the methods to LLaDA2Pipeline?
Since right now there is only one ppeline so don't really need a Mixin. we can extracting a mixin later when we have more pipelines and see a common pattern - the refactor is easy and won't break anything.
| def num_timesteps(self): | ||
| return self._num_timesteps | ||
|
|
||
| def _model_forward_logits( |
There was a problem hiding this comment.
this is only used once - can you remove the method and put the cod directly inside call?
I think we should remove try/except and maybe the attention_mask_mode argument altogether - the type of attention mask is determined by the model, no?
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
| # 2D attention mask (no padding) — the model handles backend-specific conversion internally. | ||
| attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long) |
There was a problem hiding this comment.
If I try to run the example script from the docs, I get the following error from the LLaDA-2.1 MoE modeling code:
Traceback (most recent call last):
...
File "~/.cache/modules/transformers_modules/inclusionAI/LLaDA2_dot_1_hyphen_mini/f21be037104f6e044e1a86b6d8864a6b85cc868e/modeling_llada2_moe.py", line 874, in forward
raise ValueError(
ValueError: LLaDA2.0 only support block attention mask with shape: (1, 1, 32, 32), the input attention with shape attention_mask.size()=torch.Size([1, 32])!
From the error it seems like the original modeling code will only accept 4D (rather than 2D) attention masks?
dg845
left a comment
There was a problem hiding this comment.
I think the PR is looking good modulo the following issues:
- I think the LLaDA-2.X transformer should have a
diffusers-native implementation. In particular, users should not need to settrust_remote_code=Trueto run the pipeline. @yiyixuxu, could you give a second opinion on this? - The attention mask shape error in #13226 (comment).
|
@dg845 converting the models to diffusers will be quite a lot more work and all the d-llms currently use transformers modeling files for their backbone, so yeah i wanted to first judge how useful this will be in diffusers and then see moving forward? Regarding the attention shape error, their modeling files are still stuck with some transformers 4 api and I have fixed that for them, so in the meantime, test with |
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.