From cc41b5c3cbb3619a03fea8c54a382a04e798938f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 24 Nov 2025 19:50:24 +0100 Subject: [PATCH 1/6] temp --- src/diffusers/__init__.py | 4 + src/diffusers/hooks/rolling_kv_cache.py | 136 +++ src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_krea.py | 718 ++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../__init__.py | 47 ++ .../pipeline.py | 771 ++++++++++++++++++ .../pipeline_output.py | 20 + 9 files changed, 1701 insertions(+) create mode 100644 src/diffusers/hooks/rolling_kv_cache.py create mode 100644 src/diffusers/models/transformers/transformer_krea.py create mode 100644 src/diffusers/pipelines/autoregressive_block_diffusion/__init__.py create mode 100644 src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py create mode 100644 src/diffusers/pipelines/autoregressive_block_diffusion/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 572aad4bd3f1..93a10facb72c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -227,6 +227,7 @@ "I2VGenXLUNet", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", + "KreaTransformerModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -430,6 +431,7 @@ "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", "AuraFlowPipeline", + "BaseAutoregressiveDiffusionPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "BriaFiboPipeline", @@ -934,6 +936,7 @@ I2VGenXLUNet, Kandinsky3UNet, Kandinsky5Transformer3DModel, + KreaTransformerModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -1109,6 +1112,7 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BaseAutoregressiveDiffusionPipeline, BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py new file mode 100644 index 000000000000..a7b0dd9a027f --- /dev/null +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -0,0 +1,136 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from ..models.attention_processor import Attention +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._helpers import TransformerBlockRegistry +from .hooks import ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class RollingKVCacheCacheConfig: + local_attn_size: int = -1 + num_sink_tokens: int = 1 + frame_seq_length: int = 128 + batch_size: int = 1 + max_seq_length: int = 32760 + + +# One hook per each attention layer +class RollingKVCachekHook(ModelHook): + _is_stateful = True + + def __init__( + self, + state_manager: StateManager, + batch_size: int, + max_seq_length: int, + num_sink_tokens: int, + frame_seq_length: int, + local_attn_size: int, + ): + self.state_manager = state_manager + self.batch_size = batch_size + self.num_sink_tokens = num_sink_tokens + if local_attn_size != -1: + self.max_seq_length = local_attn_size * frame_seq_length + else: + self.max_seq_length = max_seq_length + self._metadata = None + self.cache_initialized = False + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + components = unwrapped_module.components() + if "transformer" not in components: + raise ValueError( + f"{unwrapped_module.__class__.__name__} has no transformer block and can't apply a Rolling KV cache." + ) + + transformer = components["transformer"] + self.dtype = transformer.device + self.device = transformer.dtype + self.num_layers = len(transformer.blocks) + num_heads = transformer.config.num_heads + hidden_dim = transformer.config.dim + encoder_hidden_dim = transformer.config.encoder_dim # whats the common name? + self.self_attn_kv_shape = [self.batch_size, self.max_seq_length, num_heads, hidden_dim // num_heads] + self.cross_attn_kv_shape = [self.batch_size, encoder_hidden_dim, num_heads, hidden_dim // num_heads] + return module + + def lazy_initialize_cache(self, device: str, dtype: torch.dtype): + """ + Initialize a Per-GPU KV cache for the Wan model. + """ + if not self.cache_initialized: + self.key_cache = torch.zeros(self.self_attn_kv_shape, device=device, dtype=dtype) + self.value_cache = torch.zeros(self.self_attn_kv_shape, device=device, dtype=dtype) + self.cross_key_cache = torch.zeros(self.cross_attn_kv_shape, device=device, dtype=dtype) + self.cross_value_cache = torch.zeros(self.cross_attn_kv_shape, device=device, dtype=dtype) + self.cache_initialized = True + self.global_end_index + self.local_end_index + return self.key_cache, self.value_cache, self.cross_key_cache, self.cross_value_cache + + def new_forward(self, module: Attention, *args, **kwargs): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + current_cache = self.lazy_initialize_cache(original_hidden_states.device, original_hidden_states.dtype) + kwargs["kv_cache"] = current_cache + output = self.fn_ref.original_forward(*args, **kwargs) + return output + + def reset_cache(self, module): + if self.cache_initialized: + self.key_cache.zero_() + self.value_cache.zero_() + self.cross_key_cache.zero_() + self.cross_value_cache.zero_() + self.global_end_index = 0 + self.local_end_index = 0 + self.local_start_index = 0 + return module + + @torch.compiler.disable + def update(self, key_states: torch.Tensor, value_states: torch.Tensor) -> bool: + # Assign new keys/values directly up to current_end + start_idx, end_idx = self.maybe_roll_back(key_states.shape[2]) + self.key_cache[:, start_idx:end_idx] = key_states + self.value_cache[:, start_idx:end_idx] = value_states + self.local_start_index += key_states.shape[0] + return key_states, value_states + + @torch.compiler.disable + def maybe_roll_back(self, num_new_tokens: int): + if num_new_tokens + self.local_end_index > self.max_seq_length: + num_evicted_tokens = self.max_seq_length - (num_new_tokens + self.local_end_index) + else: + num_evicted_tokens = 0 + + # Skip `sink_tokens` and `num_evicted_tokens`. Roll back cache by removing the evicted tokens + num_tokens_to_skip = self.sink_tokens + num_evicted_tokens + self.key_cache[:, self.sink_tokens :] = self.key_cache[:, num_tokens_to_skip:].clone() + self.value_cache[:, self.sink_tokens :] = self.value_cache[:, num_tokens_to_skip:].clone() + + self.local_start_index = self.local_start_index - num_evicted_tokens + self.local_end_index = self.local_start_index + num_new_tokens + return self.local_start_index, self.local_end_index diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 202e77fd197d..92bfaa74357b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,6 +96,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] + _import_structure["transformers.transformer_krea"] = ["KreaTransformerModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] @@ -194,6 +195,7 @@ HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, + KreaTransformerModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 15408a4b15cc..070ca8961973 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -30,6 +30,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_krea import KreaTransformerModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_krea.py b/src/diffusers/models/transformers/transformer_krea.py new file mode 100644 index 000000000000..7940eefe5c44 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_krea.py @@ -0,0 +1,718 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_qkv_projections(attn: "KreaAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "KreaAttention", encoder_hidden_states_img: torch.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class KreaAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "KreaAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "KreaAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + + def apply_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)) + value_img = value_img.unflatten(2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out(hidden_states) + return hidden_states + + +class KreaAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "The KreaAttnProcessor2_0 class is deprecated and will be removed in a future version. " + "Please use KreaAttnProcessor instead. " + ) + deprecate("KreaAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + return KreaAttnProcessor(*args, **kwargs) + + +class KreaAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = KreaAttnProcessor + _available_processors = [KreaAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True) + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.qkv = nn.Linear(in_features, out_features, bias=True) + self.qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.kv = nn.Linear(in_features, out_features, bias=True) + self.kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "qkv"): + delattr(self, "qkv") + if hasattr(self, "kv"): + delattr(self, "kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + +class KreaImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class KreaTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = KreaImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + timestep_seq_len: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class KreaRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +@maybe_allow_in_graph +class KreaTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = KreaAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=KreaAttnProcessor(), + ) + + # 2. Cross-attention + self.attn2 = KreaAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=KreaAttnProcessor(), + ) + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + # self.ffn = nn.Sequential( + # nn.Linear(dim, ffn_dim), + # nn.GELU(approximate="tanh"), + # nn.Linear(ffn_dim, dim), + # ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm3(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class KreaTransformerModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Krea model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["KreaTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["KreaTransformerBlock"] + _cp_plan = { + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + "": { + "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + } + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = KreaRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = KreaTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + KreaTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 87d953845e21..6623c9981e92 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -152,6 +152,7 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", ] + _import_structure["autoregressive_block_diffusion"] = ["BaseAutoregressiveDiffusionPipeline"] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"] _import_structure["cogvideo"] = [ @@ -562,6 +563,7 @@ AudioLDM2UNet2DConditionModel, ) from .aura_flow import AuraFlowPipeline + from .autoregressive_block_diffusion import BaseAutoregressiveDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline from .bria_fibo import BriaFiboPipeline diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/__init__.py b/src/diffusers/pipelines/autoregressive_block_diffusion/__init__.py new file mode 100644 index 000000000000..8c15630fcf40 --- /dev/null +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline"] = ["BaseAutoregressiveDiffusionPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline import BaseAutoregressiveDiffusionPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py new file mode 100644 index 000000000000..243bfed0757b --- /dev/null +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py @@ -0,0 +1,771 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, KreaTransformerModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import KreaPipelineOutput + + +if is_ftfy_available(): + import ftfy + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class BaseAutoregressiveDiffusionPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using hybrid autoregressive-diffusion models. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`UniPCMultistepScheduler`]): + A text conditioned `KreaTransformerModel` to denoise the encoded video latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: KreaTransformerModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_blocks: int = 9, + num_frames_per_block: int = 3, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = num_blocks * num_frames_per_block + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def using_cache(self): + for module_name, module in self.transformer.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook") and "rolling_kv_hook" in module._diffusers_hook.hooks.keys(): + return True + return False + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 6, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + block_size: int = 9, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `6`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KreaPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~KreaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KreaPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + + if num_frames % block_size != 0: + raise ValueError(f"Number of frames={num_frames} is not divisible by block size={block_size}") + num_blocks = num_frames // block_size + num_frames_per_block = num_frames // num_blocks + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_blocks, + num_frames_per_block, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + latent_h, latent_w = latents.shape[-2:] + + # 6. Denoising loop + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + output = torch.zeros( + batch_size, num_channels_latents, num_frames, latent_h, latent_w, device=latents.device, dtype=latents.dtype + ) + with self.progress_bar(total=num_blocks) as progress_bar: + for block in range(num_blocks): + cache_start_frame = block * block_size if self.using_cache else 0 + current_latents = latents[:, :, cache_start_frame : cache_start_frame + block_size] + current_latents = current_latents.to(transformer_dtype) + denoised_latents = self.denoise_once( + current_latents, + timesteps, + mask, + prompt_embeds, + negative_prompt_embeds, + attention_kwargs, + guidance_scale, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ) + + self.scheduler._init_step_index(timesteps[0]) + output[:, :, cache_start_frame : cache_start_frame + block_size] = denoised_latents + # TODO: update cache and recompute with clean latents + with self.transformer.cache_context("cond"): + _ = self.transformer( + hidden_states=denoised_latents, + encoder_hidden_states=prompt_embeds, + timestep=torch.tensor([0] * batch_size, device=timesteps.device), + attention_kwargs=attention_kwargs, + return_dict=False, + ) + with self.transformer.cache_context("uncond"): + _ = self.transformer( + hidden_states=denoised_latents, + encoder_hidden_states=negative_prompt_embeds, + timestep=torch.tensor([0] * batch_size, device=timesteps.device), + attention_kwargs=attention_kwargs, + return_dict=False, + ) + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KreaPipelineOutput(frames=video) + + def denoise_once( + self, + latent_model_input, + timesteps, + mask, + prompt_embeds, + negative_prompt_embeds, + attention_kwargs, + guidance_scale, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + ): + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latent_model_input.shape[0], -1) + else: + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latent_model_input, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + + return latents diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline_output.py b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline_output.py new file mode 100644 index 000000000000..69343672d5ab --- /dev/null +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KreaPipelineOutput(BaseOutput): + r""" + Output class for Krea pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From 1bcec1a3956c696a2e485ac4ce9114d70b38035d Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 2 Dec 2025 10:09:13 +0100 Subject: [PATCH 2/6] a few updates --- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/_helpers.py | 10 ++ src/diffusers/hooks/rolling_kv_cache.py | 115 +++++++++----- .../models/transformers/transformer_krea.py | 144 +++++++++--------- .../pipeline.py | 120 +++++++++------ .../scheduling_flow_match_euler_discrete.py | 4 +- 6 files changed, 234 insertions(+), 160 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..ceaa7e211ed0 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -24,4 +24,5 @@ from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .rolling_kv_cache import RollingKVCacheConfig, apply_rolling_kv_cache from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 790199f3c978..3ae3fd57d0c6 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -175,6 +175,7 @@ def _register_transformer_blocks_metadata(): HunyuanImageSingleTransformerBlock, HunyuanImageTransformerBlock, ) + from ..models.transformers.transformer_krea import KreaTransformerBlock from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock @@ -287,6 +288,15 @@ def _register_transformer_blocks_metadata(): ), ) + # Krea + TransformerBlockRegistry.register( + model_class=KreaTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + # QwenImage TransformerBlockRegistry.register( model_class=QwenImageTransformerBlock, diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index a7b0dd9a027f..3e7d82502cb0 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -19,19 +19,23 @@ from ..models.attention_processor import Attention from ..utils import get_logger from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import ModelHook, StateManager +from .hooks import HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name +ROLLING_KV_CACHE_HOOK = "rolling_kv_cache_hook" + @dataclass -class RollingKVCacheCacheConfig: +class RollingKVCacheConfig: local_attn_size: int = -1 num_sink_tokens: int = 1 frame_seq_length: int = 128 batch_size: int = 1 + num_layers: int = None max_seq_length: int = 32760 @@ -41,16 +45,16 @@ class RollingKVCachekHook(ModelHook): def __init__( self, - state_manager: StateManager, batch_size: int, max_seq_length: int, num_sink_tokens: int, frame_seq_length: int, + num_layers: int, local_attn_size: int, ): - self.state_manager = state_manager self.batch_size = batch_size self.num_sink_tokens = num_sink_tokens + self.num_layers = num_layers if local_attn_size != -1: self.max_seq_length = local_attn_size * frame_seq_length else: @@ -61,21 +65,20 @@ def __init__( def initialize_hook(self, module): unwrapped_module = unwrap_module(module) self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) - components = unwrapped_module.components() - if "transformer" not in components: - raise ValueError( - f"{unwrapped_module.__class__.__name__} has no transformer block and can't apply a Rolling KV cache." - ) - transformer = components["transformer"] - self.dtype = transformer.device - self.device = transformer.dtype - self.num_layers = len(transformer.blocks) - num_heads = transformer.config.num_heads - hidden_dim = transformer.config.dim - encoder_hidden_dim = transformer.config.encoder_dim # whats the common name? - self.self_attn_kv_shape = [self.batch_size, self.max_seq_length, num_heads, hidden_dim // num_heads] - self.cross_attn_kv_shape = [self.batch_size, encoder_hidden_dim, num_heads, hidden_dim // num_heads] + # No access to config anymore from each transformer block? Would be great to get dims from config + self.self_attn_kv_shape = [ + self.batch_size, + self.max_seq_length, + module.num_heads, + module.dim // module.num_heads, + ] + self.cross_attn_kv_shape = [ + self.batch_size, + module.encoder_dim, + module.num_heads, + module.dim // module.num_heads, + ] return module def lazy_initialize_cache(self, device: str, dtype: torch.dtype): @@ -83,14 +86,11 @@ def lazy_initialize_cache(self, device: str, dtype: torch.dtype): Initialize a Per-GPU KV cache for the Wan model. """ if not self.cache_initialized: - self.key_cache = torch.zeros(self.self_attn_kv_shape, device=device, dtype=dtype) - self.value_cache = torch.zeros(self.self_attn_kv_shape, device=device, dtype=dtype) - self.cross_key_cache = torch.zeros(self.cross_attn_kv_shape, device=device, dtype=dtype) - self.cross_value_cache = torch.zeros(self.cross_attn_kv_shape, device=device, dtype=dtype) + self.cache = CacheLayer( + self.num_sink_tokens, self.self_attn_kv_shape, self.cross_attn_kv_shape, device=device, dtype=dtype + ) self.cache_initialized = True - self.global_end_index - self.local_end_index - return self.key_cache, self.value_cache, self.cross_key_cache, self.cross_value_cache + return self.cache def new_forward(self, module: Attention, *args, **kwargs): original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) @@ -99,38 +99,73 @@ def new_forward(self, module: Attention, *args, **kwargs): output = self.fn_ref.original_forward(*args, **kwargs) return output - def reset_cache(self, module): + def reset_state(self, module): if self.cache_initialized: - self.key_cache.zero_() - self.value_cache.zero_() - self.cross_key_cache.zero_() - self.cross_value_cache.zero_() - self.global_end_index = 0 - self.local_end_index = 0 - self.local_start_index = 0 + self.cache.reset() return module + +class CacheLayer: + def __init__(self, num_sink_tokens, self_attn_kv_shape, cross_attn_kv_shape, device, dtype) -> None: + self.key_cache = torch.zeros(self_attn_kv_shape, device=device, dtype=dtype) + self.value_cache = torch.zeros(self_attn_kv_shape, device=device, dtype=dtype) + # self.cross_key_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) + # self.cross_value_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) + self.global_end_index = 0 + self.local_end_index = 0 + self.local_start_index = 0 + self.num_sink_tokens = num_sink_tokens + self.max_seq_length = self_attn_kv_shape[1] + + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() + # self.cross_key_cache.zero_() + # self.cross_value_cache.zero_() + self.global_end_index = 0 + self.local_end_index = 0 + self.local_start_index = 0 + @torch.compiler.disable def update(self, key_states: torch.Tensor, value_states: torch.Tensor) -> bool: # Assign new keys/values directly up to current_end - start_idx, end_idx = self.maybe_roll_back(key_states.shape[2]) + start_idx, end_idx = self.maybe_roll_back(key_states.shape[1]) self.key_cache[:, start_idx:end_idx] = key_states self.value_cache[:, start_idx:end_idx] = value_states - self.local_start_index += key_states.shape[0] + # self.local_start_index += key_states.shape[1] return key_states, value_states @torch.compiler.disable def maybe_roll_back(self, num_new_tokens: int): if num_new_tokens + self.local_end_index > self.max_seq_length: - num_evicted_tokens = self.max_seq_length - (num_new_tokens + self.local_end_index) + num_evicted_tokens = (num_new_tokens + self.local_end_index) - self.max_seq_length else: num_evicted_tokens = 0 - # Skip `sink_tokens` and `num_evicted_tokens`. Roll back cache by removing the evicted tokens - num_tokens_to_skip = self.sink_tokens + num_evicted_tokens - self.key_cache[:, self.sink_tokens :] = self.key_cache[:, num_tokens_to_skip:].clone() - self.value_cache[:, self.sink_tokens :] = self.value_cache[:, num_tokens_to_skip:].clone() + # Skip `sink_tokens` and `evicted_tokens`. Roll back cache by removing the evicted tokens + num_tokens_to_skip = self.num_sink_tokens + num_evicted_tokens + # self.key_cache[:, self.num_sink_tokens : self.num_sink_tokens + num_tokens_to_skip] = self.key_cache[:, num_tokens_to_skip:].clone() + # self.value_cache[:, self.num_sink_tokens : self.num_sink_tokens + num_tokens_to_skip] = self.value_cache[:, num_tokens_to_skip:].clone() + self.key_cache.roll(-num_tokens_to_skip, dims=1) + self.value_cache.roll(-num_tokens_to_skip, dims=1) self.local_start_index = self.local_start_index - num_evicted_tokens self.local_end_index = self.local_start_index + num_new_tokens return self.local_start_index, self.local_end_index + + +def apply_rolling_kv_cache(module: torch.nn.Module, config: RollingKVCacheConfig) -> None: + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for block in submodule: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = RollingKVCachekHook( + batch_size=config.batch_size, + max_seq_length=config.max_seq_length, + num_sink_tokens=config.num_sink_tokens, + frame_seq_length=config.frame_seq_length, + num_layers=config.num_layers, + local_attn_size=config.local_attn_size, + ) + registry.register_hook(hook, ROLLING_KV_CACHE_HOOK) diff --git a/src/diffusers/models/transformers/transformer_krea.py b/src/diffusers/models/transformers/transformer_krea.py index 7940eefe5c44..8b7156feccc4 100644 --- a/src/diffusers/models/transformers/transformer_krea.py +++ b/src/diffusers/models/transformers/transformer_krea.py @@ -65,6 +65,25 @@ def _get_added_kv_projections(attn: "KreaAttention", encoder_hidden_states_img: return key_img, value_img +class KreaRMSNorm(nn.Module): + """ + KreaRMSNorm is equivalent to LlamaRMSNorm + """ + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return self.weight * hidden_states.to(input_dtype) + + class KreaAttnProcessor: _attention_backend = None _parallel_config = None @@ -82,6 +101,7 @@ def __call__( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: list[torch.Tensor] = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -117,6 +137,9 @@ def apply_rotary_emb( query = apply_rotary_emb(query, *rotary_emb) key = apply_rotary_emb(key, *rotary_emb) + if kv_cache is not None: + key, value = kv_cache.update(key, value) + # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: @@ -198,14 +221,14 @@ def __init__( self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True) - self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) - self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_q = KreaRMSNorm(dim_head * heads, eps=eps) + self.norm_k = KreaRMSNorm(dim_head * heads, eps=eps) self.add_k_proj = self.add_v_proj = None if added_kv_proj_dim is not None: self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) - self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + self.norm_added_k = nn.RMSNorm(dim_head * heads, eps=eps) self.is_cross_attention = cross_attention_dim_head is not None @@ -266,9 +289,12 @@ def forward( encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: list[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + return self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, kv_cache, **kwargs + ) class KreaImageEmbedding(torch.nn.Module): @@ -322,12 +348,8 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, - timestep_seq_len: Optional[int] = None, ): timestep = self.timesteps_proj(timestep) - if timestep_seq_len is not None: - timestep = timestep.unflatten(0, (-1, timestep_seq_len)) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) @@ -419,8 +441,12 @@ def __init__( ): super().__init__() + self.num_heads = num_heads + self.dim = dim + self.encoder_dim = dim + # 1. Self-attention - self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.norm1 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.attn1 = KreaAttention( dim=dim, heads=num_heads, @@ -440,16 +466,11 @@ def __init__( cross_attention_dim_head=dim // num_heads, processor=KreaAttnProcessor(), ) - self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.norm3 = nn.LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") - # self.ffn = nn.Sequential( - # nn.Linear(dim, ffn_dim), - # nn.GELU(approximate="tanh"), - # nn.Linear(ffn_dim, dim), - # ) - self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(dim, eps, elementwise_affine=False) self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -459,41 +480,32 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, rotary_emb: torch.Tensor, + clean_latents: bool = False, + kv_cache: list[torch.Tensor] = None, ) -> torch.Tensor: - if temb.ndim == 4: - # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table.unsqueeze(0) + temb.float() - ).chunk(6, dim=2) - # batch_size, seq_len, 1, inner_dim - shift_msa = shift_msa.squeeze(2) - scale_msa = scale_msa.squeeze(2) - gate_msa = gate_msa.squeeze(2) - c_shift_msa = c_shift_msa.squeeze(2) - c_scale_msa = c_scale_msa.squeeze(2) - c_gate_msa = c_gate_msa.squeeze(2) - else: - # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() - ).chunk(6, dim=1) + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb # .float() + ).chunk(6, dim=1) # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) - hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, kv_cache) + hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention - norm_hidden_states = self.norm3(hidden_states.float()).type_as(hidden_states) + norm_hidden_states = self.norm3(hidden_states).type_as(hidden_states) attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - hidden_states - ) + norm_hidden_states = (self.norm2(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) - hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states) + + # Last: update the cache position once per denoising loop - when recomputing cache with clean latent + if clean_latents: + kv_cache.local_start_index += hidden_states.shape[1] return hidden_states @@ -540,7 +552,7 @@ class KreaTransformerModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["KreaTransformerBlock"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + # _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["KreaTransformerBlock"] _cp_plan = { @@ -611,7 +623,7 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = nn.LayerNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) @@ -623,6 +635,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + clean_latents: bool = False, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: @@ -641,7 +654,7 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape + batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h @@ -652,22 +665,11 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) - # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) - if timestep.ndim == 2: - ts_seq_len = timestep.shape[1] - timestep = timestep.flatten() # batch_size * seq_len - else: - ts_seq_len = None - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + timestep, encoder_hidden_states, encoder_hidden_states_image ) - if ts_seq_len is not None: - # batch_size, seq_len, 6, inner_dim - timestep_proj = timestep_proj.unflatten(2, (6, -1)) - else: - # batch_size, 6, inner_dim - timestep_proj = timestep_proj.unflatten(1, (6, -1)) + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) @@ -676,21 +678,25 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + clean_latents, ) else: for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + clean_latents, + ) # 5. Output norm, projection & unpatchify - if temb.ndim == 3: - # batch_size, seq_len, inner_dim (wan 2.2 ti2v) - shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) - shift = shift.squeeze(2) - scale = scale.squeeze(2) - else: - # batch_size, inner_dim - shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) # Move the shift and scale tensors to the same device as hidden_states. # When using multi-GPU inference via accelerate these will be on the @@ -699,7 +705,7 @@ def forward( shift = shift.to(hidden_states.device) scale = scale.to(hidden_states.device) - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py index 243bfed0757b..4b2504215ca4 100644 --- a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py @@ -152,9 +152,9 @@ class BaseAutoregressiveDiffusionPipeline(DiffusionPipeline, WanLoraLoaderMixin) tokenizer (`T5Tokenizer`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`UniPCMultistepScheduler`]): + transformer ([`FlowMatchEulerDiscreteScheduler`]): A text conditioned `KreaTransformerModel` to denoise the encoded video latents. - scheduler ([`UniPCMultistepScheduler`]): + scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ @@ -438,10 +438,21 @@ def using_cache(self): for module_name, module in self.transformer.named_modules(): if module_name == "": continue - if hasattr(module, "_diffusers_hook") and "rolling_kv_hook" in module._diffusers_hook.hooks.keys(): + if hasattr(module, "_diffusers_hook") and "rolling_kv_cache_hook" in module._diffusers_hook.hooks.keys(): return True return False + def prepare_blockwise_mask(self, batch_size, kv_length, block_size, dtype, device): + idx = torch.arange(kv_length, device=device) + blocks = idx // block_size + + # blockwise permission: j’s block <= i’s block + block_mask = blocks[:, None] >= blocks[None, :] + block_mask = block_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + inverted_mask = 1.0 - block_mask.int() + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + @torch.no_grad() def __call__( self, @@ -452,7 +463,6 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 6, guidance_scale: float = 5.0, - guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -550,7 +560,6 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, - guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -560,11 +569,7 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - if self.config.boundary_ratio is not None and guidance_scale_2 is None: - guidance_scale_2 = guidance_scale - self._guidance_scale = guidance_scale - self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -628,54 +633,78 @@ def __call__( mask = torch.ones(latents.shape, dtype=torch.float32, device=device) latent_h, latent_w = latents.shape[-2:] - # 6. Denoising loop + # 7. Denoising loop # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) output = torch.zeros( - batch_size, num_channels_latents, num_frames, latent_h, latent_w, device=latents.device, dtype=latents.dtype + batch_size, + num_channels_latents, + num_frames, + latent_h, + latent_w, + device=latents.device, + dtype=latents.dtype, ) with self.progress_bar(total=num_blocks) as progress_bar: for block in range(num_blocks): - cache_start_frame = block * block_size if self.using_cache else 0 + cache_start_frame = block * block_size current_latents = latents[:, :, cache_start_frame : cache_start_frame + block_size] + kv_length = cache_start_frame + block_size + if not self.using_cache(): + # If the transformer has no cache hooks attached, we need a causal blockwise mask + # and concatenate past latents with the inputs + current_latents = torch.cat([output[:, :, :cache_start_frame], current_latents], dim=2) + attention_kwargs["attention_mask"] = self.prepare_blockwise_mask( + batch_size=batch_size, + kv_length=kv_length, + block_size=block_size, + dtype=transformer_dtype, + device=current_latents.device, + ) + current_latents = current_latents.to(transformer_dtype) - denoised_latents = self.denoise_once( - current_latents, - timesteps, - mask, - prompt_embeds, - negative_prompt_embeds, - attention_kwargs, - guidance_scale, - callback_on_step_end, - callback_on_step_end_tensor_inputs, + denoised_latents = self.denoise_single_block( + latent_model_input=current_latents, + timesteps=timesteps, + mask=mask, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + guidance_scale=guidance_scale, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self.scheduler._init_step_index(timesteps[0]) - output[:, :, cache_start_frame : cache_start_frame + block_size] = denoised_latents + output[:, :, cache_start_frame : cache_start_frame + block_size] = denoised_latents[:, :, -block_size:] + # TODO: update cache and recompute with clean latents - with self.transformer.cache_context("cond"): - _ = self.transformer( - hidden_states=denoised_latents, - encoder_hidden_states=prompt_embeds, - timestep=torch.tensor([0] * batch_size, device=timesteps.device), - attention_kwargs=attention_kwargs, - return_dict=False, - ) - with self.transformer.cache_context("uncond"): - _ = self.transformer( - hidden_states=denoised_latents, - encoder_hidden_states=negative_prompt_embeds, - timestep=torch.tensor([0] * batch_size, device=timesteps.device), - attention_kwargs=attention_kwargs, - return_dict=False, - ) + if self.using_cache(): + with self.transformer.cache_context("cond"): + _ = self.transformer( + hidden_states=denoised_latents, + encoder_hidden_states=prompt_embeds, + timestep=torch.tensor([0] * batch_size, device=timesteps.device), + attention_kwargs=attention_kwargs, + clean_latents=True, + return_dict=False, + ) + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + _ = self.transformer( + hidden_states=denoised_latents, + encoder_hidden_states=negative_prompt_embeds, + timestep=torch.tensor([0] * batch_size, device=timesteps.device), + attention_kwargs=attention_kwargs, + clean_latents=True, + return_dict=False, + ) progress_bar.update() self._current_timestep = None - if not output_type == "latent": + if output_type != "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -699,7 +728,7 @@ def __call__( return KreaPipelineOutput(frames=video) - def denoise_once( + def denoise_single_block( self, latent_model_input, timesteps, @@ -716,13 +745,7 @@ def denoise_once( continue self._current_timestep = t - if self.config.expand_timesteps: - # seq_len: num_latent_frames * latent_height//2 * latent_width//2 - temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() - # batch_size, seq_len - timestep = temp_ts.unsqueeze(0).expand(latent_model_input.shape[0], -1) - else: - timestep = t.expand(latent_model_input.shape[0]) + timestep = t.expand(latent_model_input.shape[0]) with self.transformer.cache_context("cond"): noise_pred = self.transformer( @@ -732,7 +755,6 @@ def denoise_once( attention_kwargs=attention_kwargs, return_dict=False, )[0] - if self.do_classifier_free_guidance: with self.transformer.cache_context("uncond"): noise_uncond = self.transformer( diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 1a4f12ddfa53..f20151f8f981 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -114,7 +114,7 @@ def __init__( if time_shift_type not in {"exponential", "linear"}: raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") - timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = np.linspace(0, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps @@ -315,7 +315,7 @@ def set_timesteps( sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value - if self.config.shift_terminal: + if self.config.shift_terminal is not None: sigmas = self.stretch_shift_to_terminal(sigmas) # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules From dd5dfdbc64da911bd153f66049e96af56b04c845 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 3 Dec 2025 15:38:14 +0100 Subject: [PATCH 3/6] issues with naming went unnoticed and I wasted too much time :( --- .../pipeline.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py index 4b2504215ca4..bc7c572fc2bc 100644 --- a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py @@ -562,14 +562,8 @@ def __call__( callback_on_step_end_tensor_inputs, ) - if num_frames % self.vae_scale_factor_temporal != 1: - logger.warning( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - self._guidance_scale = guidance_scale + attention_kwargs = attention_kwargs if attention_kwargs is not None else {} self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -665,13 +659,14 @@ def __call__( current_latents = current_latents.to(transformer_dtype) denoised_latents = self.denoise_single_block( - latent_model_input=current_latents, + latents=current_latents, timesteps=timesteps, mask=mask, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, attention_kwargs=attention_kwargs, guidance_scale=guidance_scale, + generator=generator, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) @@ -705,20 +700,20 @@ def __call__( self._current_timestep = None if output_type != "latent": - latents = latents.to(self.vae.dtype) + output = output.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(output.device, output.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + output.device, output.dtype ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + output = output / latents_std + latents_mean + video = self.vae.decode(output, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: - video = latents + video = output # Offload all models self.maybe_free_model_hooks() @@ -730,13 +725,14 @@ def __call__( def denoise_single_block( self, - latent_model_input, + latents, timesteps, mask, prompt_embeds, negative_prompt_embeds, attention_kwargs, guidance_scale, + generator, callback_on_step_end, callback_on_step_end_tensor_inputs, ): @@ -745,11 +741,11 @@ def denoise_single_block( continue self._current_timestep = t - timestep = t.expand(latent_model_input.shape[0]) + timestep = t.expand(latents.shape[0]) with self.transformer.cache_context("cond"): noise_pred = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, @@ -758,7 +754,7 @@ def denoise_single_block( if self.do_classifier_free_guidance: with self.transformer.cache_context("uncond"): noise_uncond = self.transformer( - hidden_states=latent_model_input, + hidden_states=latents, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, @@ -767,7 +763,7 @@ def denoise_single_block( noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latent_model_input, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, generator=generator, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} From 70603a764575791e33c41bd8a61c24c45a84411e Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 3 Dec 2025 15:39:14 +0100 Subject: [PATCH 4/6] add optional generator + compute in double --- .../schedulers/scheduling_flow_match_euler_discrete.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index f20151f8f981..623c9e2f9faa 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -454,8 +454,13 @@ def step( dt = sigma_next - sigma if self.config.stochastic_sampling: - x0 = sample - current_sigma * model_output - noise = torch.randn_like(sample) + x0 = (sample.double() - current_sigma.double() * model_output.double()).to(sample.dtype) + noise = torch.randn( + sample.shape, + device=sample.device, + dtype=sample.dtype, + generator=generator, + ) prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise else: prev_sample = sample + dt * model_output From fecb75cd5ed3d3944e938fb953ba45935f5499e4 Mon Sep 17 00:00:00 2001 From: raushan Date: Sat, 6 Dec 2025 19:25:04 +0100 Subject: [PATCH 5/6] commit before i refactor and clean --- src/diffusers/hooks/rolling_kv_cache.py | 112 ++++++++++++++---- .../models/transformers/transformer_krea.py | 48 ++++++-- .../pipeline.py | 12 +- 3 files changed, 134 insertions(+), 38 deletions(-) diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index 3e7d82502cb0..cb688c4d305b 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -51,10 +51,12 @@ def __init__( frame_seq_length: int, num_layers: int, local_attn_size: int, + layer_idx: int = None, ): self.batch_size = batch_size self.num_sink_tokens = num_sink_tokens self.num_layers = num_layers + self.layer_idx = layer_idx if local_attn_size != -1: self.max_seq_length = local_attn_size * frame_seq_length else: @@ -87,7 +89,12 @@ def lazy_initialize_cache(self, device: str, dtype: torch.dtype): """ if not self.cache_initialized: self.cache = CacheLayer( - self.num_sink_tokens, self.self_attn_kv_shape, self.cross_attn_kv_shape, device=device, dtype=dtype + self.num_sink_tokens, + self.self_attn_kv_shape, + self.cross_attn_kv_shape, + self.layer_idx, + device=device, + dtype=dtype, ) self.cache_initialized = True return self.cache @@ -106,14 +113,16 @@ def reset_state(self, module): class CacheLayer: - def __init__(self, num_sink_tokens, self_attn_kv_shape, cross_attn_kv_shape, device, dtype) -> None: + def __init__(self, num_sink_tokens, self_attn_kv_shape, cross_attn_kv_shape, layer_idx, device, dtype) -> None: self.key_cache = torch.zeros(self_attn_kv_shape, device=device, dtype=dtype) self.value_cache = torch.zeros(self_attn_kv_shape, device=device, dtype=dtype) # self.cross_key_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) # self.cross_value_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) + self.layer_idx = layer_idx self.global_end_index = 0 - self.local_end_index = 0 + self.global_start_index = 0 self.local_start_index = 0 + self.local_end_index = 0 self.num_sink_tokens = num_sink_tokens self.max_seq_length = self_attn_kv_shape[1] @@ -123,42 +132,98 @@ def reset(self): # self.cross_key_cache.zero_() # self.cross_value_cache.zero_() self.global_end_index = 0 - self.local_end_index = 0 + self.global_start_index = 0 self.local_start_index = 0 + self.local_end_index = 0 @torch.compiler.disable - def update(self, key_states: torch.Tensor, value_states: torch.Tensor) -> bool: - # Assign new keys/values directly up to current_end - start_idx, end_idx = self.maybe_roll_back(key_states.shape[1]) - self.key_cache[:, start_idx:end_idx] = key_states - self.value_cache[:, start_idx:end_idx] = value_states - # self.local_start_index += key_states.shape[1] - return key_states, value_states + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, rotary_emb: tuple[torch.Tensor, torch.Tensor], roll=None) -> bool: + # Assign new keys/values directly up to current_end and update running cache positions + self.local_start_index = self.maybe_roll_back(key_states.shape[1], rotary_emb=rotary_emb, roll=roll) + self.local_end_index = self.local_start_index + key_states.shape[1] + + self.key_cache[:, self.local_start_index : self.local_end_index].copy_(key_states) + self.value_cache[:, self.local_start_index : self.local_end_index].copy_(value_states) + return self.key_cache[:, : self.local_end_index], self.value_cache[:, : self.local_end_index] @torch.compiler.disable - def maybe_roll_back(self, num_new_tokens: int): + def maybe_roll_back(self, num_new_tokens: int, rotary_emb: tuple[torch.Tensor, torch.Tensor], roll=None): + # Skip `sink_tokens` and `evicted_tokens`. Roll back cache by removing evicted tokens if num_new_tokens + self.local_end_index > self.max_seq_length: num_evicted_tokens = (num_new_tokens + self.local_end_index) - self.max_seq_length + num_evicted_tokens_global = (num_new_tokens + self.global_start_index) - self.max_seq_length + # num_tokens_to_skip = self.num_sink_tokens + num_evicted_tokens + + if roll: + keys_to_keep = self.key_cache[:, self.num_sink_tokens + num_evicted_tokens :] + keys_to_keep = self.rerotate_key_rotary_pos_emb(keys_to_keep, *rotary_emb, num_evicted_tokens) + + # tail_keys = self.key_cache[:, self.num_sink_tokens :].clone() + # tail_keys = tail_keys.roll(-num_evicted_tokens, dims=1) + self.key_cache[:, self.num_sink_tokens : -num_evicted_tokens].copy_(keys_to_keep) + + tail_values = self.value_cache[:, self.num_sink_tokens :].clone() + tail_values = tail_values.roll(-num_evicted_tokens, dims=1) + self.value_cache[:, self.num_sink_tokens :].copy_(tail_values) else: - num_evicted_tokens = 0 + num_evicted_tokens_global = num_evicted_tokens = 0 - # Skip `sink_tokens` and `evicted_tokens`. Roll back cache by removing the evicted tokens - num_tokens_to_skip = self.num_sink_tokens + num_evicted_tokens - # self.key_cache[:, self.num_sink_tokens : self.num_sink_tokens + num_tokens_to_skip] = self.key_cache[:, num_tokens_to_skip:].clone() - # self.value_cache[:, self.num_sink_tokens : self.num_sink_tokens + num_tokens_to_skip] = self.value_cache[:, num_tokens_to_skip:].clone() - self.key_cache.roll(-num_tokens_to_skip, dims=1) - self.value_cache.roll(-num_tokens_to_skip, dims=1) + start_idx = self.global_start_index - num_evicted_tokens_global - self.local_start_index = self.local_start_index - num_evicted_tokens - self.local_end_index = self.local_start_index + num_new_tokens - return self.local_start_index, self.local_end_index + if self.layer_idx == 0: + print( + num_evicted_tokens_global, num_evicted_tokens, self.global_start_index, self.global_end_index, self.local_start_index, self.local_end_index, start_idx + ) + return start_idx + + def recompute_indices(self, num_new_tokens: int): + self.global_start_index = self.global_start_index + num_new_tokens + self.global_end_index = self.global_start_index + num_new_tokens + self.local_start_index = min(self.max_seq_length - num_new_tokens, self.local_start_index + num_new_tokens) + + @staticmethod + def _apply_rope(hidden_states, freqs_cos, freqs_sin): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + def rerotate_key_rotary_pos_emb( + self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, num_evicted_tokens: int + ) -> tuple[torch.Tensor, torch.Tensor]: + # Compute the cos and sin required for back- and forward-rotating to `num_evicted_tokens` position earlier in the sequence + original_cos = cos[:, self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + num_evicted_tokens + key_states.shape[1]] + shifted_cos = cos[:, self.num_sink_tokens : self.num_sink_tokens + key_states.shape[1]] + original_sin = sin[:, self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + num_evicted_tokens + key_states.shape[1]] + shifted_sin = sin[:, self.num_sink_tokens : self.num_sink_tokens + key_states.shape[1]] + rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin + rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin + + # rerotation_cos = cos[:, num_evicted_tokens] + # rerotation_sin = -sin[:, num_evicted_tokens] + rotated_key_states = self._apply_rope(key_states, rerotation_cos, rerotation_sin) + + # if self.layer_idx == 0: + # print(key_states.shape, num_evicted_tokens, original_cos.shape, shifted_cos.shape) + return rotated_key_states + + def get_rerotated_num_positions(self): + if (4680 * 2) + self.global_start_index > self.max_seq_length: + num_evicted_tokens = ((4680 * 2) + self.global_start_index) - self.max_seq_length + num_tokens_to_rotate_back = num_evicted_tokens - self.num_sink_tokens + else: + num_tokens_to_rotate_back = 0 + return num_tokens_to_rotate_back def apply_rolling_kv_cache(module: torch.nn.Module, config: RollingKVCacheConfig) -> None: for name, submodule in module.named_children(): if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): continue - for block in submodule: + for i, block in enumerate(submodule): registry = HookRegistry.check_if_exists_or_initialize(block) hook = RollingKVCachekHook( batch_size=config.batch_size, @@ -167,5 +232,6 @@ def apply_rolling_kv_cache(module: torch.nn.Module, config: RollingKVCacheConfig frame_seq_length=config.frame_seq_length, num_layers=config.num_layers, local_attn_size=config.local_attn_size, + layer_idx=i, ) registry.register_hook(hook, ROLLING_KV_CACHE_HOOK) diff --git a/src/diffusers/models/transformers/transformer_krea.py b/src/diffusers/models/transformers/transformer_krea.py index 8b7156feccc4..901bf6f93214 100644 --- a/src/diffusers/models/transformers/transformer_krea.py +++ b/src/diffusers/models/transformers/transformer_krea.py @@ -102,6 +102,8 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, kv_cache: list[torch.Tensor] = None, + clean_latents: bool = None, + timestep = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -125,7 +127,12 @@ def apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, + kv_cache: torch.Tensor, ): + block_length = hidden_states.shape[1] + start_index = kv_cache.local_start_index if kv_cache is not None else 0 + freqs_cos = freqs_cos[:, start_index : start_index + block_length] + freqs_sin = freqs_sin[:, start_index : start_index + block_length] x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] @@ -134,11 +141,16 @@ def apply_rotary_emb( out[..., 1::2] = x1 * sin + x2 * cos return out.type_as(hidden_states) - query = apply_rotary_emb(query, *rotary_emb) - key = apply_rotary_emb(key, *rotary_emb) + # if clean_latents and kv_cache.local_start_index != 0 and kv_cache.layer_idx == 0: + # keys_with_rope_correct = apply_rotary_emb(key, *rotary_emb, None) + # print("What I need", keys_with_rope_correct.shape, keys_with_rope_correct[:, -3120:]) + + query = apply_rotary_emb(query, *rotary_emb, kv_cache) + key = apply_rotary_emb(key, *rotary_emb, kv_cache) if kv_cache is not None: - key, value = kv_cache.update(key, value) + roll = timestep.flatten() == 1000 + key, value = kv_cache.update(key, value, rotary_emb, roll) # I2V task hidden_states_img = None @@ -202,11 +214,9 @@ def __init__( heads: int = 8, dim_head: int = 64, eps: float = 1e-5, - dropout: float = 0.0, added_kv_proj_dim: Optional[int] = None, cross_attention_dim_head: Optional[int] = None, processor=None, - is_cross_attention=None, ): super().__init__() @@ -399,8 +409,9 @@ def __init__( self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, num_channels, num_frames, height, width = hidden_states.shape + def forward(self, hidden_states: torch.Tensor, max_num_frames: int | None = None) -> torch.Tensor: + _, _, num_frames, height, width = hidden_states.shape + num_frames = max_num_frames if max_num_frames is not None else num_frames p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w @@ -481,6 +492,7 @@ def forward( temb: torch.Tensor, rotary_emb: torch.Tensor, clean_latents: bool = False, + timestep = None, kv_cache: list[torch.Tensor] = None, ) -> torch.Tensor: # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) @@ -488,9 +500,22 @@ def forward( self.scale_shift_table + temb # .float() ).chunk(6, dim=1) + # NOTE: ideally if we can pass `position_ids` to the model's forward + # and use it to compute cos/sin. Similar to simply autoregressive + # LLMs For now let's initialize one big cos/sin array and slice it + # for current frame position + block_length = hidden_states.shape[1] + start_index = kv_cache.local_start_index if kv_cache is not None else 0 + # rotary_emb = ( + # rotary_emb[0][:, start_index : start_index + block_length], + # rotary_emb[1][:, start_index : start_index + block_length], + # ) + # if kv_cache.layer_idx == 0: + # print(clean_latents, start_index, kv_cache.key_cache[:, :4680]) + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, kv_cache) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, kv_cache, clean_latents=clean_latents, timestep=timestep) hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -505,7 +530,7 @@ def forward( # Last: update the cache position once per denoising loop - when recomputing cache with clean latent if clean_latents: - kv_cache.local_start_index += hidden_states.shape[1] + kv_cache.recompute_indices(num_new_tokens=hidden_states.shape[1]) return hidden_states @@ -635,6 +660,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + max_num_frames: Optional[int] = None, clean_latents: bool = False, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -660,7 +686,7 @@ def forward( post_patch_height = height // p_h post_patch_width = width // p_w - rotary_emb = self.rope(hidden_states) + rotary_emb = self.rope(hidden_states, max_num_frames=max_num_frames) hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -684,6 +710,7 @@ def forward( timestep_proj, rotary_emb, clean_latents, + timestep, ) else: for block in self.blocks: @@ -693,6 +720,7 @@ def forward( timestep_proj, rotary_emb, clean_latents, + timestep, ) # 5. Output norm, projection & unpatchify diff --git a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py index bc7c572fc2bc..41619bd2330a 100644 --- a/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py +++ b/src/diffusers/pipelines/autoregressive_block_diffusion/pipeline.py @@ -624,7 +624,6 @@ def __call__( latents, ) - mask = torch.ones(latents.shape, dtype=torch.float32, device=device) latent_h, latent_w = latents.shape[-2:] # 7. Denoising loop @@ -661,12 +660,12 @@ def __call__( denoised_latents = self.denoise_single_block( latents=current_latents, timesteps=timesteps, - mask=mask, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, attention_kwargs=attention_kwargs, guidance_scale=guidance_scale, generator=generator, + max_num_frames=num_frames, callback_on_step_end=callback_on_step_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) @@ -674,14 +673,14 @@ def __call__( self.scheduler._init_step_index(timesteps[0]) output[:, :, cache_start_frame : cache_start_frame + block_size] = denoised_latents[:, :, -block_size:] - # TODO: update cache and recompute with clean latents - if self.using_cache(): + if self.using_cache() and block != num_blocks - 1: with self.transformer.cache_context("cond"): _ = self.transformer( hidden_states=denoised_latents, encoder_hidden_states=prompt_embeds, timestep=torch.tensor([0] * batch_size, device=timesteps.device), attention_kwargs=attention_kwargs, + max_num_frames=num_frames, clean_latents=True, return_dict=False, ) @@ -692,6 +691,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, timestep=torch.tensor([0] * batch_size, device=timesteps.device), attention_kwargs=attention_kwargs, + max_num_frames=num_frames, clean_latents=True, return_dict=False, ) @@ -727,12 +727,12 @@ def denoise_single_block( self, latents, timesteps, - mask, prompt_embeds, negative_prompt_embeds, attention_kwargs, guidance_scale, generator, + max_num_frames, callback_on_step_end, callback_on_step_end_tensor_inputs, ): @@ -749,6 +749,7 @@ def denoise_single_block( timestep=timestep, encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, + max_num_frames=max_num_frames, return_dict=False, )[0] if self.do_classifier_free_guidance: @@ -758,6 +759,7 @@ def denoise_single_block( timestep=timestep, encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, + max_num_frames=max_num_frames, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) From 0bd80ed670c8353ae3c677fca88d04a4cb634bbd Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 8 Dec 2025 18:44:47 +0100 Subject: [PATCH 6/6] this works by tracking the cache position internally in each layer --- src/diffusers/hooks/rolling_kv_cache.py | 88 +++++++------------ .../models/transformers/transformer_krea.py | 31 ++----- 2 files changed, 37 insertions(+), 82 deletions(-) diff --git a/src/diffusers/hooks/rolling_kv_cache.py b/src/diffusers/hooks/rolling_kv_cache.py index cb688c4d305b..e0daa6933397 100644 --- a/src/diffusers/hooks/rolling_kv_cache.py +++ b/src/diffusers/hooks/rolling_kv_cache.py @@ -119,10 +119,7 @@ def __init__(self, num_sink_tokens, self_attn_kv_shape, cross_attn_kv_shape, lay # self.cross_key_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) # self.cross_value_cache = torch.zeros(cross_attn_kv_shape, device=device, dtype=dtype) self.layer_idx = layer_idx - self.global_end_index = 0 - self.global_start_index = 0 self.local_start_index = 0 - self.local_end_index = 0 self.num_sink_tokens = num_sink_tokens self.max_seq_length = self_attn_kv_shape[1] @@ -131,55 +128,33 @@ def reset(self): self.value_cache.zero_() # self.cross_key_cache.zero_() # self.cross_value_cache.zero_() - self.global_end_index = 0 - self.global_start_index = 0 self.local_start_index = 0 - self.local_end_index = 0 @torch.compiler.disable - def update(self, key_states: torch.Tensor, value_states: torch.Tensor, rotary_emb: tuple[torch.Tensor, torch.Tensor], roll=None) -> bool: - # Assign new keys/values directly up to current_end and update running cache positions - self.local_start_index = self.maybe_roll_back(key_states.shape[1], rotary_emb=rotary_emb, roll=roll) - self.local_end_index = self.local_start_index + key_states.shape[1] - - self.key_cache[:, self.local_start_index : self.local_end_index].copy_(key_states) - self.value_cache[:, self.local_start_index : self.local_end_index].copy_(value_states) - return self.key_cache[:, : self.local_end_index], self.value_cache[:, : self.local_end_index] - - @torch.compiler.disable - def maybe_roll_back(self, num_new_tokens: int, rotary_emb: tuple[torch.Tensor, torch.Tensor], roll=None): + def update( + self, key_states: torch.Tensor, value_states: torch.Tensor, rotary_emb: tuple[torch.Tensor, torch.Tensor] + ) -> bool: + num_new_tokens = key_states.shape[1] # Skip `sink_tokens` and `evicted_tokens`. Roll back cache by removing evicted tokens - if num_new_tokens + self.local_end_index > self.max_seq_length: - num_evicted_tokens = (num_new_tokens + self.local_end_index) - self.max_seq_length - num_evicted_tokens_global = (num_new_tokens + self.global_start_index) - self.max_seq_length - # num_tokens_to_skip = self.num_sink_tokens + num_evicted_tokens - - if roll: - keys_to_keep = self.key_cache[:, self.num_sink_tokens + num_evicted_tokens :] - keys_to_keep = self.rerotate_key_rotary_pos_emb(keys_to_keep, *rotary_emb, num_evicted_tokens) - - # tail_keys = self.key_cache[:, self.num_sink_tokens :].clone() - # tail_keys = tail_keys.roll(-num_evicted_tokens, dims=1) - self.key_cache[:, self.num_sink_tokens : -num_evicted_tokens].copy_(keys_to_keep) - - tail_values = self.value_cache[:, self.num_sink_tokens :].clone() - tail_values = tail_values.roll(-num_evicted_tokens, dims=1) - self.value_cache[:, self.num_sink_tokens :].copy_(tail_values) - else: - num_evicted_tokens_global = num_evicted_tokens = 0 + if num_new_tokens + self.local_start_index > self.max_seq_length: + num_evicted_tokens = (num_new_tokens + self.local_start_index) - self.max_seq_length - start_idx = self.global_start_index - num_evicted_tokens_global + keys_to_keep = self.key_cache[:, self.num_sink_tokens + num_evicted_tokens :] + keys_to_keep = self.rerotate_key_rotary_pos_emb(keys_to_keep, *rotary_emb, num_evicted_tokens) + self.key_cache[:, self.num_sink_tokens : -num_evicted_tokens].copy_(keys_to_keep) - if self.layer_idx == 0: - print( - num_evicted_tokens_global, num_evicted_tokens, self.global_start_index, self.global_end_index, self.local_start_index, self.local_end_index, start_idx - ) - return start_idx + values_to_keep = self.value_cache[:, self.num_sink_tokens + num_evicted_tokens :] + self.value_cache[:, self.num_sink_tokens : -num_evicted_tokens].copy_(values_to_keep) + self.local_start_index = self.local_start_index - num_evicted_tokens + + # Assign new keys/values directly up to current_end and update running cache positions + end_index = self.local_start_index + key_states.shape[1] + self.key_cache[:, self.local_start_index : end_index].copy_(key_states) + self.value_cache[:, self.local_start_index : end_index].copy_(value_states) + return self.key_cache[:, :end_index], self.value_cache[:, :end_index] def recompute_indices(self, num_new_tokens: int): - self.global_start_index = self.global_start_index + num_new_tokens - self.global_end_index = self.global_start_index + num_new_tokens - self.local_start_index = min(self.max_seq_length - num_new_tokens, self.local_start_index + num_new_tokens) + self.local_start_index = self.local_start_index + num_new_tokens @staticmethod def _apply_rope(hidden_states, freqs_cos, freqs_sin): @@ -195,9 +170,19 @@ def rerotate_key_rotary_pos_emb( self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, num_evicted_tokens: int ) -> tuple[torch.Tensor, torch.Tensor]: # Compute the cos and sin required for back- and forward-rotating to `num_evicted_tokens` position earlier in the sequence - original_cos = cos[:, self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + num_evicted_tokens + key_states.shape[1]] + original_cos = cos[ + :, + self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + + num_evicted_tokens + + key_states.shape[1], + ] shifted_cos = cos[:, self.num_sink_tokens : self.num_sink_tokens + key_states.shape[1]] - original_sin = sin[:, self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + num_evicted_tokens + key_states.shape[1]] + original_sin = sin[ + :, + self.num_sink_tokens + num_evicted_tokens : self.num_sink_tokens + + num_evicted_tokens + + key_states.shape[1], + ] shifted_sin = sin[:, self.num_sink_tokens : self.num_sink_tokens + key_states.shape[1]] rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin @@ -205,19 +190,8 @@ def rerotate_key_rotary_pos_emb( # rerotation_cos = cos[:, num_evicted_tokens] # rerotation_sin = -sin[:, num_evicted_tokens] rotated_key_states = self._apply_rope(key_states, rerotation_cos, rerotation_sin) - - # if self.layer_idx == 0: - # print(key_states.shape, num_evicted_tokens, original_cos.shape, shifted_cos.shape) return rotated_key_states - def get_rerotated_num_positions(self): - if (4680 * 2) + self.global_start_index > self.max_seq_length: - num_evicted_tokens = ((4680 * 2) + self.global_start_index) - self.max_seq_length - num_tokens_to_rotate_back = num_evicted_tokens - self.num_sink_tokens - else: - num_tokens_to_rotate_back = 0 - return num_tokens_to_rotate_back - def apply_rolling_kv_cache(module: torch.nn.Module, config: RollingKVCacheConfig) -> None: for name, submodule in module.named_children(): diff --git a/src/diffusers/models/transformers/transformer_krea.py b/src/diffusers/models/transformers/transformer_krea.py index 901bf6f93214..dc133369bbb0 100644 --- a/src/diffusers/models/transformers/transformer_krea.py +++ b/src/diffusers/models/transformers/transformer_krea.py @@ -102,8 +102,6 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, kv_cache: list[torch.Tensor] = None, - clean_latents: bool = None, - timestep = None, ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: @@ -129,6 +127,10 @@ def apply_rotary_emb( freqs_sin: torch.Tensor, kv_cache: torch.Tensor, ): + # NOTE: ideally if we can pass `position_ids` to the model's forward + # and use it to compute cos/sin. Similar to common autoregressive + # LLMs For now let's initialize one big cos/sin array and slice it + # for current frame position block_length = hidden_states.shape[1] start_index = kv_cache.local_start_index if kv_cache is not None else 0 freqs_cos = freqs_cos[:, start_index : start_index + block_length] @@ -141,16 +143,11 @@ def apply_rotary_emb( out[..., 1::2] = x1 * sin + x2 * cos return out.type_as(hidden_states) - # if clean_latents and kv_cache.local_start_index != 0 and kv_cache.layer_idx == 0: - # keys_with_rope_correct = apply_rotary_emb(key, *rotary_emb, None) - # print("What I need", keys_with_rope_correct.shape, keys_with_rope_correct[:, -3120:]) - query = apply_rotary_emb(query, *rotary_emb, kv_cache) key = apply_rotary_emb(key, *rotary_emb, kv_cache) if kv_cache is not None: - roll = timestep.flatten() == 1000 - key, value = kv_cache.update(key, value, rotary_emb, roll) + key, value = kv_cache.update(key, value, rotary_emb) # I2V task hidden_states_img = None @@ -492,7 +489,6 @@ def forward( temb: torch.Tensor, rotary_emb: torch.Tensor, clean_latents: bool = False, - timestep = None, kv_cache: list[torch.Tensor] = None, ) -> torch.Tensor: # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) @@ -500,22 +496,9 @@ def forward( self.scale_shift_table + temb # .float() ).chunk(6, dim=1) - # NOTE: ideally if we can pass `position_ids` to the model's forward - # and use it to compute cos/sin. Similar to simply autoregressive - # LLMs For now let's initialize one big cos/sin array and slice it - # for current frame position - block_length = hidden_states.shape[1] - start_index = kv_cache.local_start_index if kv_cache is not None else 0 - # rotary_emb = ( - # rotary_emb[0][:, start_index : start_index + block_length], - # rotary_emb[1][:, start_index : start_index + block_length], - # ) - # if kv_cache.layer_idx == 0: - # print(clean_latents, start_index, kv_cache.key_cache[:, :4680]) - # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, kv_cache, clean_latents=clean_latents, timestep=timestep) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, kv_cache) hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -710,7 +693,6 @@ def forward( timestep_proj, rotary_emb, clean_latents, - timestep, ) else: for block in self.blocks: @@ -720,7 +702,6 @@ def forward( timestep_proj, rotary_emb, clean_latents, - timestep, ) # 5. Output norm, projection & unpatchify