diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..cb3d1a377f70 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -169,12 +169,14 @@ "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "TaylorSeerCacheConfig", + "TextKVCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", "apply_mag_cache", "apply_pyramid_attention_broadcast", "apply_taylorseer_cache", + "apply_text_kv_cache", ] ) _import_structure["models"].extend( @@ -260,6 +262,7 @@ "PixArtTransformer2DModel", "PriorTransformer", "PRXTransformer2DModel", + "NucleusMoEImageTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", @@ -613,6 +616,7 @@ "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", "PRXPipeline", + "NucleusMoEImagePipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", @@ -959,12 +963,14 @@ PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_text_kv_cache, ) from .models import ( AllegroTransformer3DModel, @@ -1048,6 +1054,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + NucleusMoEImageTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, @@ -1376,6 +1383,7 @@ PixArtSigmaPAGPipeline, PixArtSigmaPipeline, PRXPipeline, + NucleusMoEImagePipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 23c8bc92b2f1..2a9aa81608e7 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -27,3 +27,4 @@ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache + from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache diff --git a/src/diffusers/hooks/text_kv_cache.py b/src/diffusers/hooks/text_kv_cache.py new file mode 100644 index 000000000000..fb1a4875b366 --- /dev/null +++ b/src/diffusers/hooks/text_kv_cache.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from .hooks import HookRegistry, ModelHook + + +_TEXT_KV_CACHE_HOOK = "text_kv_cache" + + +@dataclass +class TextKVCacheConfig: + """Enable exact (lossless) text K/V caching for transformer models. + + Pre-computes per-block text key and value projections once before the + denoising loop and reuses them across all steps. The cached values are keyed by + the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive + and negative prompts (when ``true_cfg_scale > 1``) are handled correctly. + """ + + pass # no hyperparameters needed — cache is always exact + + +class TextKVCacheHook(ModelHook): + """Block-level hook that caches (txt_key, txt_value) per unique prompt.""" + + _is_stateful = True + + def __init__(self): + super().__init__() + # Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value) + self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus + + # --- extract encoder_hidden_states --- + if "encoder_hidden_states" in kwargs: + encoder_hidden_states = kwargs["encoder_hidden_states"] + else: + # positional: (hidden_states, encoder_hidden_states, temb, ...) + encoder_hidden_states = args[1] + + # --- extract image_rotary_emb --- + if "image_rotary_emb" in kwargs: + image_rotary_emb = kwargs.get("image_rotary_emb") + elif len(args) > 3: + image_rotary_emb = args[3] + else: + image_rotary_emb = None + + ptr = encoder_hidden_states.data_ptr() + + if ptr not in self.kv_cache: + context = module.encoder_proj(encoder_hidden_states) + + attn = module.attn + head_dim = attn.inner_dim // attn.heads + num_kv_heads = attn.inner_kv_dim // head_dim + + txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1)) + txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + if image_rotary_emb is not None: + _, txt_freqs = image_rotary_emb + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) + + self.kv_cache[ptr] = (txt_key, txt_value) + + txt_key, txt_value = self.kv_cache[ptr] + + # Inject cached k/v — block sees cached_txt_key and skips encoder_proj too + attn_kwargs = kwargs.get("attention_kwargs") or {} + attn_kwargs["cached_txt_key"] = txt_key + attn_kwargs["cached_txt_value"] = txt_value + kwargs["attention_kwargs"] = attn_kwargs + + return self.fn_ref.original_forward(*args, **kwargs) + + def reset_state(self, module: torch.nn.Module): + self.kv_cache.clear() + return module + + +def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None: + from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock + + for _, submodule in module.named_modules(): + if isinstance(submodule, NucleusMoEImageTransformerBlock): + hook = TextKVCacheHook() + registry = HookRegistry.check_if_exists_or_initialize(submodule) + registry.register_hook(hook, _TEXT_KV_CACHE_HOOK) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..b19461e7302c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -119,6 +119,7 @@ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] + _import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -241,6 +242,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + NucleusMoEImageTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 5f9587a1b4de..3bca773d8344 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,11 +41,12 @@ def enable_cache(self, config) -> None: Enable caching techniques on the model. Args: - config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`): + config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`): The configuration for applying the caching technique. Currently supported caching techniques are: - [`~hooks.PyramidAttentionBroadcastConfig`] - [`~hooks.FasterCacheConfig`] - [`~hooks.FirstBlockCacheConfig`] + - [`~hooks.TextKVCacheConfig`] Example: @@ -71,11 +72,13 @@ def enable_cache(self, config) -> None: MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_text_kv_cache, ) if self.is_cache_enabled: @@ -89,6 +92,8 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, MagCacheConfig): apply_mag_cache(self, config) + elif isinstance(config, TextKVCacheConfig): + apply_text_kv_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) elif isinstance(config, TaylorSeerCacheConfig): @@ -106,12 +111,14 @@ def disable_cache(self) -> None: MagCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK + from ..hooks.text_kv_cache import _TEXT_KV_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -129,6 +136,8 @@ def disable_cache(self) -> None: registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TextKVCacheConfig): + registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..f675196cfa92 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -43,6 +43,7 @@ from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel + from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py new file mode 100644 index 000000000000..88df164f7d03 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -0,0 +1,761 @@ +# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from typing import Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, RMSNorm + + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus +def _apply_rotary_emb_nucleus( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_real: + cos, sin = freqs_cis + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + +def _compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None +) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device), + ) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + +class NucleusMoETimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, use_additional_t_cond=False): + super().__init__() + + self.time_proj = Timesteps(num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding( + in_channels=embedding_dim, time_embed_dim=4 * embedding_dim, out_dim=embedding_dim + ) + self.norm = RMSNorm(embedding_dim, eps=1e-6) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = nn.Embedding(2, embedding_dim) + + def forward(self, timestep, hidden_states, addition_t_cond=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + conditioning = timesteps_emb + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.") + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) + conditioning = conditioning + addition_t_emb + + return self.norm(conditioning) + + +# Copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope with Qwen->NucleusMoE +class NucleusMoEEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self._rope_params(pos_index, self.axes_dim[0], self.theta), + self._rope_params(pos_index, self.axes_dim[1], self.theta), + self._rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self._rope_params(neg_index, self.axes_dim[0], self.theta), + self._rope_params(neg_index, self.axes_dim[1], self.theta), + self._rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + @staticmethod + def _rope_params(index, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward( + self, + video_fhw: tuple[int, int, int] | list[tuple[int, int, int]], + device: torch.device = None, + max_txt_seq_len: int | torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video. + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. + """ + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` must be provided.") + + if isinstance(video_fhw, list) and len(video_fhw) > 1: + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in NucleusMoEEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + video_freq = self._compute_video_freqs(frame, height, width, idx, device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_txt_seq_len_int = int(max_txt_seq_len) + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=128) + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: + seq_lens = frame * height * width + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class NucleusMoEAttnProcessor2_0: + """ + Attention processor for the NucleusMoE architecture. Image queries attend to concatenated + image+text keys/values (cross-attention style, no text query). Supports grouped-query + attention (GQA) when num_key_value_heads is set on the Attention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "NucleusMoEAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: torch.FloatTensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + cached_txt_key: torch.FloatTensor | None = None, + cached_txt_value: torch.FloatTensor | None = None, + ) -> torch.FloatTensor: + head_dim = attn.inner_dim // attn.heads + num_kv_heads = attn.inner_kv_dim // head_dim + num_kv_groups = attn.heads // num_kv_heads + + img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1)) + img_key = attn.to_k(hidden_states).unflatten(-1, (num_kv_heads, -1)) + img_value = attn.to_v(hidden_states).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False) + img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False) + + if cached_txt_key is not None and cached_txt_value is not None: + txt_key, txt_value = cached_txt_key, cached_txt_value + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + elif encoder_hidden_states is not None: + txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) + txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + if image_rotary_emb is not None: + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) + + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + else: + joint_key = img_key + joint_value = img_value + + if num_kv_groups > 1: + joint_key = joint_key.repeat_interleave(num_kv_groups, dim=2) + joint_value = joint_value.repeat_interleave(num_kv_groups, dim=2) + + hidden_states = dispatch_attention_fn( + img_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(img_query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def _is_moe_layer(strategy: str, layer_idx: int, num_layers: int) -> bool: + if strategy == "leave_first_three_and_last_block_dense": + return layer_idx >= 3 and layer_idx < num_layers - 1 + elif strategy == "leave_first_three_blocks_dense": + return layer_idx >= 3 + elif strategy == "leave_first_block_dense": + return layer_idx >= 1 + elif strategy == "all_moe": + return True + elif strategy == "all_dense": + return False + return True + + +class NucleusMoELayer(nn.Module): + """ + Mixture-of-Experts layer with expert-choice routing and a shared expert. + + Each expert is a separate ``FeedForward`` module stored in an ``nn.ModuleList``. + The router concatenates a timestep embedding with the (unmodulated) hidden state + to produce per-token affinity scores, then selects the top-C tokens per expert + (expert-choice routing). A shared expert processes all tokens in parallel and its + output is combined with the routed expert outputs via scatter-add. + """ + + def __init__( + self, + hidden_size: int, + moe_intermediate_dim: int, + num_experts: int, + capacity_factor: float, + use_sigmoid: bool, + route_scale: float, + ): + super().__init__() + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.use_sigmoid = use_sigmoid + self.route_scale = route_scale + + self.gate = nn.Linear(hidden_size * 2, num_experts, bias=False) + self.experts = nn.ModuleList( + [ + FeedForward( + dim=hidden_size, + dim_out=hidden_size, + inner_dim=moe_intermediate_dim, + activation_fn="swiglu", + bias=False, + ) + for _ in range(num_experts) + ] + ) + self.shared_expert = FeedForward( + dim=hidden_size, + dim_out=hidden_size, + inner_dim=moe_intermediate_dim, + activation_fn="swiglu", + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_unmodulated: torch.Tensor, + timestep: torch.Tensor | None = None, + ) -> torch.Tensor: + bs, slen, dim = hidden_states.shape + + if timestep is not None: + timestep_expanded = timestep.unsqueeze(1).expand(-1, slen, -1) + router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1) + else: + router_input = hidden_states_unmodulated + + logits = self.gate(router_input) + + if self.use_sigmoid: + scores = torch.sigmoid(logits.float()).to(logits.dtype) + else: + scores = F.softmax(logits.float(), dim=-1).to(logits.dtype) + + affinity = scores.transpose(1, 2) # (B, E, S) + capacity = max(1, math.ceil(self.capacity_factor * slen / self.num_experts)) + + topk = torch.topk(affinity, k=capacity, dim=-1) + top_indices = topk.indices # (B, E, C) + gating = affinity.gather(dim=-1, index=top_indices) # (B, E, C) + + batch_offsets = torch.arange(bs, device=hidden_states.device, dtype=torch.long).view(bs, 1, 1) * slen + global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1) + gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1) + + token_score_sums = torch.zeros(bs * slen, device=hidden_states.device, dtype=gating_flat.dtype) + token_score_sums.scatter_add_(0, global_token_indices, gating_flat) + gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12) + gating_flat = gating_flat * self.route_scale + + x_flat = hidden_states.reshape(bs * slen, dim) + routed_input = x_flat[global_token_indices] + + tokens_per_expert = bs * capacity + routed_output_parts = [] + for i, expert in enumerate(self.experts): + start = i * tokens_per_expert + end = start + tokens_per_expert + expert_out = expert(routed_input[start:end]) + routed_output_parts.append(expert_out) + + routed_output = torch.cat(routed_output_parts, dim=0) + routed_output = (routed_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states).reshape(bs * slen, dim) + + scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, dim) + out = out.scatter_add(dim=0, index=scatter_idx, src=routed_output) + out = out.reshape(bs, slen, dim) + + return out + + +class NucleusMoEImageTransformerBlock(nn.Module): + """ + Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image + stream receives adaptive modulation; the text context is projected per-block + and used as cross-attention keys/values. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_key_value_heads: int | None = None, + joint_attention_dim: int = 3584, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + mlp_ratio: float = 4.0, + moe_enabled: bool = False, + num_experts: int = 128, + moe_intermediate_dim: int = 1344, + capacity_factor: float = 8.0, + use_sigmoid: bool = False, + route_scale: float = 2.5, + ): + super().__init__() + self.dim = dim + self.moe_enabled = moe_enabled + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 4 * dim, bias=True), + ) + + self.encoder_proj = nn.Linear(joint_attention_dim, dim) + + self.pre_attn_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + self.attn = Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_key_value_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=dim, + added_proj_bias=False, + out_dim=dim, + out_bias=False, + bias=False, + processor=NucleusMoEAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + context_pre_only=None, + ) + + self.pre_mlp_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + + if moe_enabled: + self.img_mlp = NucleusMoELayer( + hidden_size=dim, + moe_intermediate_dim=moe_intermediate_dim, + num_experts=num_experts, + capacity_factor=capacity_factor, + use_sigmoid=use_sigmoid, + route_scale=route_scale, + ) + else: + mlp_inner_dim = int(dim * mlp_ratio * 2 / 3) // 128 * 128 + self.img_mlp = FeedForward( + dim=dim, + dim_out=dim, + inner_dim=mlp_inner_dim, + activation_fn="swiglu", + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1) + + gate1 = gate1.clamp(min=-2.0, max=2.0) + gate2 = gate2.clamp(min=-2.0, max=2.0) + + # Skip encoder_proj when text K/V are already cached — context won't be used by the processor + attn_kwargs = attention_kwargs or {} + context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states) + + img_normed = self.pre_attn_norm(hidden_states) + img_modulated = img_normed * (1 + scale1) + + img_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=context, + image_rotary_emb=image_rotary_emb, + **attn_kwargs, + ) + + hidden_states = hidden_states + gate1.tanh() * img_attn_output + + img_normed2 = self.pre_mlp_norm(hidden_states) + img_modulated2 = img_normed2 * (1 + scale2) + + if self.moe_enabled: + img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb) + else: + img_mlp_output = self.img_mlp(img_modulated2) + + hidden_states = hidden_states + gate2.tanh() * img_mlp_output + + if hidden_states.dtype == torch.float16: + fp16_finfo = torch.finfo(torch.float16) + hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max) + + return hidden_states + + +class NucleusMoEImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + Nucleus MoE Transformer for image generation. Single-stream DiT with + cross-attention to text and optional Mixture-of-Experts feed-forward layers. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `24`): + The number of transformer blocks. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `16`): + The number of attention heads to use. + num_key_value_heads (`int`, *optional*): + The number of key/value heads for grouped-query attention. Defaults to `num_attention_heads`. + joint_attention_dim (`int`, defaults to `3584`): + The embedding dimension of the encoder hidden states (text). + axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + mlp_ratio (`float`, defaults to `4.0`): + Multiplier for the MLP hidden dimension in dense (non-MoE) blocks. + moe_enabled (`bool`, defaults to `True`): + Whether to use Mixture-of-Experts layers. + dense_moe_strategy (`str`, defaults to ``"leave_first_three_and_last_block_dense"``): + Strategy for choosing which layers are MoE vs dense. + num_experts (`int`, defaults to `128`): + Number of experts per MoE layer. + moe_intermediate_dim (`int`, defaults to `1344`): + Hidden dimension inside each expert. + capacity_factors (`float | list[float]`, defaults to `8.0`): + Expert-choice capacity factor per layer. + use_sigmoid (`bool`, defaults to `False`): + Use sigmoid instead of softmax for routing scores. + route_scale (`float`, defaults to `2.5`): + Scaling factor applied to routing weights. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["NucleusMoEImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["NucleusMoEImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: int | None = None, + num_layers: int = 24, + attention_head_dim: int = 128, + num_attention_heads: int = 16, + num_key_value_heads: int | None = None, + joint_attention_dim: int = 3584, + axes_dims_rope: tuple[int, int, int] = (16, 56, 56), + mlp_ratio: float = 4.0, + moe_enabled: bool = True, + dense_moe_strategy: str = "leave_first_three_and_last_block_dense", + num_experts: int = 128, + moe_intermediate_dim: int = 1344, + capacity_factors: float | list[float] = 8.0, + use_sigmoid: bool = False, + route_scale: float = 2.5, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + capacity_factors = capacity_factors if isinstance(capacity_factors, list) else [capacity_factors] * num_layers + + self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = NucleusMoETimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + self.img_in = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + NucleusMoEImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_key_value_heads=num_key_value_heads, + joint_attention_dim=joint_attention_dim, + mlp_ratio=mlp_ratio, + moe_enabled=moe_enabled and _is_moe_layer(dense_moe_strategy, idx, num_layers), + num_experts=num_experts, + moe_intermediate_dim=moe_intermediate_dim, + capacity_factor=capacity_factors[idx], + use_sigmoid=use_sigmoid, + route_scale=route_scale, + ) + for idx in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + img_shapes: tuple[int, int, int] | list[tuple[int, int, int]], + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`NucleusMoEImageTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + img_shapes (`list[tuple[int, int, int]]`, *optional*): + Image shapes ``(frame, height, width)`` for RoPE computation. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Boolean mask for the encoder hidden states. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + Extra kwargs forwarded to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`]. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + + text_seq_len, _, encoder_hidden_states_mask = _compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + temb = self.time_text_embed(timestep, hidden_states) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + block_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=block_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..35e627d72aba 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -431,6 +431,7 @@ "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", ] + _import_structure["nucleusmoe_image"] = ["NucleusMoEImagePipeline"] _import_structure["qwenimage"] = [ "QwenImagePipeline", "QwenImageImg2ImgPipeline", @@ -774,6 +775,7 @@ from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline + from .nucleusmoe_image import NucleusMoEImagePipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 7f8ebd06cef1..7dedab4ed85d 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -96,6 +96,7 @@ ) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline +from .nucleusmoe_image import NucleusMoEImagePipeline from .qwenimage import ( QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, @@ -179,6 +180,7 @@ ("helios", HeliosPipeline), ("helios-pyramid", HeliosPyramidPipeline), ("cogview4-control", CogView4ControlPipeline), + ("nucleusmoe-image", NucleusMoEImagePipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), diff --git a/src/diffusers/pipelines/nucleusmoe_image/__init__.py b/src/diffusers/pipelines/nucleusmoe_image/__init__.py new file mode 100644 index 000000000000..d46644ab237f --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["NucleusMoEImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_nucleusmoe_image"] = ["NucleusMoEImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_nucleusmoe_image import NucleusMoEImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py new file mode 100644 index 000000000000..cab4b7ad02b9 --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -0,0 +1,662 @@ +# Copyright 2025 Nucleus-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLQwenImage, NucleusMoEImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import NucleusMoEImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) + +DEFAULT_SYSTEM_PROMPT = ( + "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import NucleusMoEImagePipeline + + >>> pipe = NucleusMoEImagePipeline.from_pretrained( + ... "NucleusAI/NucleusMoE-Image", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("nucleus_moe.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class NucleusMoEImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using NucleusMoE. + + This pipeline uses a single-stream DiT with Mixture-of-Experts feed-forward layers, + cross-attention to a Qwen3-VL text encoder, and a flow-matching Euler discrete scheduler. + + Args: + transformer ([`NucleusMoEImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLQwenImage`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3VLForConditionalGeneration`]): + Text encoder for computing prompt embeddings. + processor ([`Qwen3VLProcessor`]): + Processor for tokenizing text inputs. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: NucleusMoEImageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen3VLForConditionalGeneration, + processor: Qwen3VLProcessor, + ): + super().__init__() + self.register_modules( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + processor=processor, + ) + self.vae_scale_factor = ( + 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.default_sample_size = 64 + self.default_max_sequence_length = 1024 + self.default_return_index = -8 + + def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str: + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + return self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def encode_prompt( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + return_index: int | None = None, + ): + r""" + Encode text prompt(s) into embeddings using the Qwen3-VL text encoder. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to encode. + device (`torch.device`, *optional*): + Torch device for the resulting tensors. + num_images_per_prompt (`int`, defaults to 1): + Number of images to generate per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Skips encoding when provided. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated embeddings. + max_sequence_length (`int`, defaults to 1024): + Maximum token length for the encoded prompt. + """ + device = device or self._execution_device + return_index = return_index or self.default_return_index + + if prompt_embeds is None: + prompt = [prompt] if isinstance(prompt, str) else prompt + formatted = [self._format_prompt(p) for p in prompt] + + inputs = self.processor( + text=formatted, + padding="longest", + pad_to_multiple_of=8, + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ).to(device=device) + + prompt_embeds_mask = inputs.attention_mask + + outputs = self.text_encoder( + **inputs, use_cache=False, return_dict=True, output_hidden_states=True + ) + prompt_embeds = outputs.hidden_states[return_index] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(device=device) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.to(device=device) + + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat_interleave( + num_images_per_prompt, dim=0 + ) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + return_index=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, " + f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and " + f"`negative_prompt_embeds`: {negative_prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > self.default_max_sequence_length: + raise ValueError( + f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}" + ) + + if return_index is not None and abs(return_index) >= self.text_encoder.config.text_config.num_hidden_layers: + raise ValueError( + f"absolute value of `return_index` cannot be >= {self.text_encoder.config.text_config.num_hidden_layers} " + f"but is {abs(return_index)}" + ) + + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size): + latents = latents.view(batch_size, num_channels_latents, height // patch_size, patch_size, width // patch_size, patch_size) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // patch_size) * (width // patch_size), num_channels_latents * patch_size * patch_size) + return latents + + @staticmethod + def _unpack_latents(latents, height, width, patch_size, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + latents = latents.view(batch_size, height // patch_size, width // patch_size, channels // (patch_size * patch_size), patch_size, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width) + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + patch_size, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = patch_size * (int(height) // (self.vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (self.vae_scale_factor * patch_size)) + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + guidance_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + max_sequence_length: int | None = None, + return_index: int | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, an empty string is used + when `true_cfg_scale > 1`. + true_cfg_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. Values greater than 1 enable CFG. + height (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list[float]`, *optional*): + Custom sigmas for the denoising schedule. If not defined, a linear schedule is used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of torch generators to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"`, `"np"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`NucleusMoEImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list`, *optional*): + Tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for the text prompt. + + Examples: + + Returns: + [`NucleusMoEImagePipelineOutput`] or `tuple`: + [`NucleusMoEImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple` where the first + element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + max_sequence_length = max_sequence_length or self.default_max_sequence_length + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + return_index=return_index, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs or {} + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_cfg = guidance_scale > 1 + + if do_cfg and not has_neg_prompt: + negative_prompt = [""] * batch_size + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + return_index=return_index, + ) + if do_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + return_index=return_index, + ) + + num_channels_latents = self.transformer.config.in_channels // 4 + patch_size = self.transformer.config.patch_size + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + patch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + img_shapes = [( + 1, + height // self.vae_scale_factor // patch_size, + width // self.vae_scale_factor // patch_size + )] * (batch_size * num_images_per_prompt) + + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + ) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + self.scheduler.set_begin_index(0) + + if self.transformer.is_cache_enabled: + self.transformer._reset_stateful_cache() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / self.scheduler.config.num_train_timesteps, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + if do_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / self.scheduler.config.num_train_timesteps, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + img_shapes=img_shapes, + attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + noise_pred = -noise_pred + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, patch_size, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + 1.0 + / torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return NucleusMoEImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py new file mode 100644 index 000000000000..84483355fd6b --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class NucleusMoEImagePipelineOutput(BaseOutput): + """ + Output class for NucleusMoE Image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index fa37388fe75a..9208ccffb822 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -311,6 +311,25 @@ def apply_taylorseer_cache(*args, **kwargs): requires_backends(apply_taylorseer_cache, ["torch"]) +def apply_text_kv_cache(*args, **kwargs): + requires_backends(apply_text_kv_cache, ["torch"]) + + +class TextKVCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1556,6 +1575,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class NucleusMoEImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class QwenImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 1e4d14566160..691a3e0b2f63 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2807,6 +2807,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class NucleusMoEImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_nucleusmoe_image.py b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py new file mode 100644 index 000000000000..edd6de53701a --- /dev/null +++ b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py @@ -0,0 +1,242 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch + +from diffusers import NucleusMoEImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return NucleusMoEImageTransformer2DModel + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), + "moe_enabled": False, + "capacity_factors": [8.0, 8.0], + } + + def get_dummy_inputs(self) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + height = width = 4 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin): + def test_txt_seq_lens_deprecation(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = [inputs["encoder_hidden_states"].shape[1]] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_with_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + # Mask out some text tokens + mask = inputs["encoder_hidden_states_mask"].clone() + mask[:, 4:] = 0 + inputs["encoder_hidden_states_mask"] = mask + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_without_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + +class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRAHotSwap( + NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin +): + """LoRA hot-swapping tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for NucleusMoE Image Transformer.""" diff --git a/tests/pipelines/nucleusmoe_image/__init__.py b/tests/pipelines/nucleusmoe_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py new file mode 100644 index 000000000000..6b5c4b9a4baf --- /dev/null +++ b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py @@ -0,0 +1,340 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + NucleusMoEImagePipeline, + NucleusMoEImageTransformer2DModel, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = NucleusMoEImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = NucleusMoEImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=4, + joint_attention_dim=16, + axes_dims_rope=(8, 4, 4), + moe_enabled=False, + capacity_factors=[8.0, 8.0], + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * z_dim, + latents_std=[1.0] * z_dim, + # fmt: on + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 8, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + "vocab_size": 151936, + "head_dim": 8, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_channels": 16, + }, + ) + text_encoder = Qwen3VLForConditionalGeneration(config).eval() + processor = Qwen3VLProcessor.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A cat sitting on a mat", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "return_index": -1, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_true_cfg(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 4.0 + inputs["negative_prompt"] = "low quality" + image = pipe(**inputs).images + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_prompt_embeds(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=inputs["prompt"], + device=device, + max_sequence_length=inputs["max_sequence_length"], + ) + + inputs_with_embeds = self.get_dummy_inputs(device) + inputs_with_embeds.pop("prompt") + inputs_with_embeds["prompt_embeds"] = prompt_embeds + inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask + + image = pipe(**inputs_with_embeds).images + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + # PipelineTesterMixin compares outputs with assert_mean_pixel_difference, which assumes HWC numpy/PIL layout. + # With output_type="pt", tensors are CHW; numpy_to_pil then fails. Match QwenImage: only assert max diff. + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + # PipelineTesterMixin only keeps components whose keys contain "text" or "tokenizer"; this pipeline also + # needs `processor` for encode_prompt (apply_chat_template). Mirror the mixin with that key included. + if not hasattr(self.pipeline_class, "encode_prompt"): + return + + components = self.get_dummy_components() + for key in components: + if "text_encoder" in key and hasattr(components[key], "eval"): + components[key].eval() + + def _is_text_stack_component(k): + return "text" in k or "tokenizer" in k or k == "processor" + + components_with_text_encoders = {} + for k in components: + if _is_text_stack_component(k): + components_with_text_encoders[k] = components[k] + else: + components_with_text_encoders[k] = None + pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders) + pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt) + encode_prompt_parameters = list(encode_prompt_signature.parameters.values()) + + required_params = [] + for param in encode_prompt_parameters: + if param.name == "self" or param.name == "kwargs": + continue + if param.default is inspect.Parameter.empty: + required_params.append(param.name) + + encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"] + input_keys = list(inputs.keys()) + encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names} + + pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__) + pipe_call_parameters = pipe_call_signature.parameters + + for required_param_name in required_params: + if required_param_name not in encode_prompt_inputs: + pipe_call_param = pipe_call_parameters.get(required_param_name, None) + if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty: + encode_prompt_inputs[required_param_name] = pipe_call_param.default + elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict): + encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name] + else: + raise ValueError( + f"Required parameter '{required_param_name}' in " + f"encode_prompt has no default in either encode_prompt or __call__." + ) + + with torch.no_grad(): + encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs) + + ast_visitor = ReturnNameVisitor() + encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class) + ast_visitor.visit(encode_prompt_tree) + prompt_embed_kwargs = ast_visitor.return_names + prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs)) + + adapted_prompt_embeds_kwargs = { + k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters + } + + components_with_text_encoders = {} + for k in components: + if _is_text_stack_component(k): + components_with_text_encoders[k] = None + else: + components_with_text_encoders[k] = components[k] + pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device) + + pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs} + if ( + pipe_call_parameters.get("negative_prompt", None) is not None + and pipe_call_parameters.get("negative_prompt").default is not None + ): + pipe_without_tes_inputs.update({"negative_prompt": None}) + + if ( + pipe_call_parameters.get("prompt", None) is not None + and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty + and pipe_call_parameters.get("prompt_embeds", None) is not None + and pipe_call_parameters.get("prompt_embeds").default is None + ): + pipe_without_tes_inputs.update({"prompt": None}) + + pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0] + + full_pipe = self.pipeline_class(**components).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + pipe_out_2 = full_pipe(**inputs)[0] + + if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray): + self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol)) + elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor): + self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))