diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index deae25899475..b533bef35414 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -855,10 +855,11 @@ def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" is_npu = sample.device.type == "npu" + is_neuron = sample.device.type == "neuron" if isinstance(timestep, float): - dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + dtype = torch.float32 if (is_mps or is_npu or is_neuron) else torch.float64 else: - dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + dtype = torch.int32 if (is_mps or is_npu or is_neuron) else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d675f1de04a7..5b329f46e2aa 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -22,7 +22,7 @@ import types from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, List, Union, get_args, get_origin, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin, get_type_hints import httpx import numpy as np @@ -68,6 +68,7 @@ is_transformers_version, logging, numpy_to_pil, + requires_backends, ) from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card @@ -2248,6 +2249,63 @@ def _is_pipeline_device_mapped(self): return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 + def enable_neuron_compile( + self, + model_names: Optional[List[str]] = None, + cache_dir: Optional[str] = None, + fullgraph: bool = True, + ) -> None: + """ + Compiles the pipeline's nn.Module components with ``torch.compile(backend="neuron")``, + enabling whole-graph NEFF compilation for AWS Trainium/Inferentia. + + The first forward call per component triggers neuronx-cc compilation (slow). + Use ``neuron_warmup()`` to trigger this explicitly before timed inference. + + Args: + model_names (`List[str]`, *optional*): + Component names to compile. Defaults to all nn.Module components. + cache_dir (`str`, *optional*): + Path to persist compiled NEFFs across runs via ``TORCH_NEURONX_NEFF_CACHE_DIR``. + Skips recompilation on subsequent runs. + fullgraph (`bool`, defaults to `True`): + Disallow graph breaks (required for full-graph fusion). + """ + requires_backends(self, "torch_neuronx") + import torch_neuronx # noqa: F401 — registers neuron backend + from torch_neuronx.neuron_dynamo_backend import set_model_name + + if cache_dir is not None: + os.environ["TORCH_NEURONX_NEFF_CACHE_DIR"] = cache_dir + + if model_names is None: + model_names = [ + name for name, comp in self.components.items() if isinstance(comp, torch.nn.Module) + ] + + for name in model_names: + component = getattr(self, name, None) + if isinstance(component, torch.nn.Module) and not is_compiled_module(component): + logger.info(f"Compiling {name} with backend='neuron'") + set_model_name(name) + setattr(self, name, torch.compile(component, backend="neuron", fullgraph=fullgraph)) + + def neuron_warmup(self, *args, **kwargs) -> None: + """ + Runs a single dummy forward pass through the pipeline to trigger neuronx-cc + compilation for all components (static-shape NEFF compilation). + + This is equivalent to calling ``__call__`` with the same shapes but discards + the output. After warmup, subsequent calls reuse the compiled NEFFs and run fast. + + Pass the same arguments you would use for real inference (height, width, + num_inference_steps, batch_size, etc.) so that the compiled shapes match. + """ + logger.info("Running Neuron warmup forward pass to trigger NEFF compilation...") + with torch.no_grad(): + self(*args, **kwargs) + logger.info("Neuron warmup complete.") + class StableDiffusionMixin: r""" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2f6b105702e8..fdda2547f09e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1092,7 +1092,11 @@ def __call__( ) # 4. Prepare timesteps - if XLA_AVAILABLE: + # Keep timesteps on CPU for XLA (TPU) and Neuron: both use lazy/XLA execution where + # dynamic-shape ops like .nonzero() and .item() inside scheduler.index_for_timestep() + # are incompatible with static-graph compilation. + is_neuron_device = hasattr(device, "type") and device.type == "neuron" + if XLA_AVAILABLE or is_neuron_device: timestep_device = "cpu" else: timestep_device = device @@ -1195,15 +1199,23 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # For Neuron: scale_model_input on CPU to avoid XLA ops outside the compiled UNet region. + # index_for_timestep() uses .nonzero()/.item() which are incompatible with static graphs. + if is_neuron_device: + latent_model_input = self.scheduler.scale_model_input(latent_model_input.to("cpu"), t).to(device) + else: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: added_cond_kwargs["image_embeds"] = image_embeds + # For Neuron: pre-cast timestep to float32 on device. Neuron XLA does not support + # int64 ops; the compiled UNet graph requires a float32 timestep input on-device. + t_unet = t.to(torch.float32).to(device) if is_neuron_device else t noise_pred = self.unet( latent_model_input, - t, + t_unet, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, @@ -1222,7 +1234,13 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # For Neuron: scheduler.step on CPU to keep scheduler arithmetic off the XLA device. + if is_neuron_device: + latents = self.scheduler.step( + noise_pred.to("cpu"), t, latents.to("cpu"), **extra_step_kwargs, return_dict=False + )[0].to(device) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 23d7ac7c6c2d..8a86cf4f4151 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -110,6 +110,7 @@ is_timm_available, is_torch_available, is_torch_mlu_available, + is_torch_neuronx_available, is_torch_npu_available, is_torch_version, is_torch_xla_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 551fa358a28d..2ce989626b3d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -193,6 +193,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> tuple[b _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu") +_torch_neuronx_available, _torch_neuronx_version = _is_package_available("torch_neuronx") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _kernels_available, _kernels_version = _is_package_available("kernels") @@ -249,6 +250,10 @@ def is_torch_mlu_available(): return _torch_mlu_available +def is_torch_neuronx_available(): + return _torch_neuronx_available + + def is_flax_available(): return _flax_available @@ -579,6 +584,10 @@ def is_av_available(): """ +TORCH_NEURONX_IMPORT_ERROR = """ +{0} requires the torch_neuronx library (AWS Neuron SDK) but it was not found in your environment. Please install it following the AWS Neuron documentation: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/ +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -609,6 +618,7 @@ def is_av_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("torch_neuronx", (is_torch_neuronx_available, TORCH_NEURONX_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 8a48316bf3dd..55fee1d3249e 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -22,7 +22,13 @@ from typing import Callable, ParamSpec, TypeVar from . import logging -from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version +from .import_utils import ( + is_torch_available, + is_torch_mlu_available, + is_torch_neuronx_available, + is_torch_npu_available, + is_torch_version, +) T = TypeVar("T") @@ -33,12 +39,13 @@ import torch from torch.fft import fftn, fftshift, ifftn, ifftshift - BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True} + BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "neuron": False, "default": True} BACKEND_EMPTY_CACHE = { "cuda": torch.cuda.empty_cache, "xpu": torch.xpu.empty_cache, "cpu": None, "mps": torch.mps.empty_cache, + "neuron": None, "default": None, } BACKEND_DEVICE_COUNT = { @@ -46,6 +53,7 @@ "xpu": torch.xpu.device_count, "cpu": lambda: 0, "mps": lambda: 0, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "device_count", lambda: 0)(), "default": 0, } BACKEND_MANUAL_SEED = { @@ -53,6 +61,7 @@ "xpu": torch.xpu.manual_seed, "cpu": torch.manual_seed, "mps": torch.mps.manual_seed, + "neuron": torch.manual_seed, "default": torch.manual_seed, } BACKEND_RESET_PEAK_MEMORY_STATS = { @@ -60,6 +69,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_RESET_MAX_MEMORY_ALLOCATED = { @@ -67,6 +77,7 @@ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None), "cpu": None, "mps": None, + "neuron": None, "default": None, } BACKEND_MAX_MEMORY_ALLOCATED = { @@ -74,6 +85,7 @@ "xpu": getattr(torch.xpu, "max_memory_allocated", None), "cpu": 0, "mps": 0, + "neuron": 0, "default": 0, } BACKEND_SYNCHRONIZE = { @@ -81,6 +93,7 @@ "xpu": getattr(torch.xpu, "synchronize", None), "cpu": None, "mps": None, + "neuron": lambda: getattr(getattr(torch, "neuron", None), "synchronize", lambda: None)(), "default": None, } logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -169,11 +182,15 @@ def randn_tensor( layout = layout or torch.strided device = device or torch.device("cpu") + # Neuron (XLA) does not support creating random tensors directly on device; always use CPU + if device.type == "neuron": + rand_device = torch.device("cpu") + if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" - if device != "mps": + if device.type not in ("mps", "neuron"): logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" @@ -294,6 +311,8 @@ def get_device(): return "mps" elif is_torch_mlu_available(): return "mlu" + elif is_torch_neuronx_available() and hasattr(torch, "neuron") and torch.neuron.is_available(): + return "neuron" else: return "cpu"