Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 83 additions & 33 deletions src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import math
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Union
from typing import Callable, List, Literal, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -51,7 +51,14 @@ class DPMSolverSDESchedulerOutput(BaseOutput):
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""

def __init__(self, x, t0, t1, seed=None, **kwargs):
def __init__(
self,
x: torch.Tensor,
t0: float,
t1: float,
seed: Optional[Union[int, List[int]]] = None,
**kwargs,
):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
Expand Down Expand Up @@ -79,10 +86,23 @@ def __init__(self, x, t0, t1, seed=None, **kwargs):
]

@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def sort(a: float, b: float) -> Tuple[float, float, float]:
"""
Sorts two float values and returns them along with a sign indicating if they were swapped.

Args:
a (`float`):
The first value.
b (`float`):
The second value.

def __call__(self, t0, t1):
Returns:
`Tuple[float, float, float]`:
A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise).
"""
return (a, b, 1.0) if a < b else (b, a, -1.0)

def __call__(self, t0: float, t1: float) -> torch.Tensor:
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
Expand All @@ -92,23 +112,29 @@ class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.

Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples.
sigma_min (`float`): The low end of the valid interval.
sigma_max (`float`): The high end of the valid interval.
seed (`int` or `List[int]`): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
with its own seed.
transform (callable): A function that maps sigma to the sampler's
transform (`callable`): A function that maps sigma to the sampler's
internal timestep.
"""

def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
def __init__(
self,
x: torch.Tensor,
sigma_min: float,
sigma_max: float,
seed: Optional[Union[int, List[int]]] = None,
transform: Callable[[float], float] = lambda x: x,
):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)

def __call__(self, sigma, sigma_next):
def __call__(self, sigma: float, sigma_next: float) -> torch.Tensor:
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()

Expand Down Expand Up @@ -216,19 +242,28 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
Expand All @@ -238,7 +273,15 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
Expand Down Expand Up @@ -305,29 +348,29 @@ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
self._step_index = self._begin_index

@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> torch.Tensor:
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()

return (self.sigmas.max() ** 2 + 1) ** 0.5

@property
def step_index(self):
def step_index(self) -> Union[int, None]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index

@property
def begin_index(self):
def begin_index(self) -> Union[int, None]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Expand Down Expand Up @@ -369,7 +412,7 @@ def set_timesteps(
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Expand All @@ -378,6 +421,8 @@ def set_timesteps(
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
num_train_timesteps (`int`, *optional*):
The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`.
"""
self.num_inference_steps = num_inference_steps

Expand Down Expand Up @@ -443,7 +488,7 @@ def set_timesteps(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.noise_sampler = None

def _second_order_timesteps(self, sigmas, log_sigmas):
def _second_order_timesteps(self, sigmas: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
def sigma_fn(_t):
return np.exp(-_t)

Expand All @@ -459,7 +504,7 @@ def t_fn(_sigma):
return timesteps

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.

Expand Down Expand Up @@ -604,14 +649,14 @@ def _convert_to_beta(
return sigmas

@property
def state_in_first_order(self):
def state_in_first_order(self) -> bool:
return self.sample is None

def step(
self,
model_output: Union[torch.Tensor, np.ndarray],
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
sample: torch.Tensor,
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
Expand All @@ -620,11 +665,11 @@ def step(
process from the learned model outputs (most often the predicted noise).

Args:
model_output (`torch.Tensor` or `np.ndarray`):
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor` or `np.ndarray`):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
Expand All @@ -643,7 +688,9 @@ def step(
# Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
self.noise_sampler = BrownianTreeNoiseSampler(
sample, min_sigma.item(), max_sigma.item(), self.noise_sampler_seed
)

# Define functions to compute sigma and t from each other
def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -694,7 +741,10 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor:

sigma_from = sigma_fn(t)
sigma_to = sigma_fn(t_next)
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
sigma_up = min(
sigma_to,
(sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
ancestral_t = t_fn(sigma_down)
prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
Expand Down Expand Up @@ -771,5 +821,5 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples

def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
Loading