diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 066c9f71f3ec..9c1123bbda64 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -106,7 +106,8 @@ def apply_rotary_emb( freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): - x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + x1 = hidden_states[..., 0::2] + x2 = hidden_states[..., 1::2] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states)