From 76dcd51e4674d50f5fdbd021ea2a2883a9c77904 Mon Sep 17 00:00:00 2001 From: Murali Nandan Nagarapu Date: Fri, 20 Mar 2026 07:59:47 +0000 Subject: [PATCH 01/14] adding NucleusMoE-Image model --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformer_nucleusmoe_image.py | 779 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 2 + .../pipelines/nucleusmoe_image/__init__.py | 48 ++ .../pipeline_nucleusmoe_image.py | 663 +++++++++++++++ .../nucleusmoe_image/pipeline_output.py | 20 + src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + 11 files changed, 1551 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_nucleusmoe_image.py create mode 100644 src/diffusers/pipelines/nucleusmoe_image/__init__.py create mode 100644 src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py create mode 100644 src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0be7b8166a37..0fdcfa97b065 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -258,6 +258,7 @@ "PixArtTransformer2DModel", "PriorTransformer", "PRXTransformer2DModel", + "NucleusMoEImageTransformer2DModel", "QwenImageControlNetModel", "QwenImageMultiControlNetModel", "QwenImageTransformer2DModel", @@ -607,6 +608,7 @@ "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", "PRXPipeline", + "NucleusMoEImagePipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", @@ -1040,6 +1042,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + NucleusMoEImageTransformer2DModel, QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageTransformer2DModel, @@ -1364,6 +1367,7 @@ PixArtSigmaPAGPipeline, PixArtSigmaPipeline, PRXPipeline, + NucleusMoEImagePipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e4bc95fdf884..a333db4295b7 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,6 +117,7 @@ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] + _import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -237,6 +238,7 @@ PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, + NucleusMoEImageTransformer2DModel, QwenImageTransformer2DModel, SanaTransformer2DModel, SanaVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..f675196cfa92 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -43,6 +43,7 @@ from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel + from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sana_video import SanaVideoTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py new file mode 100644 index 000000000000..9c2aa17f162a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -0,0 +1,779 @@ +# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import math +from typing import Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, RMSNorm + + +logger = logging.get_logger(__name__) + + +def _apply_rotary_emb_nucleus( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> tuple[torch.Tensor, torch.Tensor]: + if use_real: + cos, sin = freqs_cis + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(1) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + return x_out.type_as(x) + + +def _compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None +) -> tuple[int, torch.Tensor | None, torch.Tensor | None]: + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where( + has_active, + active_positions.max(dim=1).values + 1, + torch.as_tensor(text_seq_len, device=encoder_hidden_states.device), + ) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + +class NucleusMoETimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, use_additional_t_cond=False): + super().__init__() + + self.time_proj = Timesteps(num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) + self.timestep_embedder = TimestepEmbedding( + in_channels=embedding_dim, time_embed_dim=4 * embedding_dim, out_dim=embedding_dim + ) + self.norm = RMSNorm(embedding_dim, eps=1e-6) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = nn.Embedding(2, embedding_dim) + + def forward(self, timestep, hidden_states, addition_t_cond=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + conditioning = timesteps_emb + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.") + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) + conditioning = conditioning + addition_t_emb + + return self.norm(conditioning) + + +class NucleusMoEEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self._rope_params(pos_index, self.axes_dim[0], self.theta), + self._rope_params(pos_index, self.axes_dim[1], self.theta), + self._rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self._rope_params(neg_index, self.axes_dim[0], self.theta), + self._rope_params(neg_index, self.axes_dim[1], self.theta), + self._rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + @staticmethod + def _rope_params(index, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward( + self, + video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], + txt_seq_lens: list[int] | None = None, + device: torch.device = None, + max_txt_seq_len: int | torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video. + txt_seq_lens (`list[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. + """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead.", + standard_warn=False, + ) + if max_txt_seq_len is None: + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + if isinstance(video_fhw, list) and len(video_fhw) > 1: + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in NucleusMoEEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + video_freq = self._compute_video_freqs(frame, height, width, idx, device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_txt_seq_len_int = int(max_txt_seq_len) + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=128) + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: + seq_lens = frame * height * width + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + +class NucleusMoEAttnProcessor2_0: + """ + Attention processor for the NucleusMoE architecture. Image queries attend to concatenated + image+text keys/values (cross-attention style, no text query). Supports grouped-query + attention (GQA) when num_key_value_heads is set on the Attention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "NucleusMoEAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: torch.FloatTensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + ) -> torch.FloatTensor: + head_dim = attn.inner_dim // attn.heads + num_kv_heads = attn.inner_kv_dim // head_dim + num_kv_groups = attn.heads // num_kv_heads + + img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1)) + img_key = attn.to_k(hidden_states).unflatten(-1, (num_kv_heads, -1)) + img_value = attn.to_v(hidden_states).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_q is not None: + img_query = attn.norm_q(img_query) + if attn.norm_k is not None: + img_key = attn.norm_k(img_key) + + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False) + img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False) + + if encoder_hidden_states is not None: + txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) + txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + if image_rotary_emb is not None: + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) + + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + else: + joint_key = img_key + joint_value = img_value + + if num_kv_groups > 1: + joint_key = joint_key.repeat_interleave(num_kv_groups, dim=2) + joint_value = joint_value.repeat_interleave(num_kv_groups, dim=2) + + hidden_states = dispatch_attention_fn( + img_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(img_query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +def _is_moe_layer(strategy: str, layer_idx: int, num_layers: int) -> bool: + if strategy == "leave_first_three_and_last_block_dense": + return layer_idx >= 3 and layer_idx < num_layers - 1 + elif strategy == "leave_first_three_blocks_dense": + return layer_idx >= 3 + elif strategy == "leave_first_block_dense": + return layer_idx >= 1 + elif strategy == "all_moe": + return True + elif strategy == "all_dense": + return False + return True + + +class NucleusMoELayer(nn.Module): + """ + Mixture-of-Experts layer with expert-choice routing and a shared expert. + + Each expert is a separate ``FeedForward`` module stored in an ``nn.ModuleList``. + The router concatenates a timestep embedding with the (unmodulated) hidden state + to produce per-token affinity scores, then selects the top-C tokens per expert + (expert-choice routing). A shared expert processes all tokens in parallel and its + output is combined with the routed expert outputs via scatter-add. + """ + + def __init__( + self, + hidden_size: int, + moe_intermediate_dim: int, + num_experts: int, + capacity_factor: float, + use_sigmoid: bool, + route_scale: float, + ): + super().__init__() + self.num_experts = num_experts + self.capacity_factor = capacity_factor + self.use_sigmoid = use_sigmoid + self.route_scale = route_scale + + self.gate = nn.Linear(hidden_size * 2, num_experts, bias=False) + self.experts = nn.ModuleList( + [ + FeedForward( + dim=hidden_size, + dim_out=hidden_size, + inner_dim=moe_intermediate_dim, + activation_fn="swiglu", + bias=False, + ) + for _ in range(num_experts) + ] + ) + self.shared_expert = FeedForward( + dim=hidden_size, + dim_out=hidden_size, + inner_dim=moe_intermediate_dim, + activation_fn="swiglu", + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + hidden_states_unmodulated: torch.Tensor, + timestep: torch.Tensor | None = None, + ) -> torch.Tensor: + bs, slen, dim = hidden_states.shape + + if timestep is not None: + timestep_expanded = timestep.unsqueeze(1).expand(-1, slen, -1) + router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1) + else: + router_input = hidden_states_unmodulated + + logits = self.gate(router_input) + + if self.use_sigmoid: + scores = torch.sigmoid(logits.float()).to(logits.dtype) + else: + scores = F.softmax(logits.float(), dim=-1).to(logits.dtype) + + affinity = scores.transpose(1, 2) # (B, E, S) + capacity = max(1, math.ceil(self.capacity_factor * slen / self.num_experts)) + + topk = torch.topk(affinity, k=capacity, dim=-1) + top_indices = topk.indices # (B, E, C) + gating = affinity.gather(dim=-1, index=top_indices) # (B, E, C) + + batch_offsets = torch.arange(bs, device=hidden_states.device, dtype=torch.long).view(bs, 1, 1) * slen + global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1) + gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1) + + token_score_sums = torch.zeros(bs * slen, device=hidden_states.device, dtype=gating_flat.dtype) + token_score_sums.scatter_add_(0, global_token_indices, gating_flat) + gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12) + gating_flat = gating_flat * self.route_scale + + x_flat = hidden_states.reshape(bs * slen, dim) + routed_input = x_flat[global_token_indices] + + tokens_per_expert = bs * capacity + routed_output_parts = [] + for i, expert in enumerate(self.experts): + start = i * tokens_per_expert + end = start + tokens_per_expert + expert_out = expert(routed_input[start:end]) + routed_output_parts.append(expert_out) + + routed_output = torch.cat(routed_output_parts, dim=0) + routed_output = (routed_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states).reshape(bs * slen, dim) + + scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, dim) + out = out.scatter_add(dim=0, index=scatter_idx, src=routed_output) + out = out.reshape(bs, slen, dim) + + return out + + +@maybe_allow_in_graph +class NucleusMoEImageTransformerBlock(nn.Module): + """ + Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image + stream receives adaptive modulation; the text context is projected per-block + and used as cross-attention keys/values. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_key_value_heads: int | None = None, + joint_attention_dim: int = 3584, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + mlp_ratio: float = 4.0, + moe_enabled: bool = False, + num_experts: int = 128, + moe_intermediate_dim: int = 1344, + capacity_factor: float = 8.0, + use_sigmoid: bool = False, + route_scale: float = 2.5, + ): + super().__init__() + self.dim = dim + self.moe_enabled = moe_enabled + + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, 4 * dim, bias=True), + ) + + self.encoder_proj = nn.Linear(joint_attention_dim, dim) + + self.pre_attn_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + self.attn = Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_key_value_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=dim, + added_proj_bias=False, + out_dim=dim, + out_bias=False, + bias=False, + processor=NucleusMoEAttnProcessor2_0(), + qk_norm=qk_norm, + eps=eps, + context_pre_only=None, + ) + + self.pre_mlp_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + + if moe_enabled: + self.img_mlp = NucleusMoELayer( + hidden_size=dim, + moe_intermediate_dim=moe_intermediate_dim, + num_experts=num_experts, + capacity_factor=capacity_factor, + use_sigmoid=use_sigmoid, + route_scale=route_scale, + ) + else: + mlp_inner_dim = int(dim * mlp_ratio * 2 / 3) // 128 * 128 + self.img_mlp = FeedForward( + dim=dim, + dim_out=dim, + inner_dim=mlp_inner_dim, + activation_fn="swiglu", + bias=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1) + scale1, scale2 = 1 + scale1, 1 + scale2 + + gate1 = gate1.clamp(min=-2.0, max=2.0) + gate2 = gate2.clamp(min=-2.0, max=2.0) + + context = self.encoder_proj(encoder_hidden_states) + + img_normed = self.pre_attn_norm(hidden_states) + img_modulated = img_normed * scale1 + + attention_kwargs = attention_kwargs or {} + img_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=context, + image_rotary_emb=image_rotary_emb, + **attention_kwargs, + ) + + hidden_states = hidden_states + gate1.tanh() * img_attn_output + + img_normed2 = self.pre_mlp_norm(hidden_states) + img_modulated2 = img_normed2 * scale2 + + if self.moe_enabled: + img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb) + else: + img_mlp_output = self.img_mlp(img_modulated2) + + hidden_states = hidden_states + gate2.tanh() * img_mlp_output + + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return hidden_states + + +class NucleusMoEImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + """ + Nucleus MoE Transformer for image generation. Single-stream DiT with + cross-attention to text and optional Mixture-of-Experts feed-forward layers. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `24`): + The number of transformer blocks. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `16`): + The number of attention heads to use. + num_key_value_heads (`int`, *optional*): + The number of key/value heads for grouped-query attention. Defaults to `num_attention_heads`. + joint_attention_dim (`int`, defaults to `3584`): + The embedding dimension of the encoder hidden states (text). + axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + mlp_ratio (`float`, defaults to `4.0`): + Multiplier for the MLP hidden dimension in dense (non-MoE) blocks. + moe_enabled (`bool`, defaults to `True`): + Whether to use Mixture-of-Experts layers. + dense_moe_strategy (`str`, defaults to ``"leave_first_three_and_last_block_dense"``): + Strategy for choosing which layers are MoE vs dense. + num_experts (`int`, defaults to `128`): + Number of experts per MoE layer. + moe_intermediate_dim (`int`, defaults to `1344`): + Hidden dimension inside each expert. + capacity_factors (`list[float]`, defaults to `[8.0] * 24`): + Expert-choice capacity factor per layer. + use_sigmoid (`bool`, defaults to `False`): + Use sigmoid instead of softmax for routing scores. + route_scale (`float`, defaults to `2.5`): + Scaling factor applied to routing weights. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["NucleusMoEImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["NucleusMoEImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 64, + out_channels: int | None = None, + num_layers: int = 24, + attention_head_dim: int = 128, + num_attention_heads: int = 16, + num_key_value_heads: int | None = None, + joint_attention_dim: int = 3584, + axes_dims_rope: tuple[int, int, int] = (16, 56, 56), + mlp_ratio: float = 4.0, + moe_enabled: bool = True, + dense_moe_strategy: str = "leave_first_three_and_last_block_dense", + num_experts: int = 128, + moe_intermediate_dim: int = 1344, + capacity_factors: List[float] = [8.0] * 24, + use_sigmoid: bool = False, + route_scale: float = 2.5, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + + self.time_text_embed = NucleusMoETimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + self.img_in = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + NucleusMoEImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_key_value_heads=num_key_value_heads, + joint_attention_dim=joint_attention_dim, + mlp_ratio=mlp_ratio, + moe_enabled=moe_enabled and _is_moe_layer(dense_moe_strategy, idx, num_layers), + num_experts=num_experts, + moe_intermediate_dim=moe_intermediate_dim, + capacity_factor=capacity_factors[idx], + use_sigmoid=use_sigmoid, + route_scale=route_scale, + ) + for idx in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + img_shapes: list[tuple[int, int, int]] | None = None, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + txt_seq_lens: list[int] | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`NucleusMoEImageTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + img_shapes (`list[tuple[int, int, int]]`, *optional*): + Image shapes ``(frame, height, width)`` for RoPE computation. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Boolean mask for the encoder hidden states. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + txt_seq_lens (`list[int]`, *optional*, **Deprecated**): + Deprecated. Use ``encoder_hidden_states_mask`` instead. + attention_kwargs (`dict`, *optional*): + Extra kwargs forwarded to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`]. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead.", + standard_warn=False, + ) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + + text_seq_len, _, encoder_hidden_states_mask = _compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + temb = self.time_text_embed(timestep, hidden_states) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + block_attention_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=block_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b9596f4b7952..cc4ab82ffed8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -430,6 +430,7 @@ "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", ] + _import_structure["nucleusmoe_image"] = ["NucleusMoEImagePipeline"] _import_structure["qwenimage"] = [ "QwenImagePipeline", "QwenImageImg2ImgPipeline", @@ -772,6 +773,7 @@ from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline + from .nucleusmoe_image import NucleusMoEImagePipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 7f8ebd06cef1..7dedab4ed85d 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -96,6 +96,7 @@ ) from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .prx import PRXPipeline +from .nucleusmoe_image import NucleusMoEImagePipeline from .qwenimage import ( QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, @@ -179,6 +180,7 @@ ("helios", HeliosPipeline), ("helios-pyramid", HeliosPyramidPipeline), ("cogview4-control", CogView4ControlPipeline), + ("nucleusmoe-image", NucleusMoEImagePipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), ("z-image", ZImagePipeline), diff --git a/src/diffusers/pipelines/nucleusmoe_image/__init__.py b/src/diffusers/pipelines/nucleusmoe_image/__init__.py new file mode 100644 index 000000000000..d46644ab237f --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["NucleusMoEImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_nucleusmoe_image"] = ["NucleusMoEImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_nucleusmoe_image import NucleusMoEImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py new file mode 100644 index 000000000000..3fc22a393077 --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -0,0 +1,663 @@ +# Copyright 2025 Nucleus-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLQwenImage, NucleusMoEImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import NucleusMoEImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) + +DEFAULT_SYSTEM_PROMPT = ( + "You are an assistant designed to generate photorealistic, ultra-high-quality images based on user prompts." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import NucleusMoEImagePipeline + + >>> pipe = NucleusMoEImagePipeline.from_pretrained( + ... "NucleusAI/NucleusMoE-Image", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("nucleus_moe.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and + the second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class NucleusMoEImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using NucleusMoE. + + This pipeline uses a single-stream DiT with Mixture-of-Experts feed-forward layers, + cross-attention to a Qwen3-VL text encoder, and a flow-matching Euler discrete scheduler. + + Args: + transformer ([`NucleusMoEImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLQwenImage`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3VLForConditionalGeneration`]): + Text encoder for computing prompt embeddings. + processor ([`Qwen3VLProcessor`]): + Processor for tokenizing text inputs. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: NucleusMoEImageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen3VLForConditionalGeneration, + processor: Qwen3VLProcessor, + ): + super().__init__() + self.register_modules( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + processor=processor, + ) + self.vae_scale_factor = ( + 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + self.return_index = -8 + + def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str: + if system_prompt is None: + system_prompt = DEFAULT_SYSTEM_PROMPT + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + return self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def encode_prompt( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + Encode text prompt(s) into embeddings using the Qwen3-VL text encoder. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to encode. + device (`torch.device`, *optional*): + Torch device for the resulting tensors. + num_images_per_prompt (`int`, defaults to 1): + Number of images to generate per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Skips encoding when provided. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated embeddings. + max_sequence_length (`int`, defaults to 1024): + Maximum token length for the encoded prompt. + """ + device = device or self._execution_device + + if prompt_embeds is None: + prompt = [prompt] if isinstance(prompt, str) else prompt + formatted = [self._format_prompt(p) for p in prompt] + + inputs = self.processor( + text=formatted, + padding="longest", + pad_to_multiple_of=8, + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ).to(device=device) + + prompt_embeds_mask = inputs.attention_mask + + outputs = self.text_encoder( + **inputs, use_cache=False, return_dict=True, output_hidden_states=True + ) + prompt_embeds = outputs.hidden_states[self.return_index] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(device=device) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.to(device=device) + + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat_interleave( + num_images_per_prompt, dim=0 + ) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} " + f"but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, " + f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and " + f"`negative_prompt_embeds`: {negative_prompt_embeds}. " + "Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError( + f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}" + ) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents + + def enable_vae_slicing(self): + r"""Enable sliced VAE decoding for memory efficiency.""" + depr_message = ( + f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and will be " + "removed in a future version. Please use `pipe.vae.enable_slicing()`." + ) + deprecate("enable_vae_slicing", "0.40.0", depr_message) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r"""Disable sliced VAE decoding.""" + depr_message = ( + f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and will be " + "removed in a future version. Please use `pipe.vae.disable_slicing()`." + ) + deprecate("disable_vae_slicing", "0.40.0", depr_message) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r"""Enable tiled VAE decoding for memory efficiency.""" + depr_message = ( + f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and will be " + "removed in a future version. Please use `pipe.vae.enable_tiling()`." + ) + deprecate("enable_vae_tiling", "0.40.0", depr_message) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r"""Disable tiled VAE decoding.""" + depr_message = ( + f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and will be " + "removed in a future version. Please use `pipe.vae.disable_tiling()`." + ) + deprecate("disable_vae_tiling", "0.40.0", depr_message) + self.vae.disable_tiling() + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, an empty string is used + when `true_cfg_scale > 1`. + true_cfg_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. Values greater than 1 enable CFG. + height (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list[float]`, *optional*): + Custom sigmas for the denoising schedule. If not defined, a linear schedule is used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of torch generators to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"`, `"np"`, or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`NucleusMoEImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list`, *optional*): + Tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length for the text prompt. + + Examples: + + Returns: + [`NucleusMoEImagePipelineOutput`] or `tuple`: + [`NucleusMoEImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple` where the first + element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._attention_kwargs = attention_kwargs or {} + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 + + if do_true_cfg and not has_neg_prompt: + negative_prompt = [""] * batch_size + + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) + img_shapes = [(1, latent_h // 2, latent_w // 2)] * (batch_size * num_images_per_prompt) + + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + ) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + img_shapes=img_shapes, + attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + noise_pred = -noise_pred + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + 1.0 + / torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return NucleusMoEImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py new file mode 100644 index 000000000000..84483355fd6b --- /dev/null +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class NucleusMoEImagePipelineOutput(BaseOutput): + """ + Output class for NucleusMoE Image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3425cc8d2b61..9fba2f661e3b 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1526,6 +1526,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class NucleusMoEImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class QwenImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2ec5bc002f41..2c904d8ce0bc 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2777,6 +2777,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class NucleusMoEImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From f6913953bf5514f580461ddf4e1d0def15a982aa Mon Sep 17 00:00:00 2001 From: sippycoder Date: Fri, 20 Mar 2026 08:05:29 +0000 Subject: [PATCH 02/14] update system prompt --- .../pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 3fc22a393077..70d8ec8212ad 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -39,7 +39,7 @@ logger = logging.get_logger(__name__) DEFAULT_SYSTEM_PROMPT = ( - "You are an assistant designed to generate photorealistic, ultra-high-quality images based on user prompts." + "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects." ) EXAMPLE_DOC_STRING = """ From 7eef03eb2aff8eab6905d89d3b86753a03c17d44 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Fri, 20 Mar 2026 22:40:50 +0000 Subject: [PATCH 03/14] Add text kv caching --- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/text_kv_cache.py | 109 ++++++++ src/diffusers/models/cache_utils.py | 11 +- .../transformer_nucleusmoe_image.py | 15 +- .../pipeline_nucleusmoe_image.py | 4 + ...est_models_transformer_nucleusmoe_image.py | 242 ++++++++++++++++++ tests/pipelines/nucleusmoe_image/__init__.py | 0 .../nucleusmoe_image/test_nucleusmoe_image.py | 197 ++++++++++++++ 8 files changed, 574 insertions(+), 5 deletions(-) create mode 100644 src/diffusers/hooks/text_kv_cache.py create mode 100644 tests/models/transformers/test_models_transformer_nucleusmoe_image.py create mode 100644 tests/pipelines/nucleusmoe_image/__init__.py create mode 100644 tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 23c8bc92b2f1..23466afe5bc5 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -27,3 +27,4 @@ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache + from .text_kv_cache import NucleusMoETextKVCacheConfig, apply_nucleus_moe_text_kv_cache diff --git a/src/diffusers/hooks/text_kv_cache.py b/src/diffusers/hooks/text_kv_cache.py new file mode 100644 index 000000000000..e718943444ca --- /dev/null +++ b/src/diffusers/hooks/text_kv_cache.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch + +from .hooks import HookRegistry, ModelHook + + +_TEXT_KV_CACHE_HOOK = "nucleus_moe_text_kv_cache" + + +@dataclass +class NucleusMoETextKVCacheConfig: + """Enable exact (lossless) text K/V caching for NucleusMoEImageTransformer2DModel. + + Pre-computes per-block text key and value projections once before the + denoising loop and reuses them across all steps. The cached values are keyed by + the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive + and negative prompts (when ``true_cfg_scale > 1``) are handled correctly. + """ + + pass # no hyperparameters needed — cache is always exact + + +class NucleusMoETextKVCacheHook(ModelHook): + """Block-level hook that caches (txt_key, txt_value) per unique prompt.""" + + _is_stateful = True + + def __init__(self): + super().__init__() + # Maps encoder_hidden_states.data_ptr() → (context, txt_key, txt_value) + self.kv_cache: dict[int, tuple] = {} + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus + + # --- extract encoder_hidden_states --- + if "encoder_hidden_states" in kwargs: + encoder_hidden_states = kwargs["encoder_hidden_states"] + else: + # positional: (hidden_states, encoder_hidden_states, temb, ...) + encoder_hidden_states = args[1] + + # --- extract image_rotary_emb --- + if "image_rotary_emb" in kwargs: + image_rotary_emb = kwargs.get("image_rotary_emb") + elif len(args) > 3: + image_rotary_emb = args[3] + else: + image_rotary_emb = None + + ptr = encoder_hidden_states.data_ptr() + + if ptr not in self.kv_cache: + context = module.encoder_proj(encoder_hidden_states) + + attn = module.attn + head_dim = attn.inner_dim // attn.heads + num_kv_heads = attn.inner_kv_dim // head_dim + + txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1)) + txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1)) + + if attn.norm_added_k is not None: + txt_key = attn.norm_added_k(txt_key) + + if image_rotary_emb is not None: + _, txt_freqs = image_rotary_emb + txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) + + self.kv_cache[ptr] = (txt_key, txt_value) + + txt_key, txt_value = self.kv_cache[ptr] + + # Inject cached k/v — block sees cached_txt_key and skips encoder_proj too + attn_kwargs = kwargs.get("attention_kwargs") or {} + attn_kwargs["cached_txt_key"] = txt_key + attn_kwargs["cached_txt_value"] = txt_value + kwargs["attention_kwargs"] = attn_kwargs + + return self.fn_ref.original_forward(*args, **kwargs) + + def reset_state(self, module: torch.nn.Module): + self.kv_cache.clear() + return module + + +def apply_nucleus_moe_text_kv_cache(module: torch.nn.Module, config: NucleusMoETextKVCacheConfig) -> None: + from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock + + for _, submodule in module.named_modules(): + if isinstance(submodule, NucleusMoEImageTransformerBlock): + hook = NucleusMoETextKVCacheHook() + registry = HookRegistry.check_if_exists_or_initialize(submodule) + registry.register_hook(hook, _TEXT_KV_CACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 5f9587a1b4de..365562465930 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,11 +41,12 @@ def enable_cache(self, config) -> None: Enable caching techniques on the model. Args: - config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`): + config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | NucleusMoETextKVCacheConfig`): The configuration for applying the caching technique. Currently supported caching techniques are: - [`~hooks.PyramidAttentionBroadcastConfig`] - [`~hooks.FasterCacheConfig`] - [`~hooks.FirstBlockCacheConfig`] + - [`~hooks.NucleusMoETextKVCacheConfig`] Example: @@ -69,11 +70,13 @@ def enable_cache(self, config) -> None: FasterCacheConfig, FirstBlockCacheConfig, MagCacheConfig, + NucleusMoETextKVCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_mag_cache, + apply_nucleus_moe_text_kv_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, ) @@ -89,6 +92,8 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, MagCacheConfig): apply_mag_cache(self, config) + elif isinstance(config, NucleusMoETextKVCacheConfig): + apply_nucleus_moe_text_kv_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) elif isinstance(config, TaylorSeerCacheConfig): @@ -104,6 +109,7 @@ def disable_cache(self) -> None: FirstBlockCacheConfig, HookRegistry, MagCacheConfig, + NucleusMoETextKVCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, ) @@ -112,6 +118,7 @@ def disable_cache(self) -> None: from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK + from ..hooks.text_kv_cache import _TEXT_KV_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -129,6 +136,8 @@ def disable_cache(self) -> None: registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, NucleusMoETextKVCacheConfig): + registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index 9c2aa17f162a..342919a26148 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -268,6 +268,8 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor | None = None, image_rotary_emb: torch.Tensor | None = None, + cached_txt_key: torch.FloatTensor | None = None, + cached_txt_value: torch.FloatTensor | None = None, ) -> torch.FloatTensor: head_dim = attn.inner_dim // attn.heads num_kv_heads = attn.inner_kv_dim // head_dim @@ -287,7 +289,11 @@ def __call__( img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False) img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False) - if encoder_hidden_states is not None: + if cached_txt_key is not None and cached_txt_value is not None: + txt_key, txt_value = cached_txt_key, cached_txt_value + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + elif encoder_hidden_states is not None: txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1)) @@ -537,17 +543,18 @@ def forward( gate1 = gate1.clamp(min=-2.0, max=2.0) gate2 = gate2.clamp(min=-2.0, max=2.0) - context = self.encoder_proj(encoder_hidden_states) + # Skip encoder_proj when text K/V are already cached — context won't be used by the processor + attn_kwargs = attention_kwargs or {} + context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states) img_normed = self.pre_attn_norm(hidden_states) img_modulated = img_normed * scale1 - attention_kwargs = attention_kwargs or {} img_attn_output = self.attn( hidden_states=img_modulated, encoder_hidden_states=context, image_rotary_emb=image_rotary_emb, - **attention_kwargs, + **attn_kwargs, ) hidden_states = hidden_states + gate1.tanh() * img_attn_output diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 70d8ec8212ad..650ec145744e 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -574,6 +574,10 @@ def __call__( self._num_timesteps = len(timesteps) self.scheduler.set_begin_index(0) + + if self.transformer.is_cache_enabled: + self.transformer._reset_stateful_cache() + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/tests/models/transformers/test_models_transformer_nucleusmoe_image.py b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py new file mode 100644 index 000000000000..edd6de53701a --- /dev/null +++ b/tests/models/transformers/test_models_transformer_nucleusmoe_image.py @@ -0,0 +1,242 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch + +from diffusers import NucleusMoEImageTransformer2DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return NucleusMoEImageTransformer2DModel + + @property + def output_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def input_shape(self) -> tuple[int, int]: + return (16, 16) + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), + "moe_enabled": False, + "capacity_factors": [8.0, 8.0], + } + + def get_dummy_inputs(self) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + height = width = 4 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin): + def test_txt_seq_lens_deprecation(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = [inputs["encoder_hidden_states"].shape[1]] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_with_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + # Mask out some text tokens + mask = inputs["encoder_hidden_states_mask"].clone() + mask[:, 4:] = 0 + inputs["encoder_hidden_states_mask"] = mask + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + + def test_without_attention_mask(self): + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + inputs["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output = model(**inputs) + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] + +class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerLoRAHotSwap( + NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin +): + """LoRA hot-swapping tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for NucleusMoE Image Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict: + batch_size = 1 + in_channels = 16 + joint_attention_dim = 16 + sequence_length = 8 + + hidden_states = randn_tensor( + (batch_size, height * width, in_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + img_shapes = [(1, height, width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for NucleusMoE Image Transformer.""" + + +class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for NucleusMoE Image Transformer.""" diff --git a/tests/pipelines/nucleusmoe_image/__init__.py b/tests/pipelines/nucleusmoe_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py new file mode 100644 index 000000000000..1327bfdcb88b --- /dev/null +++ b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py @@ -0,0 +1,197 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLQwenImage, + FlowMatchEulerDiscreteScheduler, + NucleusMoEImagePipeline, + NucleusMoEImageTransformer2DModel, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = NucleusMoEImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = NucleusMoEImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=4, + joint_attention_dim=16, + axes_dims_rope=(8, 4, 4), + moe_enabled=False, + capacity_factors=[8.0, 8.0], + ) + + torch.manual_seed(0) + z_dim = 4 + vae = AutoencoderKLQwenImage( + base_dim=z_dim * 6, + z_dim=z_dim, + dim_mult=[1, 2, 4], + num_res_blocks=1, + temperal_downsample=[False, True], + # fmt: off + latents_mean=[0.0] * z_dim, + latents_std=[1.0] * z_dim, + # fmt: on + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 8, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + "vocab_size": 151936, + "head_dim": 8, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_channels": 16, + }, + ) + text_encoder = Qwen3VLForConditionalGeneration(config).eval() + processor = Qwen3VLProcessor.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "A cat sitting on a mat", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "true_cfg_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_true_cfg(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["true_cfg_scale"] = 4.0 + inputs["negative_prompt"] = "low quality" + image = pipe(**inputs).images + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_prompt_embeds(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + prompt_embeds, prompt_embeds_mask = pipe.encode_prompt( + prompt=inputs["prompt"], + device=device, + max_sequence_length=inputs["max_sequence_length"], + ) + + inputs_with_embeds = self.get_dummy_inputs(device) + inputs_with_embeds.pop("prompt") + inputs_with_embeds["prompt_embeds"] = prompt_embeds + inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask + + image = pipe(**inputs_with_embeds).images + self.assertEqual(image[0].shape, (3, 32, 32)) From cb63a958b5ccd1e8560707c6ca84a94321dda3a7 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Fri, 20 Mar 2026 23:10:00 +0000 Subject: [PATCH 04/14] Class/function name changes --- src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/text_kv_cache.py | 16 ++++++++-------- src/diffusers/models/cache_utils.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 23466afe5bc5..2a9aa81608e7 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -27,4 +27,4 @@ from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache - from .text_kv_cache import NucleusMoETextKVCacheConfig, apply_nucleus_moe_text_kv_cache + from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache diff --git a/src/diffusers/hooks/text_kv_cache.py b/src/diffusers/hooks/text_kv_cache.py index e718943444ca..53777ae185a0 100644 --- a/src/diffusers/hooks/text_kv_cache.py +++ b/src/diffusers/hooks/text_kv_cache.py @@ -19,15 +19,15 @@ from .hooks import HookRegistry, ModelHook -_TEXT_KV_CACHE_HOOK = "nucleus_moe_text_kv_cache" +_TEXT_KV_CACHE_HOOK = "text_kv_cache" @dataclass -class NucleusMoETextKVCacheConfig: - """Enable exact (lossless) text K/V caching for NucleusMoEImageTransformer2DModel. +class TextKVCacheConfig: + """Enable exact (lossless) text K/V caching for transformer models. Pre-computes per-block text key and value projections once before the - denoising loop and reuses them across all steps. The cached values are keyed by + denoising loop and reuses them across all steps. The cached values are keyed by the ``data_ptr()`` of the ``encoder_hidden_states`` tensor so that both the positive and negative prompts (when ``true_cfg_scale > 1``) are handled correctly. """ @@ -35,14 +35,14 @@ class NucleusMoETextKVCacheConfig: pass # no hyperparameters needed — cache is always exact -class NucleusMoETextKVCacheHook(ModelHook): +class TextKVCacheHook(ModelHook): """Block-level hook that caches (txt_key, txt_value) per unique prompt.""" _is_stateful = True def __init__(self): super().__init__() - # Maps encoder_hidden_states.data_ptr() → (context, txt_key, txt_value) + # Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value) self.kv_cache: dict[int, tuple] = {} def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -99,11 +99,11 @@ def reset_state(self, module: torch.nn.Module): return module -def apply_nucleus_moe_text_kv_cache(module: torch.nn.Module, config: NucleusMoETextKVCacheConfig) -> None: +def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None: from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock for _, submodule in module.named_modules(): if isinstance(submodule, NucleusMoEImageTransformerBlock): - hook = NucleusMoETextKVCacheHook() + hook = TextKVCacheHook() registry = HookRegistry.check_if_exists_or_initialize(submodule) registry.register_hook(hook, _TEXT_KV_CACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 365562465930..3bca773d8344 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,12 +41,12 @@ def enable_cache(self, config) -> None: Enable caching techniques on the model. Args: - config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | NucleusMoETextKVCacheConfig`): + config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`): The configuration for applying the caching technique. Currently supported caching techniques are: - [`~hooks.PyramidAttentionBroadcastConfig`] - [`~hooks.FasterCacheConfig`] - [`~hooks.FirstBlockCacheConfig`] - - [`~hooks.NucleusMoETextKVCacheConfig`] + - [`~hooks.TextKVCacheConfig`] Example: @@ -70,15 +70,15 @@ def enable_cache(self, config) -> None: FasterCacheConfig, FirstBlockCacheConfig, MagCacheConfig, - NucleusMoETextKVCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_mag_cache, - apply_nucleus_moe_text_kv_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_text_kv_cache, ) if self.is_cache_enabled: @@ -92,8 +92,8 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, MagCacheConfig): apply_mag_cache(self, config) - elif isinstance(config, NucleusMoETextKVCacheConfig): - apply_nucleus_moe_text_kv_cache(self, config) + elif isinstance(config, TextKVCacheConfig): + apply_text_kv_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) elif isinstance(config, TaylorSeerCacheConfig): @@ -109,9 +109,9 @@ def disable_cache(self) -> None: FirstBlockCacheConfig, HookRegistry, MagCacheConfig, - NucleusMoETextKVCacheConfig, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK @@ -136,7 +136,7 @@ def disable_cache(self) -> None: registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, NucleusMoETextKVCacheConfig): + elif isinstance(self._cache_config, TextKVCacheConfig): registry.remove_hook(_TEXT_KV_CACHE_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) From f2eec8272ac4439fd3e233ffeb4f6e045228449b Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sat, 21 Mar 2026 05:01:17 +0000 Subject: [PATCH 05/14] add missing imports --- src/diffusers/__init__.py | 4 ++++ src/diffusers/utils/dummy_pt_objects.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0fdcfa97b065..1fa1bd176782 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -169,12 +169,14 @@ "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "TaylorSeerCacheConfig", + "TextKVCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", "apply_mag_cache", "apply_pyramid_attention_broadcast", "apply_taylorseer_cache", + "apply_text_kv_cache", ] ) _import_structure["models"].extend( @@ -955,12 +957,14 @@ PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, TaylorSeerCacheConfig, + TextKVCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, + apply_text_kv_cache, ) from .models import ( AllegroTransformer3DModel, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9fba2f661e3b..222d452d652c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -311,6 +311,25 @@ def apply_taylorseer_cache(*args, **kwargs): requires_backends(apply_taylorseer_cache, ["torch"]) +def apply_text_kv_cache(*args, **kwargs): + requires_backends(apply_text_kv_cache, ["torch"]) + + +class TextKVCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] From d8b50e5dee9fffcf6c46d5141c15c14fd05719ca Mon Sep 17 00:00:00 2001 From: sippycoder Date: Sun, 22 Mar 2026 20:26:07 +0000 Subject: [PATCH 06/14] add RoPE credits --- .../models/transformers/transformer_nucleusmoe_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index 342919a26148..6bf48b70d629 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -37,6 +37,7 @@ logger = logging.get_logger(__name__) +# copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen def _apply_rotary_emb_nucleus( x: torch.Tensor, freqs_cis: torch.Tensor | tuple[torch.Tensor], @@ -122,6 +123,7 @@ def forward(self, timestep, hidden_states, addition_t_cond=None): return self.norm(conditioning) +# copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope class NucleusMoEEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): super().__init__() From 15667849c73d0f5dfc642a0c9b326d67c9469e45 Mon Sep 17 00:00:00 2001 From: sippycoder <134823555+sippycoder@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:34:20 -0700 Subject: [PATCH 07/14] Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../models/transformers/transformer_nucleusmoe_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index 6bf48b70d629..bb2149f5dcd4 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) -# copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen +# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus def _apply_rotary_emb_nucleus( x: torch.Tensor, freqs_cis: torch.Tensor | tuple[torch.Tensor], From a5314002e6f05ebcf3a1d001d5892ad6cd56ab75 Mon Sep 17 00:00:00 2001 From: sippycoder <134823555+sippycoder@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:34:30 -0700 Subject: [PATCH 08/14] Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../models/transformers/transformer_nucleusmoe_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index bb2149f5dcd4..050712833937 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -123,7 +123,7 @@ def forward(self, timestep, hidden_states, addition_t_cond=None): return self.norm(conditioning) -# copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope +# Copied from diffusers.models.transformers.transformer_qwenimage.QwenEmbedRope with Qwen->NucleusMoE class NucleusMoEEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): super().__init__() From 5c7ec5417081a8e142e72fc17523eb82b7e8f17f Mon Sep 17 00:00:00 2001 From: sippycoder <134823555+sippycoder@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:35:14 -0700 Subject: [PATCH 09/14] Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../models/transformers/transformer_nucleusmoe_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index 050712833937..ad66e78a3a64 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -159,7 +159,7 @@ def _rope_params(index, dim, theta=10000): def forward( self, - video_fhw: tuple[int, int, int, list[tuple[int, int, int]]], + video_fhw: tuple[int, int, int] | list[tuple[int, int, int]], txt_seq_lens: list[int] | None = None, device: torch.device = None, max_txt_seq_len: int | torch.Tensor | None = None, From cffe758fa5dac0dd5a279ec18a3df45c7125a3f6 Mon Sep 17 00:00:00 2001 From: sippycoder <134823555+sippycoder@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:40:09 -0700 Subject: [PATCH 10/14] Update src/diffusers/models/transformers/transformer_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../models/transformers/transformer_nucleusmoe_image.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index ad66e78a3a64..ce2f3f762036 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -572,7 +572,8 @@ def forward( hidden_states = hidden_states + gate2.tanh() * img_mlp_output if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) + fp16_finfo = torch.finfo(torch.float16) + hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max) return hidden_states From 6a3b48b58ab2b62161ca7c821b669a9210814f4b Mon Sep 17 00:00:00 2001 From: sippycoder Date: Wed, 25 Mar 2026 17:02:24 +0000 Subject: [PATCH 11/14] update defaults --- .../pipeline_nucleusmoe_image.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 650ec145744e..635b5a8055a3 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -174,8 +174,10 @@ def __init__( 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - self.default_sample_size = 128 - self.return_index = -8 + + self.default_sample_size = 64 + self.default_max_sequence_length = 1024 + self.default_return_index = -8 def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str: if system_prompt is None: @@ -195,7 +197,8 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, - max_sequence_length: int = 1024, + max_sequence_length: int | None = None, + return_index: int | None = None, ): r""" Encode text prompt(s) into embeddings using the Qwen3-VL text encoder. @@ -235,7 +238,7 @@ def encode_prompt( outputs = self.text_encoder( **inputs, use_cache=False, return_dict=True, output_hidden_states=True ) - prompt_embeds = outputs.hidden_states[self.return_index] + prompt_embeds = outputs.hidden_states[return_index] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) else: prompt_embeds = prompt_embeds.to(device=device) @@ -266,6 +269,7 @@ def check_inputs( negative_prompt_embeds_mask=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, + return_index=None, ): if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( @@ -300,10 +304,17 @@ def check_inputs( "Please make sure to only forward one of the two." ) - if max_sequence_length is not None and max_sequence_length > 1024: + if max_sequence_length is not None and max_sequence_length > self.default_max_sequence_length: raise ValueError( f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}" ) + + if return_index is not None and abs(return_index) >= self.text_encoder.config.text_config.num_hidden_layers: + raise ValueError( + f"absolute value of `return_index` cannot be >= {self.text_encoder.config.text_config.num_hidden_layers} " + f"but is {abs(return_index)}" + ) + @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): @@ -414,6 +425,8 @@ def __call__( num_inference_steps: int = 50, sigmas: list[float] | None = None, num_images_per_prompt: int = 1, + max_sequence_length: int | None = None, + return_index: int | None = None, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, @@ -425,7 +438,6 @@ def __call__( attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, dict], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], - max_sequence_length: int = 512, ): r""" Function invoked when calling the pipeline for generation. @@ -484,6 +496,9 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + max_sequence_length = max_sequence_length or self.default_max_sequence_length + return_index = return_index or self.default_return_index + self.check_inputs( prompt, height, @@ -495,6 +510,7 @@ def __call__( negative_prompt_embeds_mask=negative_prompt_embeds_mask, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, + return_index=return_index, ) self._attention_kwargs = attention_kwargs or {} @@ -525,6 +541,7 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + return_index=return_index, ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( @@ -534,6 +551,7 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + return_index=return_index, ) num_channels_latents = self.transformer.config.in_channels // 4 From 96f6a05bd69037244419d356a407320756150a0c Mon Sep 17 00:00:00 2001 From: sippycoder <134823555+sippycoder@users.noreply.github.com> Date: Wed, 25 Mar 2026 10:06:29 -0700 Subject: [PATCH 12/14] Update src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --- .../pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 635b5a8055a3..52eabdf64f99 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -265,8 +265,6 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, - prompt_embeds_mask=None, - negative_prompt_embeds_mask=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, return_index=None, From 151bec4e395e3d5f0882ea1027d4ce48eeb5e85a Mon Sep 17 00:00:00 2001 From: sippycoder Date: Thu, 26 Mar 2026 06:24:39 +0000 Subject: [PATCH 13/14] review updates --- src/diffusers/hooks/text_kv_cache.py | 2 +- .../transformer_nucleusmoe_image.py | 44 ++------- .../pipeline_nucleusmoe_image.py | 90 +++++++------------ 3 files changed, 40 insertions(+), 96 deletions(-) diff --git a/src/diffusers/hooks/text_kv_cache.py b/src/diffusers/hooks/text_kv_cache.py index 53777ae185a0..fb1a4875b366 100644 --- a/src/diffusers/hooks/text_kv_cache.py +++ b/src/diffusers/hooks/text_kv_cache.py @@ -43,7 +43,7 @@ class TextKVCacheHook(ModelHook): def __init__(self): super().__init__() # Maps encoder_hidden_states.data_ptr() → (txt_key, txt_value) - self.kv_cache: dict[int, tuple] = {} + self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} def new_forward(self, module: torch.nn.Module, *args, **kwargs): from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus diff --git a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py index ce2f3f762036..88df164f7d03 100644 --- a/src/diffusers/models/transformers/transformer_nucleusmoe_image.py +++ b/src/diffusers/models/transformers/transformer_nucleusmoe_image.py @@ -22,8 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers -from ...utils.torch_utils import maybe_allow_in_graph +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -160,7 +159,6 @@ def _rope_params(index, dim, theta=10000): def forward( self, video_fhw: tuple[int, int, int] | list[tuple[int, int, int]], - txt_seq_lens: list[int] | None = None, device: torch.device = None, max_txt_seq_len: int | torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -168,26 +166,13 @@ def forward( Args: video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`list[int]`, *optional*, **Deprecated**): - Deprecated parameter. Use `max_txt_seq_len` instead. device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. max_txt_seq_len (`int` or `torch.Tensor`, *optional*): The maximum text sequence length for RoPE computation. """ - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " - "Please use `max_txt_seq_len` instead.", - standard_warn=False, - ) - if max_txt_seq_len is None: - max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens - if max_txt_seq_len is None: - raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + raise ValueError("Either `max_txt_seq_len` must be provided.") if isinstance(video_fhw, list) and len(video_fhw) > 1: first_fhw = video_fhw[0] @@ -457,7 +442,6 @@ def forward( return out -@maybe_allow_in_graph class NucleusMoEImageTransformerBlock(nn.Module): """ Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image @@ -540,7 +524,6 @@ def forward( attention_kwargs: dict[str, Any] | None = None, ) -> torch.Tensor: scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1) - scale1, scale2 = 1 + scale1, 1 + scale2 gate1 = gate1.clamp(min=-2.0, max=2.0) gate2 = gate2.clamp(min=-2.0, max=2.0) @@ -550,7 +533,7 @@ def forward( context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states) img_normed = self.pre_attn_norm(hidden_states) - img_modulated = img_normed * scale1 + img_modulated = img_normed * (1 + scale1) img_attn_output = self.attn( hidden_states=img_modulated, @@ -562,7 +545,7 @@ def forward( hidden_states = hidden_states + gate1.tanh() * img_attn_output img_normed2 = self.pre_mlp_norm(hidden_states) - img_modulated2 = img_normed2 * scale2 + img_modulated2 = img_normed2 * (1 + scale2) if self.moe_enabled: img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb) @@ -614,7 +597,7 @@ class NucleusMoEImageTransformer2DModel( Number of experts per MoE layer. moe_intermediate_dim (`int`, defaults to `1344`): Hidden dimension inside each expert. - capacity_factors (`list[float]`, defaults to `[8.0] * 24`): + capacity_factors (`float | list[float]`, defaults to `8.0`): Expert-choice capacity factor per layer. use_sigmoid (`bool`, defaults to `False`): Use sigmoid instead of softmax for routing scores. @@ -644,13 +627,14 @@ def __init__( dense_moe_strategy: str = "leave_first_three_and_last_block_dense", num_experts: int = 128, moe_intermediate_dim: int = 1344, - capacity_factors: List[float] = [8.0] * 24, + capacity_factors: float | list[float] = 8.0, use_sigmoid: bool = False, route_scale: float = 2.5, ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim + capacity_factors = capacity_factors if isinstance(capacity_factors, list) else [capacity_factors] * num_layers self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) @@ -687,11 +671,10 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - img_shapes: list[tuple[int, int, int]] | None = None, + img_shapes: tuple[int, int, int] | list[tuple[int, int, int]], encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_mask: torch.Tensor = None, timestep: torch.LongTensor = None, - txt_seq_lens: list[int] | None = None, attention_kwargs: dict[str, Any] | None = None, return_dict: bool = True, ) -> torch.Tensor | Transformer2DModelOutput: @@ -709,8 +692,6 @@ def forward( Boolean mask for the encoder hidden states. timestep (`torch.LongTensor`): Used to indicate denoising step. - txt_seq_lens (`list[int]`, *optional*, **Deprecated**): - Deprecated. Use ``encoder_hidden_states_mask`` instead. attention_kwargs (`dict`, *optional*): Extra kwargs forwarded to the attention processor. return_dict (`bool`, *optional*, defaults to `True`): @@ -720,15 +701,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if txt_seq_lens is not None: - deprecate( - "txt_seq_lens", - "0.39.0", - "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " - "Please use `encoder_hidden_states_mask` instead.", - standard_warn=False, - ) - if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 52eabdf64f99..7eaaa6eaf5fe 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -59,6 +59,7 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -72,6 +73,7 @@ def calculate_shift( return mu +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: int | None = None, @@ -315,26 +317,27 @@ def check_inputs( @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size): + latents = latents.view(batch_size, num_channels_latents, height // patch_size, patch_size, width // patch_size, patch_size) latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + latents = latents.reshape(batch_size, (height // patch_size) * (width // patch_size), num_channels_latents * patch_size * patch_size) return latents @staticmethod - def _unpack_latents(latents, height, width, vae_scale_factor): + def _unpack_latents(latents, height, width, patch_size, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + latents = latents.view(batch_size, height // patch_size, width // patch_size, channels // (patch_size * patch_size), patch_size, patch_size) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width) return latents def prepare_latents( self, batch_size, num_channels_latents, + patch_size, height, width, dtype, @@ -342,8 +345,8 @@ def prepare_latents( generator, latents=None, ): - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) + height = patch_size * (int(height) // (self.vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (self.vae_scale_factor * patch_size)) shape = (batch_size, 1, num_channels_latents, height, width) if latents is not None: @@ -356,45 +359,9 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size) return latents - def enable_vae_slicing(self): - r"""Enable sliced VAE decoding for memory efficiency.""" - depr_message = ( - f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and will be " - "removed in a future version. Please use `pipe.vae.enable_slicing()`." - ) - deprecate("enable_vae_slicing", "0.40.0", depr_message) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r"""Disable sliced VAE decoding.""" - depr_message = ( - f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and will be " - "removed in a future version. Please use `pipe.vae.disable_slicing()`." - ) - deprecate("disable_vae_slicing", "0.40.0", depr_message) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r"""Enable tiled VAE decoding for memory efficiency.""" - depr_message = ( - f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and will be " - "removed in a future version. Please use `pipe.vae.enable_tiling()`." - ) - deprecate("enable_vae_tiling", "0.40.0", depr_message) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r"""Disable tiled VAE decoding.""" - depr_message = ( - f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and will be " - "removed in a future version. Please use `pipe.vae.disable_tiling()`." - ) - deprecate("disable_vae_tiling", "0.40.0", depr_message) - self.vae.disable_tiling() - @property def attention_kwargs(self): return self._attention_kwargs @@ -417,7 +384,7 @@ def __call__( self, prompt: str | list[str] = None, negative_prompt: str | list[str] = None, - true_cfg_scale: float = 4.0, + guidance_scale: float = 4.0, height: int | None = None, width: int | None = None, num_inference_steps: int = 50, @@ -527,9 +494,9 @@ def __call__( has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) - do_true_cfg = true_cfg_scale > 1 + do_cfg = guidance_scale > 1 - if do_true_cfg and not has_neg_prompt: + if do_cfg and not has_neg_prompt: negative_prompt = [""] * batch_size prompt_embeds, prompt_embeds_mask = self.encode_prompt( @@ -541,7 +508,7 @@ def __call__( max_sequence_length=max_sequence_length, return_index=return_index, ) - if do_true_cfg: + if do_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, @@ -553,9 +520,12 @@ def __call__( ) num_channels_latents = self.transformer.config.in_channels // 4 + patch_size = self.transformer.config.patch_size + latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, + patch_size, height, width, prompt_embeds.dtype, @@ -564,9 +534,11 @@ def __call__( latents, ) - latent_h = 2 * (int(height) // (self.vae_scale_factor * 2)) - latent_w = 2 * (int(width) // (self.vae_scale_factor * 2)) - img_shapes = [(1, latent_h // 2, latent_w // 2)] * (batch_size * num_images_per_prompt) + img_shapes = [( + 1, + height // self.vae_scale_factor // patch_size, + width // self.vae_scale_factor // patch_size + )] * (batch_size * num_images_per_prompt) sigmas = ( np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas @@ -604,7 +576,7 @@ def __call__( noise_pred = self.transformer( hidden_states=latents, - timestep=timestep / 1000, + timestep=timestep / self.scheduler.config.num_train_timesteps, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, @@ -612,10 +584,10 @@ def __call__( return_dict=False, )[0] - if do_true_cfg: + if do_cfg: neg_noise_pred = self.transformer( hidden_states=latents, - timestep=timestep / 1000, + timestep=timestep / self.scheduler.config.num_train_timesteps, encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states_mask=negative_prompt_embeds_mask, img_shapes=img_shapes, @@ -623,7 +595,7 @@ def __call__( return_dict=False, )[0] - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) @@ -658,7 +630,7 @@ def __call__( if output_type == "latent": image = latents else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = self._unpack_latents(latents, height, width, patch_size, self.vae_scale_factor) latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) From e9624106393b97ac4f0ef718da3a56216e61ae6f Mon Sep 17 00:00:00 2001 From: sippycoder Date: Thu, 26 Mar 2026 08:08:09 +0000 Subject: [PATCH 14/14] fix the tests --- .../pipeline_nucleusmoe_image.py | 11 +- .../nucleusmoe_image/test_nucleusmoe_image.py | 147 +++++++++++++++++- 2 files changed, 154 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py index 7eaaa6eaf5fe..cab4b7ad02b9 100644 --- a/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py +++ b/src/diffusers/pipelines/nucleusmoe_image/pipeline_nucleusmoe_image.py @@ -220,6 +220,7 @@ def encode_prompt( Maximum token length for the encoded prompt. """ device = device or self._execution_device + return_index = return_index or self.default_return_index if prompt_embeds is None: prompt = [prompt] if isinstance(prompt, str) else prompt @@ -266,7 +267,9 @@ def check_inputs( width, negative_prompt=None, prompt_embeds=None, + prompt_embeds_mask=None, negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, return_index=None, @@ -362,6 +365,10 @@ def prepare_latents( latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size) return latents + @property + def guidance_scale(self): + return self._guidance_scale + @property def attention_kwargs(self): return self._attention_kwargs @@ -462,7 +469,6 @@ def __call__( width = width or self.default_sample_size * self.vae_scale_factor max_sequence_length = max_sequence_length or self.default_max_sequence_length - return_index = return_index or self.default_return_index self.check_inputs( prompt, @@ -470,14 +476,15 @@ def __call__( width, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_mask=negative_prompt_embeds_mask, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, return_index=return_index, ) + self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs or {} self._current_timestep = None self._interrupt = False diff --git a/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py index 1327bfdcb88b..6b5c4b9a4baf 100644 --- a/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py +++ b/tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest +import numpy as np import torch from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor +from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor + from diffusers import ( AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler, @@ -135,7 +139,8 @@ def get_dummy_inputs(self, device, seed=0): "negative_prompt": "bad quality", "generator": generator, "num_inference_steps": 2, - "true_cfg_scale": 1.0, + "return_index": -1, + "guidance_scale": 1.0, "height": 32, "width": 32, "max_sequence_length": 16, @@ -168,7 +173,7 @@ def test_true_cfg(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - inputs["true_cfg_scale"] = 4.0 + inputs["guidance_scale"] = 4.0 inputs["negative_prompt"] = "low quality" image = pipe(**inputs).images self.assertEqual(image[0].shape, (3, 32, 32)) @@ -195,3 +200,141 @@ def test_prompt_embeds(self): image = pipe(**inputs_with_embeds).images self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + # PipelineTesterMixin compares outputs with assert_mean_pixel_difference, which assumes HWC numpy/PIL layout. + # With output_type="pt", tensors are CHW; numpy_to_pil then fails. Match QwenImage: only assert max diff. + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + # PipelineTesterMixin only keeps components whose keys contain "text" or "tokenizer"; this pipeline also + # needs `processor` for encode_prompt (apply_chat_template). Mirror the mixin with that key included. + if not hasattr(self.pipeline_class, "encode_prompt"): + return + + components = self.get_dummy_components() + for key in components: + if "text_encoder" in key and hasattr(components[key], "eval"): + components[key].eval() + + def _is_text_stack_component(k): + return "text" in k or "tokenizer" in k or k == "processor" + + components_with_text_encoders = {} + for k in components: + if _is_text_stack_component(k): + components_with_text_encoders[k] = components[k] + else: + components_with_text_encoders[k] = None + pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders) + pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt) + encode_prompt_parameters = list(encode_prompt_signature.parameters.values()) + + required_params = [] + for param in encode_prompt_parameters: + if param.name == "self" or param.name == "kwargs": + continue + if param.default is inspect.Parameter.empty: + required_params.append(param.name) + + encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"] + input_keys = list(inputs.keys()) + encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names} + + pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__) + pipe_call_parameters = pipe_call_signature.parameters + + for required_param_name in required_params: + if required_param_name not in encode_prompt_inputs: + pipe_call_param = pipe_call_parameters.get(required_param_name, None) + if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty: + encode_prompt_inputs[required_param_name] = pipe_call_param.default + elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict): + encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name] + else: + raise ValueError( + f"Required parameter '{required_param_name}' in " + f"encode_prompt has no default in either encode_prompt or __call__." + ) + + with torch.no_grad(): + encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs) + + ast_visitor = ReturnNameVisitor() + encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class) + ast_visitor.visit(encode_prompt_tree) + prompt_embed_kwargs = ast_visitor.return_names + prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs)) + + adapted_prompt_embeds_kwargs = { + k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters + } + + components_with_text_encoders = {} + for k in components: + if _is_text_stack_component(k): + components_with_text_encoders[k] = None + else: + components_with_text_encoders[k] = components[k] + pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device) + + pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs} + if ( + pipe_call_parameters.get("negative_prompt", None) is not None + and pipe_call_parameters.get("negative_prompt").default is not None + ): + pipe_without_tes_inputs.update({"negative_prompt": None}) + + if ( + pipe_call_parameters.get("prompt", None) is not None + and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty + and pipe_call_parameters.get("prompt_embeds", None) is not None + and pipe_call_parameters.get("prompt_embeds").default is None + ): + pipe_without_tes_inputs.update({"prompt": None}) + + pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0] + + full_pipe = self.pipeline_class(**components).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + pipe_out_2 = full_pipe(**inputs)[0] + + if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray): + self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol)) + elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor): + self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))