diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 112edfa2f79b..82106ece38fe 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -108,8 +108,17 @@ def __torch_function__(self, func, types, args=(), kwargs=None): class AttentionProcessorSkipHook(ModelHook): - def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + def __init__( + self, + skip_processor_output_fn: Callable, + skip_attention_scores: bool = False, + dropout: float = 1.0, + skip_attn_scores_fn: Callable | None = None, + ): + super().__init__() self.skip_processor_output_fn = skip_processor_output_fn + # STG default: return the values as attention output + self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v) self.skip_attention_scores = skip_attention_scores self.dropout = dropout @@ -119,8 +128,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): raise ValueError( "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) - with AttentionScoreSkipFunctionMode(): - output = self.fn_ref.original_forward(*args, **kwargs) + processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores") + if processor_supports_skip_fn: + module.processor._skip_attn_scores = True + module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn + # Use try block in case attn processor raises an exception + try: + if processor_supports_skip_fn: + output = self.fn_ref.original_forward(*args, **kwargs) + else: + # Fallback to torch native SDPA intercept approach + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + finally: + if processor_supports_skip_fn: + module.processor._skip_attn_scores = False + module.processor._skip_attn_scores_fn = None else: if math.isclose(self.dropout, 1.0): output = self.skip_processor_output_fn(module, *args, **kwargs)