From 9da9b68cd2cf7afd0080003c736c4ae3c1741467 Mon Sep 17 00:00:00 2001 From: Prakhar Agarwal Date: Sun, 22 Mar 2026 02:20:34 -0700 Subject: [PATCH] fix: disable non-blocking tensor copies to MPS during model loading When loading model weights with `device_map="mps"`, `load_model_dict_into_meta` unconditionally passes `non_blocking=True` to `set_module_tensor_to_device` (accelerate > 1.8.1). With mmap-backed safetensors the source CPU memory can be released before the asynchronous MPS copy completes, silently corrupting the destination weights. The corruption is non-deterministic and dtype-dependent (float32 corrupts weights but not biases; float16 corrupts biases but not weights), producing extreme values (~1e37), LayerNorm overflow, and NaN outputs. Move the `non_blocking` / `clear_cache` assignment after `param_device` is resolved, and force `non_blocking=False` when the target is MPS. Fixes https://github.com/huggingface/diffusers/issues/13227 Made-with: Cursor --- src/diffusers/models/model_loading_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..a62b173f1bd1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -253,10 +253,6 @@ def load_model_dict_into_meta( param = param.to(dtype) set_module_kwargs["dtype"] = dtype - if is_accelerate_version(">", "1.8.1"): - set_module_kwargs["non_blocking"] = True - set_module_kwargs["clear_cache"] = False - # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 @@ -277,6 +273,17 @@ def load_model_dict_into_meta( param_device = _determine_param_device(param_name, device_map) + if is_accelerate_version(">", "1.8.1"): + # MPS does not support truly asynchronous non-blocking transfers from CPU. + # When non_blocking=True the source tensor may be freed or recycled (especially + # with mmap-backed safetensors) before the MPS copy completes, silently corrupting + # the destination weights. Force synchronous copies on MPS to avoid this. + is_mps_target = str(param_device) == "mps" or ( + isinstance(param_device, torch.device) and param_device.type == "mps" + ) + set_module_kwargs["non_blocking"] = not is_mps_target + set_module_kwargs["clear_cache"] = False + # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied if empty_state_dict[param_name].shape != param.shape: