Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,6 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype

if is_accelerate_version(">", "1.8.1"):
set_module_kwargs["non_blocking"] = True
set_module_kwargs["clear_cache"] = False

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
Expand All @@ -277,6 +273,17 @@ def load_model_dict_into_meta(

param_device = _determine_param_device(param_name, device_map)

if is_accelerate_version(">", "1.8.1"):
# MPS does not support truly asynchronous non-blocking transfers from CPU.
# When non_blocking=True the source tensor may be freed or recycled (especially
# with mmap-backed safetensors) before the MPS copy completes, silently corrupting
# the destination weights. Force synchronous copies on MPS to avoid this.
is_mps_target = str(param_device) == "mps" or (
isinstance(param_device, torch.device) and param_device.type == "mps"
)
set_module_kwargs["non_blocking"] = not is_mps_target
set_module_kwargs["clear_cache"] = False

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
Expand Down