You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to capture the model prompt cross attention in order to apply some latent optimization techniques during inference, but for some reason I'm getting latents.grad as None no matter what I'm trying. I'll add small code snippets to describe what I'm trying to do:
class Optimizer:
def __init__(
self,
loss_fn: LayoutLoss,
num_refinements: int = 3,
lr_start: float = 0.01,
lr_end: float = 0.05,
betas: tuple[float, float] = (0.4, 0.9),
weight_decay: float = 0.0,
):
self.loss_fn = loss_fn
self.num_refinements = num_refinements
self.lr_start = lr_start
self.lr_end = lr_end
self.betas = betas
self.weight_decay = weight_decay
def optimize(
self,
transformer: LTX2VideoTransformer3DModel,
latents: torch.Tensor,
audio_latents: torch.Tensor,
prompt_embeds: torch.Tensor,
audio_prompt_embeds: torch.Tensor,
timestep: torch.Tensor,
attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
fps: float,
audio_num_frames: int,
video_coords: torch.Tensor,
audio_coords: torch.Tensor,
attention_kwargs: Dict[str, Any],
store: AttentionStore,
progress_bar: tqdm.tqdm,
) -> torch.Tensor:
latents = latents.clone().detach()
latents = latents.to(transformer.dtype)
optimizer = torch.optim.AdamW(
[latents],
lr=self.lr_start,
betas=self.betas,
weight_decay=self.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.num_refinements, eta_min=self.lr_end
)
transformer.zero_grad(set_to_none=True)
first_loss = None
with torch.enable_grad():
for i in range(self.num_refinements):
latents = latents.requires_grad_(True)
store.reset()
latent_model_input = latents.to(transformer.dtype)
_ = transformer(
hidden_states=latent_model_input,
audio_hidden_states=audio_latents,
encoder_hidden_states=prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=timestep,
encoder_attention_mask=attention_mask,
audio_encoder_attention_mask=attention_mask,
num_frames=num_frames,
height=height,
width=width,
fps=fps,
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
attn = store.get_avg_attention().unsqueeze(0)
loss = self.loss_fn(attn)
# Backward
loss.backward()
# FIX 4: Verify gradients exist before stepping
if latents.grad is None:
print(f"WARNING: latents.grad is None at iteration {i+1}!")
print(" Gradient flow is broken. Check:")
print(" 1. AttentionStore doesn't use .clone()")
print(" 2. No dtype conversion breaks the computation graph")
print(" 3. Gradient checkpointing is disabled")
break
# Only step if we have gradients
optimizer.step()
scheduler.step()
if i == 0:
first_loss = loss.item()
current_lr = scheduler.get_last_lr()[0]
progress_bar.set_postfix(
loss=f"{first_loss:.2f}→{loss.item():.2f}",
grad=f"{latents.grad.norm().item():.2e}",
lr=f"{current_lr:.2e}",
refine_step=f"{i + 1}/{self.num_refinements}",
)
store.reset()
return latents.detach()
class AttentionStore:
def __init__(self):
self.accumulator = None
self.count = 0
self.keep_heads = False
def __call__(self, probs: torch.Tensor) -> torch.Tensor:
if probs.shape[0] == 2:
probs = probs[1:]
if not self.keep_heads:
probs = probs.mean(dim=1)
if self.accumulator is None:
self.accumulator = probs
else:
self.accumulator = self.accumulator + probs
self.count += 1
return probs
def reset(self):
self.accumulator = None
self.count = 0
def get_avg_attention(self) -> torch.Tensor:
return self.accumulator / self.count
class AttnProcessor:
r"""
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
FIXED: Now uses manual attention output for cross-attention to maintain gradient flow.
"""
_attention_backend = None
_parallel_config = None
def __init__(self, store: AttentionStore, name: str):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
self.store = store
self.name = name
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
original_encoder_hidden_states = encoder_hidden_states
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
is_cross = original_encoder_hidden_states is not None and original_encoder_hidden_states is not hidden_states
if is_cross:
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
scale_factor = 1.0 / math.sqrt(q.size(-1))
scores = torch.matmul(q, k.transpose(-1, -2)) * scale_factor
if attention_mask is not None:
scores = scores + attention_mask
probs = F.softmax(scores, dim=-1)
self.store(probs)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
Hoped someone can point me to my issue, I have a feeling the the captured attention are not affecting the latents during the forward (graph computation).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm trying to capture the model prompt cross attention in order to apply some latent optimization techniques during inference, but for some reason I'm getting latents.grad as None no matter what I'm trying. I'll add small code snippets to describe what I'm trying to do:
Hoped someone can point me to my issue, I have a feeling the the captured attention are not affecting the latents during the forward (graph computation).
Beta Was this translation helpful? Give feedback.
All reactions