Skip to content
Merged
21 changes: 21 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,12 @@ def forward(
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
Comment thread
sayakpaul marked this conversation as resolved.
# All-gather a local mask so its layout matches the QKV layout after all-to-all.
mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
Comment thread
sayakpaul marked this conversation as resolved.

out = forward_op(
ctx,
query,
Expand Down Expand Up @@ -2399,6 +2405,8 @@ def forward(
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config

_, S_KV_LOCAL, _, _ = key.shape

metadata = ulysses_anything_metadata(query)
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
Expand All @@ -2408,6 +2416,19 @@ def forward(
key = key_wait() # type: torch.Tensor
value = value_wait() # type: torch.Tensor

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
# All-gather a local mask to match the post-all-to-all global sequence.
# The "anything" path allows unequal local sizes, so we pad to the
# maximum across ranks before all-gathering, then trim back.
mask_local_sizes = gather_size_by_comm(attn_mask.shape[-1], group)
max_local = max(mask_local_sizes)
if attn_mask.shape[-1] < max_local:
attn_mask = F.pad(attn_mask, (0, max_local - attn_mask.shape[-1]))
mask_list = [torch.empty_like(attn_mask) for _ in range(dist.get_world_size(group=group))]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
attn_mask = attn_mask[..., : sum(mask_local_sizes)]

out = forward_op(
ctx,
query,
Expand Down
17 changes: 4 additions & 13 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,36 +205,27 @@ def forward(
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

# Construct joint attention mask once to avoid reconstructing in every block
Comment thread
kashif marked this conversation as resolved.
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask

block_samples = ()
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
encoder_hidden_states_mask,
temb,
image_rotary_emb,
block_attention_kwargs,
joint_attention_kwargs,
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=block_attention_kwargs,
joint_attention_kwargs=joint_attention_kwargs,
)
block_samples = block_samples + (hidden_states,)

Expand Down
32 changes: 17 additions & 15 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,18 @@ def __call__(
if encoder_hidden_states is None:
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")

if attention_mask is not None:
raise ValueError(
"QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
"Pass encoder_hidden_states_mask to let the processor build the joint mask."
)

if encoder_hidden_states_mask is not None:
seq_img = hidden_states.shape[1]
image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device)
attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
attention_mask = attention_mask[:, None, None, :]

Comment on lines +500 to +511

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the controlnet also have similar changes or not because it doesn't define _cp_plan?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the controlnet model is calling QwenDoubleStreamAttnProcessor2_0 here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate that?

@zhtmike zhtmike Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The controlnet model itself does not have the CP plan. However, because we have modified QwenDoubleStreamAttnProcessor2_0 in this PR with an extra guard:

if attention_mask is not None:
    raise ValueError(
        "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
        "Pass encoder_hidden_states_mask to let the processor build the joint mask."
    )

So we either need to:

  • modify the controlnet following the QwenImage transformer’s change; or
  • drop the guard and avoid touching the controlnet.

In the future, we may need to add support for the SP implementation for controlnet. In that case, option 1 may be a better solution, since it will avoid such mask bugs more easily and is also consistent with the style of the Qwen‑Image transformer. So personally, I prefer option 1.

seq_txt = encoder_hidden_states.shape[1]

# Compute QKV for image stream (sample projections)
Expand Down Expand Up @@ -770,6 +782,7 @@ class QwenImageTransformer2DModel(
},
"transformer_blocks.*": {
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
Expand Down Expand Up @@ -911,38 +924,27 @@ def forward(

image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif here since this would revert the optimization as part of this PR #12702

I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP

@kashif kashif Jun 3, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @yiyixuxu! the #12702 bit was just me building the joint mask once instead of per-block. "eliminates 60 GPU syncs" was a bad comment on my part, i checked and it's actually 0 syncs, just a plain cat/ones. cost of dropping it is ~0.85ms/fwd eager and basically 0 with compile, so no real loss.

and yeah it has to go for CP anyway: the pre-built mask is over the full unsharded seq so it can't line up after the all-to-all. confirmed with the #13696 repro, main is off by 2.9e-2 and this PR is exactly 0.0. lgtm from me 👍

if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
encoder_hidden_states_mask,
temb,
image_rotary_emb,
block_attention_kwargs,
attention_kwargs,
modulate_index,
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=block_attention_kwargs,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)

Expand Down
37 changes: 35 additions & 2 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ def _find_free_port():


def _context_parallel_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
inputs_dict,
return_dict,
attention_backend=None,
state_dict=None,
):
"""Worker function for context parallel testing."""
try:
Expand All @@ -75,6 +84,8 @@ def _context_parallel_worker(

# Create model
model = model_class(**init_dict)
if state_dict is not None:
model.load_state_dict(state_dict)
model.to(device)
model.eval()

Expand All @@ -100,6 +111,9 @@ def _context_parallel_worker(
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
if state_dict is not None:
# Serialise via nested list so the manager dict can transport it across processes.
return_dict["output"] = output.cpu().tolist()

except Exception as e:
if rank == 0:
Expand Down Expand Up @@ -247,6 +261,12 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)

# Single-GPU reference
model = self.model_class(**init_dict).eval().to(torch_device)
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
with torch.no_grad():
ref_output = model(**inputs_dict, return_dict=False)[0].cpu()

# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
Expand All @@ -261,7 +281,17 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
# Spawn worker processes
mp.spawn(
_context_parallel_worker,
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
inputs_dict,
return_dict,
None,
state_dict,
),
nprocs=world_size,
join=True,
)
Expand All @@ -270,6 +300,9 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

cp_output = torch.tensor(return_dict["output"])
torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4)

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""

def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
inputs = super().get_dummy_inputs(batch_size=batch_size)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
return inputs


class TestQwenImageTransformerContextParallelAttnBackends(
QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
Expand Down
Loading