Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,14 @@
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"TextKVCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
"apply_text_kv_cache",
]
)
_import_structure["models"].extend(
Expand Down Expand Up @@ -260,6 +262,7 @@
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"NucleusMoEImageTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
Expand Down Expand Up @@ -613,6 +616,7 @@
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"NucleusMoEImagePipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
Expand Down Expand Up @@ -959,12 +963,14 @@
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
apply_text_kv_cache,
)
from .models import (
AllegroTransformer3DModel,
Expand Down Expand Up @@ -1048,6 +1054,7 @@
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
NucleusMoEImageTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
Expand Down Expand Up @@ -1376,6 +1383,7 @@
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
NucleusMoEImagePipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
109 changes: 109 additions & 0 deletions src/diffusers/hooks/text_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass

import torch

from .hooks import HookRegistry, ModelHook


_TEXT_KV_CACHE_HOOK = "text_kv_cache"


@dataclass
class TextKVCacheConfig:
"""Enable exact (lossless) text K/V caching for transformer models.

Pre-computes per-block text key and value projections once before the
denoising loop and reuses them across all steps. The cached values are keyed by
the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive
and negative prompts (when ``true_cfg_scale > 1``) are handled correctly.
"""

pass # no hyperparameters needed — cache is always exact


class TextKVCacheHook(ModelHook):
"""Block-level hook that caches (txt_key, txt_value) per unique prompt."""

_is_stateful = True

def __init__(self):
super().__init__()
# Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value)
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}

def new_forward(self, module: torch.nn.Module, *args, **kwargs):
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus

# --- extract encoder_hidden_states ---
if "encoder_hidden_states" in kwargs:
encoder_hidden_states = kwargs["encoder_hidden_states"]
else:
# positional: (hidden_states, encoder_hidden_states, temb, ...)
encoder_hidden_states = args[1]

# --- extract image_rotary_emb ---
if "image_rotary_emb" in kwargs:
image_rotary_emb = kwargs.get("image_rotary_emb")
elif len(args) > 3:
image_rotary_emb = args[3]
else:
image_rotary_emb = None

ptr = encoder_hidden_states.data_ptr()

if ptr not in self.kv_cache:
context = module.encoder_proj(encoder_hidden_states)

attn = module.attn
head_dim = attn.inner_dim // attn.heads
num_kv_heads = attn.inner_kv_dim // head_dim

txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))

if attn.norm_added_k is not None:
txt_key = attn.norm_added_k(txt_key)

if image_rotary_emb is not None:
_, txt_freqs = image_rotary_emb
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)

self.kv_cache[ptr] = (txt_key, txt_value)

txt_key, txt_value = self.kv_cache[ptr]

# Inject cached k/v — block sees cached_txt_key and skips encoder_proj too
attn_kwargs = kwargs.get("attention_kwargs") or {}
attn_kwargs["cached_txt_key"] = txt_key
attn_kwargs["cached_txt_value"] = txt_value
kwargs["attention_kwargs"] = attn_kwargs

return self.fn_ref.original_forward(*args, **kwargs)

def reset_state(self, module: torch.nn.Module):
self.kv_cache.clear()
return module


def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock

for _, submodule in module.named_modules():
if isinstance(submodule, NucleusMoEImageTransformerBlock):
hook = TextKVCacheHook()
registry = HookRegistry.check_if_exists_or_initialize(submodule)
registry.register_hook(hook, _TEXT_KV_CACHE_HOOK)
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
Expand Down Expand Up @@ -241,6 +242,7 @@
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
NucleusMoEImageTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SanaVideoTransformer3DModel,
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/models/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def enable_cache(self, config) -> None:
Enable caching techniques on the model.

Args:
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
- [`~hooks.TextKVCacheConfig`]

Example:

Expand All @@ -71,11 +72,13 @@ def enable_cache(self, config) -> None:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
apply_text_kv_cache,
)

if self.is_cache_enabled:
Expand All @@ -89,6 +92,8 @@ def enable_cache(self, config) -> None:
apply_first_block_cache(self, config)
elif isinstance(config, MagCacheConfig):
apply_mag_cache(self, config)
elif isinstance(config, TextKVCacheConfig):
apply_text_kv_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
Expand All @@ -106,12 +111,14 @@ def disable_cache(self) -> None:
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
TextKVCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_HOOK

if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
Expand All @@ -129,6 +136,8 @@ def disable_cache(self) -> None:
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TextKVCacheConfig):
registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_ovis_image import OvisImageTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sana_video import SanaVideoTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
Expand Down
Loading