From 6c29e11f0b610f4f89330788d817b64c8823777d Mon Sep 17 00:00:00 2001 From: arrdel Date: Wed, 3 Dec 2025 12:13:52 -0500 Subject: [PATCH] Disable Sage Attention sm90 backend due to confetti/noisy output The _SAGE_QK_INT8_PV_FP8_CUDA_SM90 backend is causing confetti/noisy output on SM 9.0+ GPUs. Temporarily disabling this backend by commenting out its registration until the upstream sageattention library fixes the issue. Fixes #12783 --- src/diffusers/models/attention_dispatch.py | 46 +++++++++++----------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffad94cc7f27..79b7e6b3562c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2215,28 +2215,30 @@ def _sage_qk_int8_pv_fp8_cuda_attention( ) -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_sm90_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - return_lse: bool = False, - _parallel_config: Optional["ParallelConfig"] = None, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda_sm90( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) +# Temporarily disabled due to issue #12783 - sm90 backend causes confetti/noisy output +# @_AttentionBackendRegistry.register( +# AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, +# constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +# ) +# def _sage_qk_int8_pv_fp8_cuda_sm90_attention( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# is_causal: bool = False, +# scale: Optional[float] = None, +# return_lse: bool = False, +# _parallel_config: Optional["ParallelConfig"] = None, +# ) -> torch.Tensor: +# return sageattn_qk_int8_pv_fp8_cuda_sm90( +# q=query, +# k=key, +# v=value, +# tensor_layout="NHD", +# is_causal=is_causal, +# sm_scale=scale, +# return_lse=return_lse, +# ) + @_AttentionBackendRegistry.register(