diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffad94cc7f27..79b7e6b3562c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -2215,28 +2215,30 @@ def _sage_qk_int8_pv_fp8_cuda_attention( ) -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_sm90_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - scale: Optional[float] = None, - return_lse: bool = False, - _parallel_config: Optional["ParallelConfig"] = None, -) -> torch.Tensor: - return sageattn_qk_int8_pv_fp8_cuda_sm90( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) +# Temporarily disabled due to issue #12783 - sm90 backend causes confetti/noisy output +# @_AttentionBackendRegistry.register( +# AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, +# constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +# ) +# def _sage_qk_int8_pv_fp8_cuda_sm90_attention( +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# is_causal: bool = False, +# scale: Optional[float] = None, +# return_lse: bool = False, +# _parallel_config: Optional["ParallelConfig"] = None, +# ) -> torch.Tensor: +# return sageattn_qk_int8_pv_fp8_cuda_sm90( +# q=query, +# k=key, +# v=value, +# tensor_layout="NHD", +# is_causal=is_causal, +# sm_scale=scale, +# return_lse=return_lse, +# ) + @_AttentionBackendRegistry.register(