diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 6f905a623d70..a1d4a997d126 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -51,7 +51,14 @@ class DPMSolverSDESchedulerOutput(BaseOutput): class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" - def __init__(self, x, t0, t1, seed=None, **kwargs): + def __init__( + self, + x: torch.Tensor, + t0: float, + t1: float, + seed: Optional[Union[int, List[int]]] = None, + **kwargs, + ): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: @@ -79,10 +86,23 @@ def __init__(self, x, t0, t1, seed=None, **kwargs): ] @staticmethod - def sort(a, b): - return (a, b, 1) if a < b else (b, a, -1) + def sort(a: float, b: float) -> Tuple[float, float, float]: + """ + Sorts two float values and returns them along with a sign indicating if they were swapped. + + Args: + a (`float`): + The first value. + b (`float`): + The second value. - def __call__(self, t0, t1): + Returns: + `Tuple[float, float, float]`: + A tuple containing the sorted values (min, max) and a sign (1.0 if a < b, -1.0 otherwise). + """ + return (a, b, 1.0) if a < b else (b, a, -1.0) + + def __call__(self, t0: float, t1: float) -> torch.Tensor: t0, t1, sign = self.sort(t0, t1) w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] @@ -92,23 +112,29 @@ class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: - x (Tensor): The tensor whose shape, device and dtype to use to generate - random samples. - sigma_min (float): The low end of the valid interval. - sigma_max (float): The high end of the valid interval. - seed (int or List[int]): The random seed. If a list of seeds is + x (`torch.Tensor`): The tensor whose shape, device and dtype is used to generate random samples. + sigma_min (`float`): The low end of the valid interval. + sigma_max (`float`): The high end of the valid interval. + seed (`int` or `List[int]`): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. - transform (callable): A function that maps sigma to the sampler's + transform (`callable`): A function that maps sigma to the sampler's internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + def __init__( + self, + x: torch.Tensor, + sigma_min: float, + sigma_max: float, + seed: Optional[Union[int, List[int]]] = None, + transform: Callable[[float], float] = lambda x: x, + ): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) - def __call__(self, sigma, sigma_next): + def __call__(self, sigma: float, sigma_next: float) -> torch.Tensor: t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() @@ -216,19 +242,28 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, - timestep_spacing: str = "linspace", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) @@ -238,7 +273,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -305,7 +348,7 @@ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None: self._step_index = self._begin_index @property - def init_noise_sigma(self): + def init_noise_sigma(self) -> torch.Tensor: # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() @@ -313,21 +356,21 @@ def init_noise_sigma(self): return (self.sigmas.max() ** 2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Union[int, None]: """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Union[int, None]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -369,7 +412,7 @@ def set_timesteps( num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -378,6 +421,8 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + num_train_timesteps (`int`, *optional*): + The number of train timesteps. If `None`, uses `self.config.num_train_timesteps`. """ self.num_inference_steps = num_inference_steps @@ -443,7 +488,7 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.noise_sampler = None - def _second_order_timesteps(self, sigmas, log_sigmas): + def _second_order_timesteps(self, sigmas: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: def sigma_fn(_t): return np.exp(-_t) @@ -459,7 +504,7 @@ def t_fn(_sigma): return timesteps # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -604,14 +649,14 @@ def _convert_to_beta( return sigmas @property - def state_in_first_order(self): + def state_in_first_order(self) -> bool: return self.sample is None def step( self, - model_output: Union[torch.Tensor, np.ndarray], + model_output: torch.Tensor, timestep: Union[float, torch.Tensor], - sample: Union[torch.Tensor, np.ndarray], + sample: torch.Tensor, return_dict: bool = True, s_noise: float = 1.0, ) -> Union[DPMSolverSDESchedulerOutput, Tuple]: @@ -620,11 +665,11 @@ def step( process from the learned model outputs (most often the predicted noise). Args: - model_output (`torch.Tensor` or `np.ndarray`): + model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. - sample (`torch.Tensor` or `np.ndarray`): + sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or @@ -643,7 +688,9 @@ def step( # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() - self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) + self.noise_sampler = BrownianTreeNoiseSampler( + sample, min_sigma.item(), max_sigma.item(), self.noise_sampler_seed + ) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.Tensor) -> torch.Tensor: @@ -694,7 +741,10 @@ def t_fn(_sigma: torch.Tensor) -> torch.Tensor: sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) - sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_up = min( + sigma_to, + (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( @@ -771,5 +821,5 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps