Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"SD3AutoBlocks",
"SD3ModularPipeline",
"Wan22Blocks",
"Wan22Image2VideoBlocks",
"Wan22Image2VideoModularPipeline",
Expand Down Expand Up @@ -1209,6 +1211,8 @@
QwenImageLayeredAutoBlocks,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
SD3AutoBlocks,
SD3ModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22Blocks,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["stable_diffusion_3"] =["SD3AutoBlocks", "SD3ModularPipeline"]
_import_structure["wan"] = [
"WanBlocks",
"Wan22Blocks",
Expand Down Expand Up @@ -140,6 +141,7 @@
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
)
from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import (
Wan22Blocks,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _helios_pyramid_map_fn(config_dict=None):
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")),
("wan", _wan_map_fn),
("wan-i2v", _wan_i2v_map_fn),
("flux", _create_default_map_fn("FluxModularPipeline")),
Expand Down
47 changes: 47 additions & 0 deletions src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING

from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)


_dummy_objects = {}
_import_structure = {}

try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_stable_diffusion_3"] = ["SD3AutoBlocks"]
_import_structure["modular_pipeline"] = ["SD3ModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_stable_diffusion_3 import SD3AutoBlocks
from .modular_pipeline import SD3ModularPipeline
else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
292 changes: 292 additions & 0 deletions src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import SD3ModularPipeline


logger = logging.get_logger(__name__)


def retrieve_timesteps(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're using the one from here

let's directly import or add a "# Copied from ..." comment.

scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps


def calculate_shift(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu


def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
height,
width,
patch_size,
vae_scale_factor,
num_inference_steps,
sigmas,
device,
mu=None,
):
scheduler_kwargs = {}
if scheduler.config.get("use_dynamic_shifting", None) and mu is None:
image_seq_len = (height // vae_scale_factor // patch_size) * (width // vae_scale_factor // patch_size)
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.16),
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu

timesteps, num_inference_steps = retrieve_timesteps(
scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
)
return timesteps, num_inference_steps


class SD3SetTimestepsStep(ModularPipelineBlocks):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to follow this semantic for the rest of the blocks as well.

Suggested change
class SD3SetTimestepsStep(ModularPipelineBlocks):
class StableDiffusion3SetTimestepsStep(ModularPipelineBlocks):

model_name = "stable-diffusion-3"

@property
def expected_components(self) -> list[ComponentSpec]:
return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]

@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"

@property
def inputs(self) -> list[InputParam]:
return[
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("mu", type_hint=float),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return[
OutputParam("timesteps", type_hint=torch.Tensor),
OutputParam("num_inference_steps", type_hint=int),
]

@torch.no_grad()
def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device

timesteps, num_inference_steps = _get_initial_timesteps_and_optionals(
components.transformer,
components.scheduler,
block_state.height,
block_state.width,
components.patch_size,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.sigmas,
block_state.device,
getattr(block_state, "mu", None)
)

block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps

self.set_block_state(state, block_state)
return components, state


class SD3Img2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-3"

@property
def expected_components(self) -> list[ComponentSpec]:
return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]

@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for img2img inference"

@property
def inputs(self) -> list[InputParam]:
return[
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("strength", default=0.6),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("mu", type_hint=float),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return[
OutputParam("timesteps", type_hint=torch.Tensor),
OutputParam("num_inference_steps", type_hint=int),
]

@staticmethod
def get_timesteps(scheduler, num_inference_steps, strength):
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
if hasattr(scheduler, "set_begin_index"):
scheduler.set_begin_index(t_start * scheduler.order)
return timesteps, num_inference_steps - t_start

@torch.no_grad()
def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device

timesteps, num_inference_steps = _get_initial_timesteps_and_optionals(
components.transformer,
components.scheduler,
block_state.height,
block_state.width,
components.patch_size,
components.vae_scale_factor,
block_state.num_inference_steps,
block_state.sigmas,
block_state.device,
getattr(block_state, "mu", None)
)

timesteps, num_inference_steps = self.get_timesteps(
components.scheduler, num_inference_steps, block_state.strength
)

block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps

self.set_block_state(state, block_state)
return components, state


class SD3PrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-3"

@property
def description(self) -> str:
return "Prepare latents step for Text-to-Image"

@property
def inputs(self) -> list[InputParam]:
return[
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=torch.Tensor | None),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam("batch_size", required=True, type_hint=int),
InputParam("dtype", type_hint=torch.dtype),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return[OutputParam("latents", type_hint=torch.Tensor)]

@torch.no_grad()
def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt

if block_state.latents is not None:
block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype)
else:
shape = (
batch_size,
components.num_channels_latents,
int(block_state.height) // components.vae_scale_factor,
int(block_state.width) // components.vae_scale_factor,
)
block_state.latents = randn_tensor(shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype)

self.set_block_state(state, block_state)
return components, state


class SD3Img2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-3"

@property
def expected_components(self) -> list[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]

@property
def inputs(self) -> list[InputParam]:
return[
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("image_latents", required=True, type_hint=torch.Tensor),
InputParam("timesteps", required=True, type_hint=torch.Tensor),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [OutputParam("initial_noise", type_hint=torch.Tensor)]

@torch.no_grad()
def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
block_state.initial_noise = block_state.latents
block_state.latents = components.scheduler.scale_noise(
block_state.image_latents, latent_timestep, block_state.latents
)
self.set_block_state(state, block_state)
return components, state
Loading