From e27c4d3acce44275d57803a5e3bf22979d02e7ff Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 7 Mar 2026 01:09:47 +0100 Subject: [PATCH 1/2] Modify AttentionProcessorSkipHook to support _skip_attn_scores flag on attn processors to allow custom STG-style logic --- src/diffusers/hooks/layer_skip.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 112edfa2f79b..867b1d937e38 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -108,8 +108,17 @@ def __torch_function__(self, func, types, args=(), kwargs=None): class AttentionProcessorSkipHook(ModelHook): - def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0): + def __init__( + self, + skip_processor_output_fn: Callable, + skip_attn_scores_fn: Callable | None = None, + skip_attention_scores: bool = False, + dropout: float = 1.0, + ): + super().__init__() self.skip_processor_output_fn = skip_processor_output_fn + # STG default: return the values as attention output + self.skip_attn_scores_fn = skip_attn_scores_fn or (lambda attn, q, k, v: v) self.skip_attention_scores = skip_attention_scores self.dropout = dropout @@ -119,8 +128,22 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): raise ValueError( "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0." ) - with AttentionScoreSkipFunctionMode(): - output = self.fn_ref.original_forward(*args, **kwargs) + processor_supports_skip_fn = hasattr(module.processor, "_skip_attn_scores") + if processor_supports_skip_fn: + module.processor._skip_attn_scores = True + module.processor._skip_attn_scores_fn = self.skip_attn_scores_fn + # Use try block in case attn processor raises an exception + try: + if processor_supports_skip_fn: + output = self.fn_ref.original_forward(*args, **kwargs) + else: + # Fallback to torch native SDPA intercept approach + with AttentionScoreSkipFunctionMode(): + output = self.fn_ref.original_forward(*args, **kwargs) + finally: + if processor_supports_skip_fn: + module.processor._skip_attn_scores = False + module.processor._skip_attn_scores_fn = None else: if math.isclose(self.dropout, 1.0): output = self.skip_processor_output_fn(module, *args, **kwargs) From 82da133a384ae58a774db28cba5d742ddf90caaf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 8 Mar 2026 03:00:49 +0100 Subject: [PATCH 2/2] Respect original order of args for AttentionProcessorSkipHook --- src/diffusers/hooks/layer_skip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 867b1d937e38..82106ece38fe 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -111,9 +111,9 @@ class AttentionProcessorSkipHook(ModelHook): def __init__( self, skip_processor_output_fn: Callable, - skip_attn_scores_fn: Callable | None = None, skip_attention_scores: bool = False, dropout: float = 1.0, + skip_attn_scores_fn: Callable | None = None, ): super().__init__() self.skip_processor_output_fn = skip_processor_output_fn