diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8d262f2cac13..d9920a877112 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2200,6 +2200,12 @@ def forward( query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: + # All-gather a local mask so its layout matches the QKV layout after all-to-all. + mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)] + dist.all_gather(mask_list, attn_mask, group=group) + attn_mask = torch.cat(mask_list, dim=-1) + out = forward_op( ctx, query, @@ -2399,6 +2405,8 @@ def forward( ctx.backward_op = backward_op ctx._parallel_config = _parallel_config + _, S_KV_LOCAL, _, _ = key.shape + metadata = ulysses_anything_metadata(query) query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) @@ -2408,6 +2416,19 @@ def forward( key = key_wait() # type: torch.Tensor value = value_wait() # type: torch.Tensor + if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: + # All-gather a local mask to match the post-all-to-all global sequence. + # The "anything" path allows unequal local sizes, so we pad to the + # maximum across ranks before all-gathering, then trim back. + mask_local_sizes = gather_size_by_comm(attn_mask.shape[-1], group) + max_local = max(mask_local_sizes) + if attn_mask.shape[-1] < max_local: + attn_mask = F.pad(attn_mask, (0, max_local - attn_mask.shape[-1])) + mask_list = [torch.empty_like(attn_mask) for _ in range(dist.get_world_size(group=group))] + dist.all_gather(mask_list, attn_mask, group=group) + attn_mask = torch.cat(mask_list, dim=-1) + attn_mask = attn_mask[..., : sum(mask_local_sizes)] + out = forward_op( ctx, query, diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index a7c91099926b..f721c51261e1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -205,15 +205,6 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block - block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - block_attention_kwargs["attention_mask"] = joint_attention_mask - block_samples = () for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -221,20 +212,20 @@ def forward( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, + joint_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 2385c0b1c8c3..464712bd94fd 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -497,6 +497,18 @@ def __call__( if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + if attention_mask is not None: + raise ValueError( + "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. " + "Pass encoder_hidden_states_mask to let the processor build the joint mask." + ) + + if encoder_hidden_states_mask is not None: + seq_img = hidden_states.shape[1] + image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device) + attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + attention_mask = attention_mask[:, None, None, :] + seq_txt = encoder_hidden_states.shape[1] # Compute QKV for image stream (sample projections) @@ -770,6 +782,7 @@ class QwenImageTransformer2DModel( }, "transformer_blocks.*": { "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), @@ -911,27 +924,16 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility - block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - joint_attention_mask = joint_attention_mask[:, None, None, :] - block_attention_kwargs["attention_mask"] = joint_attention_mask - for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, + attention_kwargs, modulate_index, ) @@ -939,10 +941,10 @@ def forward( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=attention_kwargs, modulate_index=modulate_index, ) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index d4f5e99d6763..d6b8854f93bf 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -51,7 +51,16 @@ def _find_free_port(): def _context_parallel_worker( - rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None + rank, + world_size, + master_port, + model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + attention_backend=None, + state_dict=None, ): """Worker function for context parallel testing.""" try: @@ -75,6 +84,8 @@ def _context_parallel_worker( # Create model model = model_class(**init_dict) + if state_dict is not None: + model.load_state_dict(state_dict) model.to(device) model.eval() @@ -100,6 +111,9 @@ def _context_parallel_worker( if rank == 0: return_dict["status"] = "success" return_dict["output_shape"] = list(output.shape) + if state_dict is not None: + # Serialise via nested list so the manager dict can transport it across processes. + return_dict["output"] = output.cpu().tolist() except Exception as e: if rank == 0: @@ -247,6 +261,12 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs(batch_size=batch_size) + # Single-GPU reference + model = self.model_class(**init_dict).eval().to(torch_device) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + with torch.no_grad(): + ref_output = model(**inputs_dict, return_dict=False)[0].cpu() + # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} cp_dict = {cp_type: world_size} @@ -261,7 +281,17 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): # Spawn worker processes mp.spawn( _context_parallel_worker, - args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict), + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + None, + state_dict, + ), nprocs=world_size, join=True, ) @@ -270,6 +300,9 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + cp_output = torch.tensor(return_dict["output"]) + torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4) + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 18da11c5f7a2..33bccf816c78 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -253,6 +253,15 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for QwenImage Transformer.""" + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + inputs = super().get_dummy_inputs(batch_size=batch_size) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + return inputs + class TestQwenImageTransformerContextParallelAttnBackends( QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin