From 84c545f9f650819719095f099ec0917d70e57e8b Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 15 May 2026 12:13:17 +0800 Subject: [PATCH 01/10] fix mask --- .../models/transformers/transformer_qwenimage.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index bdb87a385da7..76b84532b007 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -913,10 +913,22 @@ def forward( # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + + parallel_config = getattr(self, "_parallel_config", None) + cp_config = parallel_config.context_parallel_config if parallel_config is not None else None + sp_size = cp_config.ulysses_degree if cp_config is not None else 1 + + if sp_size == 1: + # Build joint mask: [text_mask, all_ones_for_image] + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + else: + # Interleave text and image mask chunks to match the post-all-to-all layout + txt_chunks = encoder_hidden_states_mask.chunk(sp_size, dim=1) + img_chunks = image_mask.chunk(sp_size, dim=1) + joint_attention_mask = torch.cat([x for pair in zip(txt_chunks, img_chunks) for x in pair], dim=1) + joint_attention_mask = joint_attention_mask[:, None, None, :] block_attention_kwargs["attention_mask"] = joint_attention_mask From 1d391b5f635d17173d4ec2512588ceb6454b3104 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Tue, 19 May 2026 11:47:12 +0800 Subject: [PATCH 02/10] handle mask locally --- src/diffusers/models/attention_dispatch.py | 5 +++ .../transformers/transformer_qwenimage.py | 38 ++++++------------- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e68d317bc140..59d4bb7ce8be 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2200,6 +2200,11 @@ def forward( query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: + mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)] + dist.all_gather(mask_list, attn_mask, group=group) + attn_mask = torch.cat(mask_list, dim=-1) + out = forward_op( ctx, query, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 76b84532b007..691f535e4a42 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -497,6 +497,12 @@ def __call__( if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") + if attention_mask is None and encoder_hidden_states_mask is not None: + seq_img = hidden_states.shape[1] + image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device) + attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + attention_mask = attention_mask[:, None, None, :] + seq_txt = encoder_hidden_states.shape[1] # Compute QKV for image stream (sample projections) @@ -770,6 +776,7 @@ class QwenImageTransformer2DModel( }, "transformer_blocks.*": { "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), @@ -909,39 +916,16 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility - block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - - parallel_config = getattr(self, "_parallel_config", None) - cp_config = parallel_config.context_parallel_config if parallel_config is not None else None - sp_size = cp_config.ulysses_degree if cp_config is not None else 1 - - if sp_size == 1: - # Build joint mask: [text_mask, all_ones_for_image] - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - else: - # Interleave text and image mask chunks to match the post-all-to-all layout - txt_chunks = encoder_hidden_states_mask.chunk(sp_size, dim=1) - img_chunks = image_mask.chunk(sp_size, dim=1) - joint_attention_mask = torch.cat([x for pair in zip(txt_chunks, img_chunks) for x in pair], dim=1) - - joint_attention_mask = joint_attention_mask[:, None, None, :] - block_attention_kwargs["attention_mask"] = joint_attention_mask - for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, + attention_kwargs, modulate_index, ) @@ -949,10 +933,10 @@ def forward( encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=attention_kwargs, modulate_index=modulate_index, ) From 9346da0678ebbcea12da256e11b2be612736bbdb Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 20 May 2026 10:00:42 +0800 Subject: [PATCH 03/10] update according to comment --- src/diffusers/models/attention_dispatch.py | 1 + .../models/transformers/transformer_qwenimage.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 59d4bb7ce8be..9095ca17957f 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2201,6 +2201,7 @@ def forward( query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: + # All-gather a local mask so its layout matches the QKV layout after all-to-all. mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)] dist.all_gather(mask_list, attn_mask, group=group) attn_mask = torch.cat(mask_list, dim=-1) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 691f535e4a42..cc6875336a84 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -491,13 +491,19 @@ def __call__( hidden_states: torch.FloatTensor, # Image stream encoder_hidden_states: torch.FloatTensor = None, # Text stream encoder_hidden_states_mask: torch.FloatTensor = None, - attention_mask: torch.FloatTensor | None = None, + attention_mask: None = None, image_rotary_emb: torch.Tensor | None = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") - if attention_mask is None and encoder_hidden_states_mask is not None: + if attention_mask is not None: + raise ValueError( + "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. " + "Pass encoder_hidden_states_mask to let the processor build the joint mask." + ) + + if encoder_hidden_states_mask is not None: seq_img = hidden_states.shape[1] image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device) attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) From 026327f528085fd4a4801f8eeb98906cf6e47ca0 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 20 May 2026 10:57:19 +0800 Subject: [PATCH 04/10] fix ulysses_anything as well --- src/diffusers/models/attention_dispatch.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 9095ca17957f..d706e360fa0d 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2405,6 +2405,8 @@ def forward( ctx.backward_op = backward_op ctx._parallel_config = _parallel_config + _, S_KV_LOCAL, _, _ = key.shape + metadata = ulysses_anything_metadata(query) query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) @@ -2414,6 +2416,19 @@ def forward( key = key_wait() # type: torch.Tensor value = value_wait() # type: torch.Tensor + if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: + # All-gather a local mask to match the post-all-to-all global sequence. + # The "anything" path allows unequal local sizes, so we pad to the + # maximum across ranks before all-gathering, then trim back. + mask_local_sizes = gather_size_by_comm(attn_mask.shape[-1], group) + max_local = max(mask_local_sizes) + if attn_mask.shape[-1] < max_local: + attn_mask = F.pad(attn_mask, (0, max_local - attn_mask.shape[-1])) + mask_list = [torch.empty_like(attn_mask) for _ in range(dist.get_world_size(group=group))] + dist.all_gather(mask_list, attn_mask, group=group) + attn_mask = torch.cat(mask_list, dim=-1) + attn_mask = attn_mask[..., : sum(mask_local_sizes)] + out = forward_op( ctx, query, From 52bd9405a46353995097b7c65df5f5701a4d44a1 Mon Sep 17 00:00:00 2001 From: Cheung Ka Wai Date: Wed, 20 May 2026 11:29:03 +0800 Subject: [PATCH 05/10] Update src/diffusers/models/transformers/transformer_qwenimage.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_qwenimage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cc6875336a84..dbef7b2e1bca 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -491,7 +491,7 @@ def __call__( hidden_states: torch.FloatTensor, # Image stream encoder_hidden_states: torch.FloatTensor = None, # Text stream encoder_hidden_states_mask: torch.FloatTensor = None, - attention_mask: None = None, + attention_mask: torch.FloatTensor | None = None, image_rotary_emb: torch.Tensor | None = None, ) -> torch.FloatTensor: if encoder_hidden_states is None: From 03017a6ee961e4cb216ba7902281a351fc1d025d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 20 May 2026 11:53:16 +0800 Subject: [PATCH 06/10] add accuracy test --- tests/models/testing_utils/parallelism.py | 91 +++++++++++++++++++ .../test_models_transformer_qwenimage.py | 9 ++ 2 files changed, 100 insertions(+) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index d4f5e99d6763..f80b0658bd85 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -227,6 +227,51 @@ def _custom_mesh_worker( dist.destroy_process_group() +def _context_parallel_correctness_worker( + rank, world_size, master_port, model_class, init_dict, state_dict, cp_dict, inputs_dict, return_dict +): + """Worker that runs a CP forward pass and returns the output tensor for numerical comparison.""" + try: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) + backend = device_config["backend"] + device_module = device_config["module"] + + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + device_module.set_device(rank) + device = torch.device(f"{torch_device}:{rank}") + + model = model_class(**init_dict) + model.load_state_dict(state_dict) + model.to(device) + model.eval() + + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + if rank == 0: + return_dict["status"] = "success" + # Serialise via nested list so the manager dict can transport it across processes. + return_dict["output"] = output.cpu().tolist() + + except Exception as e: + if rank == 0: + return_dict["status"] = "error" + return_dict["error"] = str(e) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + @is_context_parallel @require_torch_multi_accelerator class ContextParallelTesterMixin: @@ -369,6 +414,52 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1): + """Verify that CP output is numerically identical to a single-GPU reference forward pass.""" + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) + + # Single-GPU reference + model = self.model_class(**init_dict).eval().to(torch_device) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + with torch.no_grad(): + ref_output = model(**inputs_dict, return_dict=False)[0].cpu() + + # Context-parallel run with the same weights + inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + cp_dict = {cp_type: world_size} + + master_port = _find_free_port() + manager = mp.Manager() + return_dict = manager.dict() + + mp.spawn( + _context_parallel_correctness_worker, + args=(world_size, master_port, self.model_class, init_dict, state_dict, cp_dict, inputs_cpu, return_dict), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Context parallel correctness check failed: {return_dict.get('error', 'Unknown error')}" + ) + + cp_output = torch.tensor(return_dict["output"]) + torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4) + @is_attention @is_context_parallel diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 18da11c5f7a2..33bccf816c78 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -253,6 +253,15 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for QwenImage Transformer.""" + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + inputs = super().get_dummy_inputs(batch_size=batch_size) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask + return inputs + class TestQwenImageTransformerContextParallelAttnBackends( QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin From 95694008998385d0dde2eda399bad4afe48e585d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 5 Jun 2026 06:13:23 +0000 Subject: [PATCH 07/10] fix controlnet --- .../models/controlnets/controlnet_qwenimage.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index a7c91099926b..f721c51261e1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -205,15 +205,6 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block - block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} - if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] - batch_size, image_seq_len = hidden_states.shape[:2] - image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) - joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) - block_attention_kwargs["attention_mask"] = joint_attention_mask - block_samples = () for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -221,20 +212,20 @@ def forward( block, hidden_states, encoder_hidden_states, - None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) + encoder_hidden_states_mask, temb, image_rotary_emb, - block_attention_kwargs, + joint_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) + encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=block_attention_kwargs, + joint_attention_kwargs=joint_attention_kwargs, ) block_samples = block_samples + (hidden_states,) From 840a605b96acffe97e6c991dbf6ea158ba935874 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 5 Jun 2026 07:25:11 +0000 Subject: [PATCH 08/10] refactor worker --- tests/models/testing_utils/parallelism.py | 75 ++++++++--------------- 1 file changed, 27 insertions(+), 48 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index f80b0658bd85..6324d564806f 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -51,7 +51,16 @@ def _find_free_port(): def _context_parallel_worker( - rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None + rank, + world_size, + master_port, + model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + attention_backend=None, + state_dict=None, ): """Worker function for context parallel testing.""" try: @@ -75,6 +84,8 @@ def _context_parallel_worker( # Create model model = model_class(**init_dict) + if state_dict is not None: + model.load_state_dict(state_dict) model.to(device) model.eval() @@ -100,6 +111,9 @@ def _context_parallel_worker( if rank == 0: return_dict["status"] = "success" return_dict["output_shape"] = list(output.shape) + if state_dict is not None: + # Serialise via nested list so the manager dict can transport it across processes. + return_dict["output"] = output.cpu().tolist() except Exception as e: if rank == 0: @@ -227,51 +241,6 @@ def _custom_mesh_worker( dist.destroy_process_group() -def _context_parallel_correctness_worker( - rank, world_size, master_port, model_class, init_dict, state_dict, cp_dict, inputs_dict, return_dict -): - """Worker that runs a CP forward pass and returns the output tensor for numerical comparison.""" - try: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(master_port) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - - device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) - backend = device_config["backend"] - device_module = device_config["module"] - - dist.init_process_group(backend=backend, rank=rank, world_size=world_size) - device_module.set_device(rank) - device = torch.device(f"{torch_device}:{rank}") - - model = model_class(**init_dict) - model.load_state_dict(state_dict) - model.to(device) - model.eval() - - inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} - - cp_config = ContextParallelConfig(**cp_dict) - model.enable_parallelism(config=cp_config) - - with torch.no_grad(): - output = model(**inputs_on_device, return_dict=False)[0] - - if rank == 0: - return_dict["status"] = "success" - # Serialise via nested list so the manager dict can transport it across processes. - return_dict["output"] = output.cpu().tolist() - - except Exception as e: - if rank == 0: - return_dict["status"] = "error" - return_dict["error"] = str(e) - finally: - if dist.is_initialized(): - dist.destroy_process_group() - - @is_context_parallel @require_torch_multi_accelerator class ContextParallelTesterMixin: @@ -447,8 +416,18 @@ def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1) return_dict = manager.dict() mp.spawn( - _context_parallel_correctness_worker, - args=(world_size, master_port, self.model_class, init_dict, state_dict, cp_dict, inputs_cpu, return_dict), + _context_parallel_worker, + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + inputs_cpu, + return_dict, + None, + state_dict, + ), nprocs=world_size, join=True, ) From 2ab47d3a7a6348035cfd12646bf46564d07dc49d Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Fri, 5 Jun 2026 09:03:37 +0000 Subject: [PATCH 09/10] refactor correctness test --- tests/models/testing_utils/parallelism.py | 67 ++++------------------- 1 file changed, 10 insertions(+), 57 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 6324d564806f..560ac59de81e 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -261,6 +261,12 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs(batch_size=batch_size) + # Single-GPU reference + model = self.model_class(**init_dict).eval().to(torch_device) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + with torch.no_grad(): + ref_output = model(**inputs_dict, return_dict=False)[0].cpu() + # Move all tensors to CPU for multiprocessing inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} cp_dict = {cp_type: world_size} @@ -275,7 +281,7 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): # Spawn worker processes mp.spawn( _context_parallel_worker, - args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict), + args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict, None, state_dict), nprocs=world_size, join=True, ) @@ -284,6 +290,9 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + cp_output = torch.tensor(return_dict["output"]) + torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4) + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2) @@ -383,62 +392,6 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names) f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) - @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) - def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1): - """Verify that CP output is numerically identical to a single-GPU reference forward pass.""" - if not torch.distributed.is_available(): - pytest.skip("torch.distributed is not available.") - - if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: - pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") - - if cp_type == "ring_degree": - active_backend, _ = _AttentionBackendRegistry.get_active_backend() - if active_backend == AttentionBackendName.NATIVE: - pytest.skip("Ring attention is not supported with the native attention backend.") - - world_size = 2 - init_dict = self.get_init_dict() - inputs_dict = self.get_dummy_inputs(batch_size=batch_size) - - # Single-GPU reference - model = self.model_class(**init_dict).eval().to(torch_device) - state_dict = {k: v.cpu() for k, v in model.state_dict().items()} - with torch.no_grad(): - ref_output = model(**inputs_dict, return_dict=False)[0].cpu() - - # Context-parallel run with the same weights - inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} - cp_dict = {cp_type: world_size} - - master_port = _find_free_port() - manager = mp.Manager() - return_dict = manager.dict() - - mp.spawn( - _context_parallel_worker, - args=( - world_size, - master_port, - self.model_class, - init_dict, - cp_dict, - inputs_cpu, - return_dict, - None, - state_dict, - ), - nprocs=world_size, - join=True, - ) - - assert return_dict.get("status") == "success", ( - f"Context parallel correctness check failed: {return_dict.get('error', 'Unknown error')}" - ) - - cp_output = torch.tensor(return_dict["output"]) - torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4) - @is_attention @is_context_parallel From b2816d359f5d1e4d726ee8bf48a8c6036331cf50 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Mon, 8 Jun 2026 07:50:54 +0000 Subject: [PATCH 10/10] fix style --- tests/models/testing_utils/parallelism.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 560ac59de81e..d6b8854f93bf 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -281,7 +281,17 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): # Spawn worker processes mp.spawn( _context_parallel_worker, - args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict, None, state_dict), + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + None, + state_dict, + ), nprocs=world_size, join=True, )