diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 04642ad5d401..a62b173f1bd1 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -253,10 +253,6 @@ def load_model_dict_into_meta( param = param.to(dtype) set_module_kwargs["dtype"] = dtype - if is_accelerate_version(">", "1.8.1"): - set_module_kwargs["non_blocking"] = True - set_module_kwargs["clear_cache"] = False - # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 @@ -277,6 +273,17 @@ def load_model_dict_into_meta( param_device = _determine_param_device(param_name, device_map) + if is_accelerate_version(">", "1.8.1"): + # MPS does not support truly asynchronous non-blocking transfers from CPU. + # When non_blocking=True the source tensor may be freed or recycled (especially + # with mmap-backed safetensors) before the MPS copy completes, silently corrupting + # the destination weights. Force synchronous copies on MPS to avoid this. + is_mps_target = str(param_device) == "mps" or ( + isinstance(param_device, torch.device) and param_device.type == "mps" + ) + set_module_kwargs["non_blocking"] = not is_mps_target + set_module_kwargs["clear_cache"] = False + # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied if empty_state_dict[param_name].shape != param.shape: