Fix the QwenImage Attention mask under Ulysses SP#13756
Conversation
|
@sayakpaul, I can make this fix more generic by addressing Here’s my suggestion:
However, this approach requires handling masks in attention layer locally (and with extra communication cost) and may revert some of the performance improvements introduced in #12702. So do you have any suggestions? |
|
Could you explain a bit why the actual mask is Position: 0 1 2 3 | 4 5 | 6 7 8 9 | 10 11
Mask: 1 1 0 0 | 1 1 | 0 0 0 0 | 1 1 |
since the original input is (T: token for text; I: Token for image) after all-to-all so the mask should be in the correct position matching the content, avoid the wrong attention of KV |
|
I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean. |
|
@claude-2-serge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
The fix correctly identifies that under Ulysses sequence parallelism, the attention mask must be interleaved to match the post-all-to-all token layout. The approach of chunking text and image masks separately and interleaving them is sound.
Potential Issue
_parallel_confignot declared on the transformer model class:getattr(self, "_parallel_config", None)at line 919 referencesselfwhich is theQwenImageTransformer2DModelinstance. However,_parallel_config = Noneis only declared onQwenDoubleStreamAttnProcessor2_0(line 480), not on the transformer model class itself. Other transformer models (e.g.,transformer_flux.py) declare_parallel_configon both the processor and the model class. If the parallel setup code sets_parallel_configon the model viasetattr, this works — but it would be more robust and consistent to also declare_parallel_config = Noneas a class attribute onQwenImageTransformer2DModel.
Style
- Trailing whitespace on line 926 (
else:).
Tests
- The PR description references issue #13696 and claims the test should pass, but no new tests are added in this diff. It would be good to confirm the existing test coverage is sufficient.
16 LLM turns · 20 tool calls · 92.9s · 217527 in / 3766 out tokens
Yeah, it would be better to handle via the CP utilities and keep the forward as clean as possible. |
|
Hi @sayakpaul and @DN6 , I've updated the PR using local mask handling. This is what happens now: Inputs (T for text token, I for image token): After CP shard (assume 2 ranks) After QKV All-to-all and mask all-gather The KV and mask should have the one-to-one correspondence now. |
|
@claude-2-serge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
This PR correctly fixes the attention mask handling under Ulysses sequence parallelism for QwenImage by moving joint mask construction from the transformer's forward (where it was built once from unsplit tensors) into the per-block attention processor (where it's built from already-sharded encoder_hidden_states_mask and hidden_states). The attention dispatch layers then all-gather the local mask to match the post-all-to-all QKV layout.
Correctness
- The core fix is sound: under CP,
encoder_hidden_states_maskis now split via the_cp_plan(dim=1), passed into each block, and the processor builds the joint mask from the sharded text mask + sharded image ones. The all-gather inTemplatedUlyssesAttentionandTemplatedUlyssesAnythingAttentionthen reconstructs the global mask in the correct rank-concatenated order matching the QKV layout. - The "anything" path correctly handles unequal local sizes with pad → all-gather → trim.
- The guard
attn_mask.shape[-1] == S_KV_LOCALcorrectly identifies masks that need gathering (local masks) vs. masks that are already global. F.padon a bool mask defaults to padding withFalse, which is correct (padded positions should be masked out before trimming).
Minor issues
- The type annotation
attention_mask: None = Noneis technically valid but unconventional and could confuse tooling/users. A more standard approach would be to keep the original type hint and rely solely on the runtimeValueError. - The mask is reconstructed from scratch in every block (cat + unsqueeze), which adds minor overhead compared to the previous approach of building it once. This is the necessary trade-off for correctness under CP, but worth noting.
Suggestions / additional info (dead code trace)
Under the default pipeline call path, the encoder_hidden_states_mask flows correctly from the transformer forward → block → processor → dispatch_attention_fn. The attention_mask parameter on the processor is now effectively dead (always None from external callers, raises if not), which is the intended design — the processor owns mask construction.
23 LLM turns · 26 tool calls · 150.3s · 485476 in / 5617 out tokens
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Done. Added an accuracy test under |
|
|
||
| # Construct joint attention mask once to avoid reconstructing in every block | ||
| # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility | ||
| block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} |
There was a problem hiding this comment.
cc @kashif here since this would revert the optimization as part of this PR #12702
I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP
There was a problem hiding this comment.
thanks @yiyixuxu! the #12702 bit was just me building the joint mask once instead of per-block. "eliminates 60 GPU syncs" was a bad comment on my part, i checked and it's actually 0 syncs, just a plain cat/ones. cost of dropping it is ~0.85ms/fwd eager and basically 0 with compile, so no real loss.
and yeah it has to go for CP anyway: the pre-built mask is over the full unsharded seq so it can't line up after the all-to-all. confirmed with the #13696 repro, main is off by 2.9e-2 and this PR is exactly 0.0. lgtm from me 👍
|
Hi @sayakpaul @yiyixuxu , may I know if there is any concern for this PR? |
|
We're waiting for @kashif for #13756 (comment) |
|
@zhtmike one heads up: the new correctness test actually passes on main as-is, so it won't catch the regression. if you bump it to batch_size=2 with per-sample padding like in #13696 (e.g. |
|
Hi @kashif , thank you for your valuable comments! In my local testing environment, I just used the latest main with two files cherry-picked from the current PR.
and ran the following test: pytest -rxXs tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAnd I got the following error ============================= test session starts ==============================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /scratch/fq9hpsac/mikecheung/gitlocal/diffusers
configfile: pyproject.toml
plugins: timeout-2.4.0, xdist-3.8.0, anyio-4.12.1, hydra-core-1.3.2, requests-mock-1.10.0
collected 12 items
tests/models/transformers/test_models_transformer_qwenimage.py .s.s.s.ss [ 75%]
.Fs [100%]
=================================== FAILURES ===================================
_ TestQwenImageTransformerContextParallel.test_context_parallel_output_correctness[ulysses] _
self = <tests.models.transformers.test_models_transformer_qwenimage.TestQwenImageTransformerContextParallel object at 0x1550c59ab710>
cp_type = 'ulysses_degree', batch_size = 1
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1):
"""Verify that CP output is numerically identical to a single-GPU reference forward pass."""
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
# Single-GPU reference
model = self.model_class(**init_dict).eval().to(torch_device)
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
with torch.no_grad():
ref_output = model(**inputs_dict, return_dict=False)[0].cpu()
# Context-parallel run with the same weights
inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_context_parallel_correctness_worker,
args=(world_size, master_port, self.model_class, init_dict, state_dict, cp_dict, inputs_cpu, return_dict),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Context parallel correctness check failed: {return_dict.get('error', 'Unknown error')}"
)
cp_output = torch.tensor(return_dict["output"])
> torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 251 / 256 (98.0%)
E Greatest absolute difference: 0.0255483016371727 at index (0, 9, 3) (up to 0.0001 allowed)
E Greatest relative difference: 1.582349419593811 at index (0, 11, 3) (up to 0.0001 allowed)
tests/models/testing_utils/parallelism.py:461: AssertionErrorSo, in my testing environment, the newly added test correctly guards against the error in #13696. And I think the error should be irrelevant to the batch size, so I prefer to keep it simple with batch_size = 1. Can you please take a look to see if I have missed something? |
|
no you are right! @zhtmike I was not in the main branch 🙈 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Kindly ping @sayakpaul |
|
We should fix the tests |
|
The two failed tests are related to |
|
Updated. The error is due to the input check of The |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for your patience and immense amount of hardwork!
| ) | ||
|
|
||
| @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) | ||
| def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1): |
There was a problem hiding this comment.
Maybe I didn't make myself clear but why can't we just correctness validation to
There was a problem hiding this comment.
Done! Originally, I thought a two-level approach was more suitable, since an accurate test is a stricter guard than merely runnable. Now it has been merged into a single test.
|
do we want to add the same all CP modes, including ring? |
I think a precision test for all CP is still necessary? |
What does this PR do?
This fixes the issue #13696 . The test should be passed after this PR.
This the problem I found: The mask does not have a one-to-one correspondence with the content.
For QwenImage Pipeline, use the following example
After CP shard (assume 2 ranks)
After All-to-all
But the mask is not handled correctly
This PR makes mask correctly assigned
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.