From f944a6458e1f55183254a123478967c9b8e9e194 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Mon, 23 Mar 2026 15:53:04 +0000 Subject: [PATCH 1/3] initial architecture --- .../stable_diffusion_3/__init__.py | 46 +++ .../stable_diffusion_3/before_denoise.py | 280 ++++++++++++++++ .../stable_diffusion_3/decoders.py | 51 +++ .../stable_diffusion_3/denoise.py | 153 +++++++++ .../stable_diffusion_3/encoders.py | 304 ++++++++++++++++++ .../stable_diffusion_3/inputs.py | 141 ++++++++ .../modular_blocks_stable_diffusion_3.py | 105 ++++++ .../stable_diffusion_3/modular_pipeline.py | 48 +++ .../stable_diffusion_3/__init__.py | 0 ...est_modular_pipeline_stable_diffusion_3.py | 122 +++++++ 10 files changed, 1250 insertions(+) create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py create mode 100644 tests/modular_pipelines/stable_diffusion_3/__init__.py create mode 100644 tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..13396327ee7c --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_stable_diffusion_3"] = ["SD3AutoBlocks"] + _import_structure["modular_pipeline"] = ["SD3ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_stable_diffusion_3 import SD3AutoBlocks + from .modular_pipeline import SD3ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py new file mode 100644 index 000000000000..7eee1d7dc652 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -0,0 +1,280 @@ +import inspect + +import numpy as np +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + + +logger = logging.get_logger(__name__) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + height, + width, + patch_size, + vae_scale_factor, + num_inference_steps, + sigmas, + device, + mu=None, +): + scheduler_kwargs = {} + if scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (height // vae_scale_factor // patch_size) * (width // vae_scale_factor // patch_size) + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) + return timesteps, num_inference_steps + + +class SD3SetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("mu", type_hint=float), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None) + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class SD3Img2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for img2img inference" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("strength", default=0.6), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("mu", type_hint=float), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + ] + + @staticmethod + def get_timesteps(scheduler, num_inference_steps, strength): + init_timestep = min(num_inference_steps * strength, num_inference_steps) + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None) + ) + + timesteps, num_inference_steps = self.get_timesteps( + components.scheduler, num_inference_steps, block_state.strength + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class SD3PrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Prepare latents step for Text-to-Image" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam("batch_size", required=True, type_hint=int), + InputParam("dtype", type_hint=torch.dtype), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam("latents", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + if block_state.latents is not None: + block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) + else: + shape = ( + batch_size, + components.num_channels_latents, + int(block_state.height) // components.vae_scale_factor, + int(block_state.width) // components.vae_scale_factor, + ) + block_state.latents = randn_tensor(shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class SD3Img2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("image_latents", required=True, type_hint=torch.Tensor), + InputParam("timesteps", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("initial_noise", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.initial_noise = block_state.latents + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py new file mode 100644 index 000000000000..3f037f1fee01 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -0,0 +1,51 @@ +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL +from ...utils import logging +from ...image_processor import VaeImageProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) + + +class SD3DecodeStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("output_type", default="pil"), + InputParam("latents", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("images", type_hint=list[PIL.Image.Image] | torch.Tensor)] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = (block_state.latents / vae.config.scaling_factor) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py new file mode 100644 index 000000000000..4341c3daf3c9 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -0,0 +1,153 @@ +from typing import Any + +import torch + +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + + +class SD3LoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", SD3Transformer2DModel)] + + @property + def description(self) -> str: + return "Step within the denoising loop that denoise the latents." + + @property + def inputs(self) -> list[tuple[str, Any]]: + return[ + InputParam("joint_attention_kwargs"), + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("do_classifier_free_guidance", type_hint=bool), + InputParam("guidance_scale", default=7.0), + InputParam("skip_guidance_layers", type_hint=list), + InputParam("skip_layer_guidance_scale", default=2.8), + InputParam("skip_layer_guidance_stop", default=0.2), + InputParam("skip_layer_guidance_start", default=0.01), + InputParam("original_prompt_embeds", type_hint=torch.Tensor), + InputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("num_inference_steps", type_hint=int), + ] + + @torch.no_grad() + def __call__( + self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latent_model_input = torch.cat([block_state.latents] * 2) if block_state.do_classifier_free_guidance else block_state.latents + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + )[0] + + if block_state.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + block_state.guidance_scale * (noise_pred_text - noise_pred_uncond) + + should_skip_layers = ( + getattr(block_state, "skip_guidance_layers", None) is not None + and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) + and i < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) + ) + + if should_skip_layers: + timestep_skip = t.expand(block_state.latents.shape[0]) + noise_pred_skip_layers = components.transformer( + hidden_states=block_state.latents, + timestep=timestep_skip, + encoder_hidden_states=block_state.original_prompt_embeds, + pooled_projections=block_state.original_pooled_prompt_embeds, + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + skip_layers=block_state.skip_guidance_layers, + )[0] + noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * getattr(block_state, "skip_layer_guidance_scale", 2.8) + + block_state.noise_pred = noise_pred + return components, block_state + + +class SD3LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam("latents", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class SD3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", SD3Transformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return[ + InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam("num_inference_steps", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class SD3DenoiseStep(SD3DenoiseLoopWrapper): + block_classes = [SD3LoopDenoiser, SD3LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py new file mode 100644 index 000000000000..24f38fbfce38 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -0,0 +1,304 @@ +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import SD3LoraLoaderMixin +from ...models import AutoencoderKL +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): + if isinstance(generator, list): + image_latents =[ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + return image_latents + +class SD3ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Image Preprocess step for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return[InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam(name="processed_image")] + + @staticmethod + def check_inputs(height, width, vae_scale_factor, patch_size): + if height is not None and height % (vae_scale_factor * patch_size) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * patch_size} but is {height}") + + if width is not None and width % (vae_scale_factor * patch_size) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}") + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("`resized_image` and `image` cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, + vae_scale_factor=components.vae_scale_factor, patch_size=components.patch_size + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + +class SD3VaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__(self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"): + self._image_input_name = input_name + self._image_latents_output_name = output_name + self.sample_mode = sample_mode + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKL)] + + @property + def inputs(self) -> list[InputParam]: + return[InputParam(self._image_input_name), InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latents representing the reference image") + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image = getattr(block_state, self._image_input_name) + + if image is None: + setattr(block_state, self._image_latents_output_name, None) + else: + device = components._execution_device + dtype = components.vae.dtype + image = image.to(device=device, dtype=dtype) + image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode + ) + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + return components, state + +class SD3TextEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the image generation for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("text_encoder", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("text_encoder_3", T5EncoderModel), + ComponentSpec("tokenizer_3", T5TokenizerFast), + ] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("prompt_3"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("negative_prompt_3"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("guidance_scale", default=7.0), + InputParam("clip_skip", type_hint=int), + InputParam("max_sequence_length", type_hint=int, default=256), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @staticmethod + def _get_t5_prompt_embeds(components, prompt, max_sequence_length, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if components.text_encoder_3 is None: + return torch.zeros( + (batch_size, max_sequence_length, components.transformer.config.joint_attention_dim), + device=device, + dtype=components.text_encoder.dtype, + ) + + text_inputs = components.tokenizer_3( + prompt, padding="max_length", max_length=max_sequence_length, + truncation=True, add_special_tokens=True, return_tensors="pt", + ) + prompt_embeds = components.text_encoder_3(text_inputs.input_ids.to(device))[0] + return prompt_embeds.to(dtype=components.text_encoder_3.dtype, device=device) + + @staticmethod + def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_index): + clip_tokenizers = [components.tokenizer, components.tokenizer_2] + clip_text_encoders =[components.text_encoder, components.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, padding="max_length", max_length=tokenizer.model_max_length, + truncation=True, return_tensors="pt", + ) + + prompt_embeds = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + return prompt_embeds.to(dtype=components.text_encoder.dtype, device=device), pooled_prompt_embeds + + @staticmethod + def encode_prompt(components, block_state, device, do_classifier_free_guidance, lora_scale): + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if components.text_encoder is not None: scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt_embeds = block_state.prompt_embeds + pooled_prompt_embeds = block_state.pooled_prompt_embeds + + if prompt_embeds is None: + prompt = [block_state.prompt] if isinstance(block_state.prompt, str) else block_state.prompt + prompt_2 = block_state.prompt_2 or prompt + prompt_3 = block_state.prompt_3 or prompt + + prompt_embed, pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt, device, block_state.clip_skip, 0) + prompt_2_embed, pooled_2_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt_2, device, block_state.clip_skip, 1) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, prompt_3, block_state.max_sequence_length, device) + clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_embed, pooled_2_embed], dim=-1) + + negative_prompt_embeds = block_state.negative_prompt_embeds + negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds + + if do_classifier_free_guidance and negative_prompt_embeds is None: + batch_size = prompt_embeds.shape[0] + neg_prompt = block_state.negative_prompt or "" + neg_prompt_2 = block_state.negative_prompt_2 or neg_prompt + neg_prompt_3 = block_state.negative_prompt_3 or neg_prompt + + neg_prompt = batch_size * [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt + neg_prompt_2 = batch_size * [neg_prompt_2] if isinstance(neg_prompt_2, str) else neg_prompt_2 + neg_prompt_3 = batch_size * [neg_prompt_3] if isinstance(neg_prompt_3, str) else neg_prompt_3 + + neg_embed, neg_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt, device, None, 0) + neg_2_embed, neg_2_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt_2, device, None, 1) + neg_clip_embeds = torch.cat([neg_embed, neg_2_embed], dim=-1) + + t5_neg_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, neg_prompt_3, block_state.max_sequence_length, device) + neg_clip_embeds = torch.nn.functional.pad(neg_clip_embeds, (0, t5_neg_embed.shape[-1] - neg_clip_embeds.shape[-1])) + + negative_prompt_embeds = torch.cat([neg_clip_embeds, t5_neg_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat([neg_pooled_embed, neg_2_pooled_embed], dim=-1) + + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if components.text_encoder is not None: unscale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + do_classifier_free_guidance = block_state.guidance_scale > 1.0 + lora_scale = block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None + + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + components, block_state, block_state.device, do_classifier_free_guidance, lora_scale + ) + + block_state.prompt_embeds = prompt_embeds + block_state.negative_prompt_embeds = negative_prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py new file mode 100644 index 000000000000..61ca894faafc --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -0,0 +1,141 @@ +import torch +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam +from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + +class SD3TextInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=7.0), + InputParam("skip_guidance_layers", type_hint=list), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return[ + OutputParam("batch_size", type_hint=int), + OutputParam("dtype", type_hint=torch.dtype), + OutputParam("do_classifier_free_guidance", type_hint=bool), + OutputParam("prompt_embeds", type_hint=torch.Tensor), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("original_prompt_embeds", type_hint=torch.Tensor), + OutputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 + + _, seq_len, _ = block_state.prompt_embeds.shape + prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: + _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape + negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, neg_seq_len, -1) + + negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.skip_guidance_layers is not None: + block_state.original_prompt_embeds = prompt_embeds + block_state.original_pooled_prompt_embeds = pooled_prompt_embeds + + block_state.prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + block_state.pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + block_state.prompt_embeds = prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state + +class SD3AdditionalInputsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] =[]): + self._image_latent_inputs = image_latent_inputs if isinstance(image_latent_inputs, list) else [image_latent_inputs] + self._additional_batch_inputs = additional_batch_inputs if isinstance(additional_batch_inputs, list) else[additional_batch_inputs] + super().__init__() + + @property + def description(self) -> str: + return "Updates height/width if None, and expands batch size. SD3 does not pack latents on pipeline level." + + @property + def inputs(self) -> list[InputParam]: + inputs =[ + InputParam("num_images_per_prompt", default=1), + InputParam("batch_size", required=True), + InputParam("height"), + InputParam("width"), + ] + for name in self._image_latent_inputs + self._additional_batch_inputs: + inputs.append(InputParam(name)) + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("image_height", type_hint=int), + OutputParam("image_width", type_hint=int), + ] + + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + for input_name in self._image_latent_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: + continue + + height, width = calculate_dimension_from_latents(tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + tensor = repeat_tensor_to_batch_size( + input_name=input_name, input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + ) + setattr(block_state, input_name, tensor) + + for input_name in self._additional_batch_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: continue + tensor = repeat_tensor_to_batch_size( + input_name=input_name, input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + ) + setattr(block_state, input_name, tensor) + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py new file mode 100644 index 000000000000..719910b5ca8f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -0,0 +1,105 @@ +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + SD3Img2ImgPrepareLatentsStep, + SD3Img2ImgSetTimestepsStep, + SD3PrepareLatentsStep, + SD3SetTimestepsStep, +) +from .decoders import SD3DecodeStep +from .denoise import SD3DenoiseStep +from .encoders import ( + SD3ProcessImagesInputStep, + SD3TextEncoderStep, + SD3VaeEncoderStep, +) +from .inputs import ( + SD3AdditionalInputsStep, + SD3TextInputStep, +) + + +logger = logging.get_logger(__name__) + + +class SD3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = [SD3ProcessImagesInputStep(), SD3VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + +class SD3AutoVaeEncoderStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3Img2ImgVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs =["image"] + + +class SD3BeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3PrepareLatentsStep(), SD3SetTimestepsStep()] + block_names = ["prepare_latents", "set_timesteps"] + + +class SD3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[ + SD3PrepareLatentsStep(), + SD3Img2ImgSetTimestepsStep(), + SD3Img2ImgPrepareLatentsStep(), + ] + block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents"] + + +class SD3AutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3Img2ImgBeforeDenoiseStep, SD3BeforeDenoiseStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + +class SD3Img2ImgInputStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3TextInputStep(), SD3AdditionalInputsStep()] + block_names =["text_inputs", "additional_inputs"] + + +class SD3AutoInputStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = [SD3Img2ImgInputStep, SD3TextInputStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + +class SD3CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3AutoInputStep, SD3AutoBeforeDenoiseStep, SD3DenoiseStep] + block_names =["input", "before_denoise", "denoise"] + @property + def outputs(self): + return [OutputParam.template("latents")] + + +AUTO_BLOCKS = InsertableDict([ + ("text_encoder", SD3TextEncoderStep()), + ("vae_encoder", SD3AutoVaeEncoderStep()), + ("denoise", SD3CoreDenoiseStep()), + ("decode", SD3DecodeStep()), + ] +) + + +class SD3AutoBlocks(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def outputs(self): + return [OutputParam.template("images")] \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py new file mode 100644 index 000000000000..56033fa08bc7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -0,0 +1,48 @@ +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + +logger = logging.get_logger(__name__) + + +class SD3ModularPipeline(ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): + """ + A ModularPipeline for Stable Diffusion 3. + + >[!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "SD3AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.sample_size + return 128 + + @property + def patch_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.patch_size + return 2 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.in_channels + return 16 \ No newline at end of file diff --git a/tests/modular_pipelines/stable_diffusion_3/__init__.py b/tests/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py new file mode 100644 index 000000000000..20c1542ee3ab --- /dev/null +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +import random +import numpy as np +import PIL +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.modular_pipelines import ModularPipeline +from diffusers.modular_pipelines.stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +SD3_TEXT2IMAGE_WORKFLOWS = { + "text2image":[ + ("text_encoder", "SD3TextEncoderStep"), + ("denoise.input", "SD3TextInputStep"), + ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "SD3SetTimestepsStep"), + ("denoise.denoise", "SD3DenoiseStep"), + ("decode", "SD3DecodeStep"), + ] +} + +class TestSD3ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = SD3ModularPipeline + pipeline_blocks_class = SD3AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + expected_workflow_blocks = SD3_TEXT2IMAGE_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + return { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +SD3_IMAGE2IMAGE_WORKFLOWS = { + "image2image":[ + ("text_encoder", "SD3TextEncoderStep"), + ("vae_encoder.preprocess", "SD3ProcessImagesInputStep"), + ("vae_encoder.encode", "SD3VaeEncoderStep"), + ("denoise.input.text_inputs", "SD3TextInputStep"), + ("denoise.input.additional_inputs", "SD3AdditionalInputsStep"), + ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "SD3Img2ImgSetTimestepsStep"), + ("denoise.before_denoise.prepare_img2img_latents", "SD3Img2ImgPrepareLatentsStep"), + ("denoise.denoise", "SD3DenoiseStep"), + ("decode", "SD3DecodeStep"), + ] +} + +class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = SD3ModularPipeline + pipeline_blocks_class = SD3AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = super().get_pipeline(components_manager, torch_dtype) + pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) + return pipeline + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 4, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") + inputs["image"] = init_image + inputs["strength"] = 0.5 + return inputs + + def test_save_from_pretrained(self, tmp_path): + pipes =[] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) + pipes.append(pipe) + + image_slices =[] + for pipe in pipes: + inputs = self.get_dummy_inputs() + image = pipe(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_float16_inference(self): + super().test_float16_inference(8e-2) \ No newline at end of file From 08d14c60e7d155568974b02792bc34248776779d Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 24 Mar 2026 17:17:32 +0000 Subject: [PATCH 2/3] add blocks to various inits --- src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 2 + .../modular_pipelines/modular_pipeline.py | 3 +- .../stable_diffusion_3/before_denoise.py | 14 ++++ .../stable_diffusion_3/decoders.py | 14 ++++ .../stable_diffusion_3/denoise.py | 14 ++++ .../stable_diffusion_3/encoders.py | 14 ++++ .../stable_diffusion_3/inputs.py | 21 +++++- .../modular_blocks_stable_diffusion_3.py | 14 ++++ .../stable_diffusion_3/modular_pipeline.py | 14 ++++ .../dummy_torch_and_transformers_objects.py | 29 ++++++++ ...est_modular_pipeline_stable_diffusion_3.py | 73 +++++++++++++++++-- 12 files changed, 205 insertions(+), 11 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0be7b8166a37..0f2852baf421 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -450,6 +450,8 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", + "SD3AutoBlocks", + "SD3ModularPipeline", "Wan22Blocks", "Wan22Image2VideoBlocks", "Wan22Image2VideoModularPipeline", @@ -1211,6 +1213,8 @@ QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, + SD3AutoBlocks, + SD3ModularPipeline, Wan22Blocks, Wan22Image2VideoBlocks, Wan22Image2VideoModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index fd9bd691ca87..e9a92c5704ac 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,6 +46,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["stable_diffusion_3"] =["SD3AutoBlocks", "SD3ModularPipeline"] _import_structure["wan"] = [ "WanBlocks", "Wan22Blocks", @@ -141,6 +142,7 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline + from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline from .wan import ( Wan22Blocks, Wan22Image2VideoBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9cd2f9f5c6ae..e2ca24812e72 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -119,8 +119,9 @@ def _helios_pyramid_map_fn(config_dict=None): MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), + ("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")), ("wan", _wan_map_fn), - ("wan-i2v", _wan_i2v_map_fn), + ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), ("flux2", _create_default_map_fn("Flux2ModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 7eee1d7dc652..ebadf45236da 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import numpy as np diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 3f037f1fee01..c8d9f6a562c1 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import PIL import torch diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index 4341c3daf3c9..a41e87665ede 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any import torch diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 24f38fbfce38..6087f349a691 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 61ca894faafc..5ae213b09040 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -53,6 +67,9 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + block_state.original_prompt_embeds = prompt_embeds + block_state.original_pooled_prompt_embeds = pooled_prompt_embeds + if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) @@ -61,10 +78,6 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if block_state.skip_guidance_layers is not None: - block_state.original_prompt_embeds = prompt_embeds - block_state.original_pooled_prompt_embeds = pooled_prompt_embeds - block_state.prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) block_state.pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 719910b5ca8f..0595d26346c2 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index 56033fa08bc7..a54b1fd54423 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2ec5bc002f41..a3d9f8bcf56c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -391,6 +391,35 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SD3AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SD3ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls,["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + class Wan22Blocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 20c1542ee3ab..860fa70f0565 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -1,5 +1,18 @@ # coding=utf-8 -# Copyright 2025 HuggingFace Inc. +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random import numpy as np import PIL @@ -46,10 +59,45 @@ def get_dummy_inputs(self, seed=0): "output_type": "pt", } + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_class.from_pretrained( + self.pretrained_model_name_or_path, torch_dtype=torch_dtype + ) + if components_manager is not None: + pipeline.components_manager = components_manager + return pipeline + + def test_save_from_pretrained(self, tmp_path): + pipes =[] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipes.append(pipe) + + image_slices =[] + for p in pipes: + inputs = self.get_dummy_inputs() + image = p(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + def test_float16_inference(self): super().test_float16_inference(9e-2) - SD3_IMAGE2IMAGE_WORKFLOWS = { "image2image":[ ("text_encoder", "SD3TextEncoderStep"), @@ -75,7 +123,11 @@ class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): - pipeline = super().get_pipeline(components_manager, torch_dtype) + pipeline = self.pipeline_class.from_pretrained( + self.pretrained_model_name_or_path, torch_dtype=torch_dtype + ) + if components_manager is not None: + pipeline.components_manager = components_manager pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) return pipeline @@ -104,19 +156,28 @@ def test_save_from_pretrained(self, tmp_path): pipes.append(base_pipe) base_pipe.save_pretrained(str(tmp_path)) - pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) pipe.load_components(torch_dtype=torch.float32) pipe.to(torch_device) pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) pipes.append(pipe) image_slices =[] - for pipe in pipes: + for p in pipes: inputs = self.get_dummy_inputs() - image = pipe(**inputs, output="images") + image = p(**inputs, output="images") image_slices.append(image[0, -3:, -3:, -1].flatten()) assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + def test_float16_inference(self): super().test_float16_inference(8e-2) \ No newline at end of file From 0a81741904319427b82e77a80c922084dbd933ce Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 24 Mar 2026 17:20:28 +0000 Subject: [PATCH 3/3] styling --- src/diffusers/__init__.py | 4 ++-- src/diffusers/modular_pipelines/__init__.py | 2 +- .../modular_pipelines/modular_pipeline.py | 2 +- .../stable_diffusion_3/__init__.py | 3 ++- .../stable_diffusion_3/before_denoise.py | 14 ++++++-------- .../stable_diffusion_3/decoders.py | 4 ++-- .../stable_diffusion_3/denoise.py | 7 ++++--- .../stable_diffusion_3/encoders.py | 19 ++++++++++++------- .../stable_diffusion_3/inputs.py | 11 +++++++---- .../modular_blocks_stable_diffusion_3.py | 2 +- .../stable_diffusion_3/modular_pipeline.py | 3 ++- .../dummy_torch_and_transformers_objects.py | 2 +- ...est_modular_pipeline_stable_diffusion_3.py | 12 ++++++------ 13 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f2852baf421..c1fcef28465b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1211,10 +1211,10 @@ QwenImageLayeredAutoBlocks, QwenImageLayeredModularPipeline, QwenImageModularPipeline, - StableDiffusionXLAutoBlocks, - StableDiffusionXLModularPipeline, SD3AutoBlocks, SD3ModularPipeline, + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, Wan22Blocks, Wan22Image2VideoBlocks, Wan22Image2VideoModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e9a92c5704ac..3e4802609be3 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -141,8 +141,8 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline + from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, Wan22Image2VideoBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index e2ca24812e72..2f36d8526cdc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -121,7 +121,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), ("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")), ("wan", _wan_map_fn), - ("wan-i2v", _wan_i2v_map_fn), + ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), ("flux2", _create_default_map_fn("Flux2ModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py index 13396327ee7c..d6a8b5891986 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -9,6 +9,7 @@ is_transformers_available, ) + _dummy_objects = {} _import_structure = {} @@ -43,4 +44,4 @@ ) for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) \ No newline at end of file + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index ebadf45236da..6781235e1ac8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import numpy as np import torch from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -143,10 +141,10 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.device, getattr(block_state, "mu", None) ) - + block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps - + self.set_block_state(state, block_state) return components, state @@ -207,11 +205,11 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.device, getattr(block_state, "mu", None) ) - + timesteps, num_inference_steps = self.get_timesteps( components.scheduler, num_inference_steps, block_state.strength ) - + block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps @@ -247,7 +245,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) block_state.device = components._execution_device batch_size = block_state.batch_size * block_state.num_images_per_prompt - + if block_state.latents is not None: block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) else: @@ -291,4 +289,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.image_latents, latent_timestep, block_state.latents ) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index c8d9f6a562c1..939df4b5bf36 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -16,9 +16,9 @@ import torch from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...utils import logging -from ...image_processor import VaeImageProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -62,4 +62,4 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state.images = block_state.latents self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index a41e87665ede..dc57b994e33b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -28,6 +28,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) @@ -79,7 +80,7 @@ def __call__( if block_state.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + block_state.guidance_scale * (noise_pred_text - noise_pred_uncond) - + should_skip_layers = ( getattr(block_state, "skip_guidance_layers", None) is not None and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) @@ -151,7 +152,7 @@ def loop_inputs(self) -> list[InputParam]: def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) @@ -164,4 +165,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe class SD3DenoiseStep(SD3DenoiseLoopWrapper): block_classes = [SD3LoopDenoiser, SD3LoopAfterDenoiser] - block_names = ["denoiser", "after_denoiser"] \ No newline at end of file + block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 6087f349a691..46ae89ac65c9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -24,6 +24,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) def retrieve_latents( @@ -95,7 +96,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState): if block_state.resized_image is None: image = block_state.image self.check_inputs( - height=block_state.height, width=block_state.width, + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor, patch_size=components.patch_size ) height = block_state.height or components.default_height @@ -233,7 +234,7 @@ def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_in prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) - + prompt_embeds = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] @@ -247,8 +248,10 @@ def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_in @staticmethod def encode_prompt(components, block_state, device, do_classifier_free_guidance, lora_scale): if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: scale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: scale_lora_layers(components.text_encoder_2, lora_scale) + if components.text_encoder is not None: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: + scale_lora_layers(components.text_encoder_2, lora_scale) prompt_embeds = block_state.prompt_embeds pooled_prompt_embeds = block_state.pooled_prompt_embeds @@ -292,8 +295,10 @@ def encode_prompt(components, block_state, device, do_classifier_free_guidance, negative_pooled_prompt_embeds = torch.cat([neg_pooled_embed, neg_2_pooled_embed], dim=-1) if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: unscale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: unscale_lora_layers(components.text_encoder_2, lora_scale) + if components.text_encoder is not None: + unscale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: + unscale_lora_layers(components.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -315,4 +320,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 5ae213b09040..225e23994f1f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -13,12 +13,14 @@ # limitations under the License. import torch + from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) class SD3TextInputStep(ModularPipelineBlocks): @@ -55,7 +57,7 @@ def intermediate_outputs(self) -> list[str]: @torch.no_grad() def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - + block_state.batch_size = block_state.prompt_embeds.shape[0] block_state.dtype = block_state.prompt_embeds.dtype block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 @@ -129,7 +131,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe height, width = calculate_dimension_from_latents(tensor, components.vae_scale_factor) block_state.height = block_state.height or height block_state.width = block_state.width or width - + if not hasattr(block_state, "image_height"): block_state.image_height = height if not hasattr(block_state, "image_width"): @@ -143,7 +145,8 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe for input_name in self._additional_batch_inputs: tensor = getattr(block_state, input_name) - if tensor is None: continue + if tensor is None: + continue tensor = repeat_tensor_to_batch_size( input_name=input_name, input_tensor=tensor, num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size @@ -151,4 +154,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe setattr(block_state, input_name, tensor) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 0595d26346c2..34e850bf11b8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -116,4 +116,4 @@ class SD3AutoBlocks(SequentialPipelineBlocks): @property def outputs(self): - return [OutputParam.template("images")] \ No newline at end of file + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index a54b1fd54423..657cda1a08ad 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -16,6 +16,7 @@ from ...utils import logging from ..modular_pipeline import ModularPipeline + logger = logging.get_logger(__name__) @@ -59,4 +60,4 @@ def vae_scale_factor(self): def num_channels_latents(self): if getattr(self, "transformer", None) is not None: return self.transformer.config.in_channels - return 16 \ No newline at end of file + return 16 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a3d9f8bcf56c..6f23acf9e9fd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -419,7 +419,7 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - + class Wan22Blocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 860fa70f0565..d256f22b9ae8 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -14,12 +14,12 @@ # limitations under the License. import random + import numpy as np import PIL import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.modular_pipelines import ModularPipeline from diffusers.modular_pipelines.stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline from ...testing_utils import floats_tensor, torch_device @@ -89,10 +89,10 @@ def test_save_from_pretrained(self, tmp_path): def test_load_expected_components_from_save_pretrained(self, tmp_path): base_pipe = self.get_pipeline() base_pipe.save_pretrained(str(tmp_path)) - + pipe = self.pipeline_class.from_pretrained(tmp_path) pipe.load_components(torch_dtype=torch.float32) - + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) def test_float16_inference(self): @@ -173,11 +173,11 @@ def test_save_from_pretrained(self, tmp_path): def test_load_expected_components_from_save_pretrained(self, tmp_path): base_pipe = self.get_pipeline() base_pipe.save_pretrained(str(tmp_path)) - + pipe = self.pipeline_class.from_pretrained(tmp_path) pipe.load_components(torch_dtype=torch.float32) - + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) def test_float16_inference(self): - super().test_float16_inference(8e-2) \ No newline at end of file + super().test_float16_inference(8e-2)