Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 160 additions & 2 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
temp_state_dict[new_key] = v
original_state_dict = temp_state_dict

# Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
# Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")]
for alpha_key in alpha_keys:
alpha = original_state_dict.pop(alpha_key).item()
module_path = alpha_key[: -len(".alpha")]
lora_a_key = f"{module_path}.lora_A.weight"
if lora_a_key in original_state_dict:
rank = original_state_dict[lora_a_key].shape[0]
scale = alpha / rank
original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale
Comment on lines +2334 to +2344
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it rebase with upstream please? I think this was fixed with #13313


num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
Expand Down Expand Up @@ -2443,6 +2455,152 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict


def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
"""Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
`fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
QKV projections before injecting the adapter.
Comment on lines +2461 to +2463
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot assume that the user will always fuse_projections(). So, let's split the fused QKV as done in the other conversion workflows.

"""
converted_state_dict = {}

prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()}

num_double_layers = 0
num_single_layers = 0
for key in original_state_dict:
if key.startswith("single_blocks."):
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
elif key.startswith("double_blocks."):
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)

lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2")

def _remap_lokr_module(bfl_path, diff_path):
"""Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
alpha_key = f"{bfl_path}.alpha"
alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None

for suffix in lokr_suffixes:
src_key = f"{bfl_path}.{suffix}"
if src_key not in original_state_dict:
continue

weight = original_state_dict.pop(src_key)

# Bake alpha/rank scaling into the first w1 tensor encountered for this module.
# After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"):
w2a_key = f"{bfl_path}.lokr_w2_a"
w1a_key = f"{bfl_path}.lokr_w1_a"
if w2a_key in original_state_dict:
r_eff = original_state_dict[w2a_key].shape[1]
elif w1a_key in original_state_dict:
r_eff = original_state_dict[w1a_key].shape[1]
else:
r_eff = alpha
scale = alpha / r_eff
weight = weight * scale
alpha = None # only bake once per module

converted_state_dict[f"{diff_path}.{suffix}"] = weight

# --- Single blocks ---
for sl in range(num_single_layers):
_remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj")
_remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out")

# --- Double blocks ---
for dl in range(num_double_layers):
tb = f"transformer_blocks.{dl}"
db = f"double_blocks.{dl}"

# QKV → fused to_qkv / to_added_qkv (model must be fused before injection)
_remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv")
_remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv")

# Projections
_remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0")
_remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out")

# MLPs
_remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in")
_remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out")
_remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in")
_remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out")

# --- Extra mappings (embedders, modulation, final layer) ---
extra_mappings = {
"img_in": "x_embedder",
"txt_in": "context_embedder",
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"final_layer.linear": "proj_out",
"final_layer.adaLN_modulation.1": "norm_out.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
}
for bfl_key, diff_key in extra_mappings.items():
_remap_lokr_module(bfl_key, diff_key)

if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _refuse_flux2_lora_state_dict(state_dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also change given the above comment.

"""Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys.
When the model's QKV projections are fused, incoming LoRA keys targeting separate
to_q/to_k/to_v must be re-fused to match. Inverse of the QKV split performed by
``_convert_non_diffusers_flux2_lora_to_diffusers``.
"""
converted = {}
remaining = dict(state_dict)

# Detect double block indices from keys
num_double_layers = 0
for key in remaining:
if ".transformer_blocks." in key:
parts = key.split(".")
idx = parts.index("transformer_blocks") + 1
if idx < len(parts) and parts[idx].isdigit():
num_double_layers = max(num_double_layers, int(parts[idx]) + 1)

# Fuse Q/K/V for image stream and text stream
qkv_groups = [
(["to_q", "to_k", "to_v"], "to_qkv"),
(["add_q_proj", "add_k_proj", "add_v_proj"], "to_added_qkv"),
]

for dl in range(num_double_layers):
attn_prefix = f"transformer.transformer_blocks.{dl}.attn"
for separate_keys, fused_name in qkv_groups:
for lora_key in ("lora_A", "lora_B"):
src_keys = [f"{attn_prefix}.{sk}.{lora_key}.weight" for sk in separate_keys]
if not all(k in remaining for k in src_keys):
continue

weights = [remaining.pop(k) for k in src_keys]
dst_key = f"{attn_prefix}.{fused_name}.{lora_key}.weight"
if lora_key == "lora_A":
# lora_A was replicated during split - all three are identical, take the first
converted[dst_key] = weights[0]
else:
# lora_B was chunked along dim=0 - concatenate back
converted[dst_key] = torch.cat(weights, dim=0)

# Pass through all non-QKV keys unchanged
converted.update(remaining)
return converted


def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.
Expand Down Expand Up @@ -2600,14 +2758,14 @@ def get_alpha_scales(down_weight, alpha_key):

base = k[: -len(lora_dot_down_key)]

# Skip combined "qkv" projection individual to.q/k/v keys are also present.
# Skip combined "qkv" projection - individual to.q/k/v keys are also present.
if base.endswith(".qkv"):
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
state_dict.pop(base + ".alpha", None)
continue

# Skip bare "out.lora.*" "to_out.0.lora.*" covers the same projection.
# Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection.
if re.search(r"\.out$", base) and ".to_out" not in base:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
Expand Down
31 changes: 26 additions & 5 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lokr_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
Expand All @@ -56,6 +57,7 @@
_convert_non_diffusers_z_image_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
_refuse_flux2_lora_state_dict,
)


Expand Down Expand Up @@ -5679,12 +5681,18 @@ def lora_state_dict(

is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
if is_ai_toolkit:
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
is_lokr = any("lokr_" in k for k in state_dict)
if is_lokr:
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
if metadata is None:
metadata = {}
metadata["is_lokr"] = "true"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could append to metadata as:

metadata["is_lokr"] = is_lokr

We could make it a bit more general (e.g., adapter type) if needed in a future PR. Cc: @BenjaminBossan

else:
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
Expand Down Expand Up @@ -5712,13 +5720,26 @@ def load_lora_weights(
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.")

# For LoKR adapters, fuse QKV projections so peft can target the fused modules directly.
is_lokr = metadata is not None and metadata.get("is_lokr") == "true"
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
if is_lokr:
transformer.fuse_qkv_projections()
elif (
hasattr(transformer, "transformer_blocks")
and len(transformer.transformer_blocks) > 0
and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False)
):
# Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match.
state_dict = _refuse_flux2_lora_state_dict(state_dict)
Comment on lines +5730 to +5738
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not assume it as mentioned previously. Doing this will also simplify the code (getting rid of the conditionals, for example).


self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
Expand Down
111 changes: 60 additions & 51 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales


logger = logging.get_logger(__name__)

_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
lambda: lambda model_cls, weights: weights,
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
Expand Down Expand Up @@ -213,56 +213,65 @@ def load_lora_adapter(
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)

# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)

# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
is_sai_sd_control_lora = "lora_controlnet" in state_dict
if is_sai_sd_control_lora:
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)

rank = {}
for key, val in state_dict.items():
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)

# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)
# Detect whether this is a LoKR adapter (Kronecker product, not low-rank)
is_lokr = any("lokr_" in k for k in state_dict)

if is_lokr:
if adapter_name is None:
adapter_name = get_adapter_name(self)
lora_config = _create_lokr_config(state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it adapter_config to be a bit more idiomatic.

is_sai_sd_control_lora = False
else:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)

# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
is_sai_sd_control_lora = "lora_controlnet" in state_dict
if is_sai_sd_control_lora:
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)

rank = {}
for key, val in state_dict.items():
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)

# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)

# Adjust LoRA config for Control LoRA
if is_sai_sd_control_lora:
lora_config.lora_alpha = lora_config.r
lora_config.alpha_pattern = lora_config.rank_pattern
lora_config.bias = "all"
lora_config.modules_to_save = lora_config.exclude_modules
lora_config.exclude_modules = None
# Adjust LoRA config for Control LoRA
if is_sai_sd_control_lora:
lora_config.lora_alpha = lora_config.r
lora_config.alpha_pattern = lora_config.rank_pattern
lora_config.bias = "all"
lora_config.modules_to_save = lora_config.exclude_modules
lora_config.exclude_modules = None

# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
Expand Down
Loading