-
Notifications
You must be signed in to change notification settings - Fork 7k
Fix the QwenImage Attention mask under Ulysses SP #13756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
84c545f
1d391b5
9346da0
026327f
52bd940
03017a6
7d21597
fdc436f
cb4baf4
9569400
840a605
2ab47d3
6c45dc4
b2816d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -497,6 +497,18 @@ def __call__( | |
| if encoder_hidden_states is None: | ||
| raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") | ||
|
|
||
| if attention_mask is not None: | ||
| raise ValueError( | ||
| "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. " | ||
| "Pass encoder_hidden_states_mask to let the processor build the joint mask." | ||
| ) | ||
|
|
||
| if encoder_hidden_states_mask is not None: | ||
| seq_img = hidden_states.shape[1] | ||
| image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device) | ||
| attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) | ||
| attention_mask = attention_mask[:, None, None, :] | ||
|
|
||
|
Comment on lines
+500
to
+511
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the controlnet also have similar changes or not because it doesn't define
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the controlnet model is calling
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate that?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The controlnet model itself does not have the CP plan. However, because we have modified if attention_mask is not None:
raise ValueError(
"QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
"Pass encoder_hidden_states_mask to let the processor build the joint mask."
)So we either need to:
In the future, we may need to add support for the SP implementation for controlnet. In that case, option 1 may be a better solution, since it will avoid such mask bugs more easily and is also consistent with the style of the Qwen‑Image transformer. So personally, I prefer option 1. |
||
| seq_txt = encoder_hidden_states.shape[1] | ||
|
|
||
| # Compute QKV for image stream (sample projections) | ||
|
|
@@ -770,6 +782,7 @@ class QwenImageTransformer2DModel( | |
| }, | ||
| "transformer_blocks.*": { | ||
| "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), | ||
| "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), | ||
| }, | ||
| "pos_embed": { | ||
| 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), | ||
|
|
@@ -911,38 +924,27 @@ def forward( | |
|
|
||
| image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) | ||
|
|
||
| # Construct joint attention mask once to avoid reconstructing in every block | ||
| # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility | ||
| block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @kashif here since this would revert the optimization as part of this PR #12702 I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks @yiyixuxu! the #12702 bit was just me building the joint mask once instead of per-block. "eliminates 60 GPU syncs" was a bad comment on my part, i checked and it's actually 0 syncs, just a plain cat/ones. cost of dropping it is ~0.85ms/fwd eager and basically 0 with compile, so no real loss. and yeah it has to go for CP anyway: the pre-built mask is over the full unsharded seq so it can't line up after the all-to-all. confirmed with the #13696 repro, main is off by 2.9e-2 and this PR is exactly 0.0. lgtm from me 👍 |
||
| if encoder_hidden_states_mask is not None: | ||
| # Build joint mask: [text_mask, all_ones_for_image] | ||
| batch_size, image_seq_len = hidden_states.shape[:2] | ||
| image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) | ||
| joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) | ||
| joint_attention_mask = joint_attention_mask[:, None, None, :] | ||
| block_attention_kwargs["attention_mask"] = joint_attention_mask | ||
|
|
||
| for index_block, block in enumerate(self.transformer_blocks): | ||
| if torch.is_grad_enabled() and self.gradient_checkpointing: | ||
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( | ||
| block, | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) | ||
| encoder_hidden_states_mask, | ||
| temb, | ||
| image_rotary_emb, | ||
| block_attention_kwargs, | ||
| attention_kwargs, | ||
| modulate_index, | ||
| ) | ||
|
|
||
| else: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) | ||
| encoder_hidden_states_mask=encoder_hidden_states_mask, | ||
| temb=temb, | ||
| image_rotary_emb=image_rotary_emb, | ||
| joint_attention_kwargs=block_attention_kwargs, | ||
| joint_attention_kwargs=attention_kwargs, | ||
| modulate_index=modulate_index, | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.