From c41b7b5d9fcf33ba215e05a4e3f1f4fc06860794 Mon Sep 17 00:00:00 2001 From: CalamitousFelicitousness Date: Tue, 24 Mar 2026 00:16:37 +0000 Subject: [PATCH] Add Flux2 LoKR adapter support with dual conversion paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load --- .../loaders/lora_conversion_utils.py | 162 +++++++++++++++++- src/diffusers/loaders/lora_pipeline.py | 31 +++- src/diffusers/loaders/peft.py | 111 ++++++------ src/diffusers/utils/peft_utils.py | 115 +++++++++++++ 4 files changed, 361 insertions(+), 58 deletions(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 298aa61d37ed..6f33b8219815 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): temp_state_dict[new_key] = v original_state_dict = temp_state_dict + # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed. + # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1. + alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")] + for alpha_key in alpha_keys: + alpha = original_state_dict.pop(alpha_key).item() + module_path = alpha_key[: -len(".alpha")] + lora_a_key = f"{module_path}.lora_A.weight" + if lora_a_key in original_state_dict: + rank = original_state_dict[lora_a_key].shape[0] + scale = alpha / rank + original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale + num_double_layers = 0 num_single_layers = 0 for key in original_state_dict.keys(): @@ -2443,6 +2455,152 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): return converted_state_dict +def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): + """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format. + + Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by + `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's + QKV projections before injecting the adapter. + """ + converted_state_dict = {} + + prefix = "diffusion_model." + original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} + + num_double_layers = 0 + num_single_layers = 0 + for key in original_state_dict: + if key.startswith("single_blocks."): + num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1) + elif key.startswith("double_blocks."): + num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1) + + lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2") + + def _remap_lokr_module(bfl_path, diff_path): + """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path.""" + alpha_key = f"{bfl_path}.alpha" + alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None + + for suffix in lokr_suffixes: + src_key = f"{bfl_path}.{suffix}" + if src_key not in original_state_dict: + continue + + weight = original_state_dict.pop(src_key) + + # Bake alpha/rank scaling into the first w1 tensor encountered for this module. + # After baking, peft's config uses alpha=r so its runtime scaling is 1.0. + if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"): + w2a_key = f"{bfl_path}.lokr_w2_a" + w1a_key = f"{bfl_path}.lokr_w1_a" + if w2a_key in original_state_dict: + r_eff = original_state_dict[w2a_key].shape[1] + elif w1a_key in original_state_dict: + r_eff = original_state_dict[w1a_key].shape[1] + else: + r_eff = alpha + scale = alpha / r_eff + weight = weight * scale + alpha = None # only bake once per module + + converted_state_dict[f"{diff_path}.{suffix}"] = weight + + # --- Single blocks --- + for sl in range(num_single_layers): + _remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj") + _remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out") + + # --- Double blocks --- + for dl in range(num_double_layers): + tb = f"transformer_blocks.{dl}" + db = f"double_blocks.{dl}" + + # QKV → fused to_qkv / to_added_qkv (model must be fused before injection) + _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") + _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") + + # Projections + _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") + _remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out") + + # MLPs + _remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in") + _remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out") + _remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in") + _remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out") + + # --- Extra mappings (embedders, modulation, final layer) --- + extra_mappings = { + "img_in": "x_embedder", + "txt_in": "context_embedder", + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "final_layer.linear": "proj_out", + "final_layer.adaLN_modulation.1": "norm_out.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + } + for bfl_key, diff_key in extra_mappings.items(): + _remap_lokr_module(bfl_key, diff_key) + + if len(original_state_dict) > 0: + raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") + + for key in list(converted_state_dict.keys()): + converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) + + return converted_state_dict + + +def _refuse_flux2_lora_state_dict(state_dict): + """Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys. + + When the model's QKV projections are fused, incoming LoRA keys targeting separate + to_q/to_k/to_v must be re-fused to match. Inverse of the QKV split performed by + ``_convert_non_diffusers_flux2_lora_to_diffusers``. + """ + converted = {} + remaining = dict(state_dict) + + # Detect double block indices from keys + num_double_layers = 0 + for key in remaining: + if ".transformer_blocks." in key: + parts = key.split(".") + idx = parts.index("transformer_blocks") + 1 + if idx < len(parts) and parts[idx].isdigit(): + num_double_layers = max(num_double_layers, int(parts[idx]) + 1) + + # Fuse Q/K/V for image stream and text stream + qkv_groups = [ + (["to_q", "to_k", "to_v"], "to_qkv"), + (["add_q_proj", "add_k_proj", "add_v_proj"], "to_added_qkv"), + ] + + for dl in range(num_double_layers): + attn_prefix = f"transformer.transformer_blocks.{dl}.attn" + for separate_keys, fused_name in qkv_groups: + for lora_key in ("lora_A", "lora_B"): + src_keys = [f"{attn_prefix}.{sk}.{lora_key}.weight" for sk in separate_keys] + if not all(k in remaining for k in src_keys): + continue + + weights = [remaining.pop(k) for k in src_keys] + dst_key = f"{attn_prefix}.{fused_name}.{lora_key}.weight" + if lora_key == "lora_A": + # lora_A was replicated during split - all three are identical, take the first + converted[dst_key] = weights[0] + else: + # lora_B was chunked along dim=0 - concatenate back + converted[dst_key] = torch.cat(weights, dim=0) + + # Pass through all non-QKV keys unchanged + converted.update(remaining) + return converted + + def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): """ Convert non-diffusers ZImage LoRA state dict to diffusers format. @@ -2600,14 +2758,14 @@ def get_alpha_scales(down_weight, alpha_key): base = k[: -len(lora_dot_down_key)] - # Skip combined "qkv" projection — individual to.q/k/v keys are also present. + # Skip combined "qkv" projection - individual to.q/k/v keys are also present. if base.endswith(".qkv"): state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) state_dict.pop(base + ".alpha", None) continue - # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection. + # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection. if re.search(r"\.out$", base) and ".to_out" not in base: state_dict.pop(k) state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5d10f596f2e6..8be7420dbc54 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -45,6 +45,7 @@ _convert_hunyuan_video_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers, _convert_musubi_wan_lora_to_diffusers, + _convert_non_diffusers_flux2_lokr_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, @@ -56,6 +57,7 @@ _convert_non_diffusers_z_image_lora_to_diffusers, _convert_xlabs_flux_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers, + _refuse_flux2_lora_state_dict, ) @@ -5679,12 +5681,18 @@ def lora_state_dict( is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) if is_ai_toolkit: - state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) + is_lokr = any("lokr_" in k for k in state_dict) + if is_lokr: + state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) + if metadata is None: + metadata = {} + metadata["is_lokr"] = "true" + else: + state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict return out - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -5712,13 +5720,26 @@ def load_lora_weights( kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.") + + # For LoKR adapters, fuse QKV projections so peft can target the fused modules directly. + is_lokr = metadata is not None and metadata.get("is_lokr") == "true" + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if is_lokr: + transformer.fuse_qkv_projections() + elif ( + hasattr(transformer, "transformer_blocks") + and len(transformer.transformer_blocks) > 0 + and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False) + ): + # Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match. + state_dict = _refuse_flux2_lora_state_dict(state_dict) self.load_lora_into_transformer( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + transformer=transformer, adapter_name=adapter_name, metadata=metadata, _pipeline=self, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index daa078bc25d5..1ec304f24944 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -38,7 +38,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys +from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( - lambda: (lambda model_cls, weights: weights), + lambda: lambda model_cls, weights: weights, { "UNet2DConditionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales, @@ -213,56 +213,65 @@ def load_lora_adapter( "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora - # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors - is_sai_sd_control_lora = "lora_controlnet" in state_dict - if is_sai_sd_control_lora: - state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) - - rank = {} - for key, val in state_dict.items(): - # Cannot figure out rank from lora layers that don't have at least 2 dimensions. - # Bias layers in LoRA only have a single dimension - if "lora_B" in key and val.ndim > 1: - # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. - # We may run into some ambiguous configuration values when a model has module - # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, - # for example) and they have different LoRA ranks. - rank[f"^{key}"] = val.shape[1] - - if network_alphas is not None and len(network_alphas) >= 1: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] - network_alphas = { - k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys - } - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(self) - - # create LoraConfig - lora_config = _create_lora_config( - state_dict, - network_alphas, - metadata, - rank, - model_state_dict=self.state_dict(), - adapter_name=adapter_name, - ) + # Detect whether this is a LoKR adapter (Kronecker product, not low-rank) + is_lokr = any("lokr_" in k for k in state_dict) + + if is_lokr: + if adapter_name is None: + adapter_name = get_adapter_name(self) + lora_config = _create_lokr_config(state_dict) + is_sai_sd_control_lora = False + else: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora + # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors + is_sai_sd_control_lora = "lora_controlnet" in state_dict + if is_sai_sd_control_lora: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) + + rank = {} + for key, val in state_dict.items(): + # Cannot figure out rank from lora layers that don't have at least 2 dimensions. + # Bias layers in LoRA only have a single dimension + if "lora_B" in key and val.ndim > 1: + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] + + if network_alphas is not None and len(network_alphas) >= 1: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] + network_alphas = { + k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys + } + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(self) + + # create LoraConfig + lora_config = _create_lora_config( + state_dict, + network_alphas, + metadata, + rank, + model_state_dict=self.state_dict(), + adapter_name=adapter_name, + ) - # Adjust LoRA config for Control LoRA - if is_sai_sd_control_lora: - lora_config.lora_alpha = lora_config.r - lora_config.alpha_pattern = lora_config.rank_pattern - lora_config.bias = "all" - lora_config.modules_to_save = lora_config.exclude_modules - lora_config.exclude_modules = None + # Adjust LoRA config for Control LoRA + if is_sai_sd_control_lora: + lora_config.lora_alpha = lora_config.r + lora_config.alpha_pattern = lora_config.rank_pattern + lora_config.bias = "all" + lora_config.modules_to_save = lora_config.exclude_modules + lora_config.exclude_modules = None # None: ) +def _create_lokr_config(state_dict): + """Create a peft LoKrConfig from a converted LoKR state dict. + + Infers rank, decompose_both, decompose_factor, and target_modules from the state dict key names + and tensor shapes. Alpha scaling is assumed to be already baked into the weights, so config + alpha = r (scaling = 1.0). + + Peft determines w2 decomposition via ``r < max(out_k, in_n) / 2``. We must set per-module rank + values that reproduce the same decomposition pattern as the checkpoint. For modules with full + (non-decomposed) lokr_w2, we set rank = max(lokr_w2.shape) so that peft also creates a full w2. + """ + from peft import LoKrConfig + + # Infer decompose_both from presence of lokr_w1_a keys + decompose_both = any("lokr_w1_a" in k for k in state_dict) + + # Infer decompose_factor from lokr_w1 shapes. + # With a fixed factor (e.g., 4), all w1 shapes are (factor, factor). + # With factor=-1 (near-sqrt), w1 shapes vary per module based on dimension. + w1_shapes = set() + for key, val in state_dict.items(): + if "lokr_w1" in key and "lokr_w1_a" not in key and "lokr_w1_b" not in key and val.ndim == 2: + w1_shapes.add(val.shape[0]) + if len(w1_shapes) == 1: + # All w1 have the same first dimension - this is the decompose_factor + decompose_factor = w1_shapes.pop() + else: + # Shapes vary - near-sqrt factorization was used + decompose_factor = -1 + + # Extract target modules and their decomposition state + lokr_suffixes = {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"} + target_modules = set() + for key in state_dict: + for suffix in lokr_suffixes: + if f".{suffix}" in key: + target_modules.add(key.split(f".{suffix}")[0]) + break + + # Build per-module rank dict that ensures peft creates matching decomposition + rank_dict = {} + for key, val in state_dict.items(): + if "lokr_w2_a" in key and val.ndim > 1: + # Decomposed w2: rank = inner dimension of w2_a + module_name = key.split(".lokr_w2_a")[0] + rank_dict[module_name] = val.shape[1] + elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key and val.ndim > 1: + # Full w2 matrix: set rank high enough so peft also creates full w2. + # Peft uses full w2 when r >= max(out_k, in_n) / 2, where (out_k, in_n) = lokr_w2.shape. + module_name = key.split(".lokr_w2")[0] + if module_name not in rank_dict: + rank_dict[module_name] = max(val.shape) + + # Also extract rank from w1_a if w2 info is missing + for key, val in state_dict.items(): + if "lokr_w1_a" in key and val.ndim > 1: + module_name = key.split(".lokr_w1_a")[0] + if module_name not in rank_dict: + rank_dict[module_name] = val.shape[1] + + # Determine default rank (most common) and per-module rank pattern + if rank_dict: + import collections + + r = collections.Counter(rank_dict.values()).most_common()[0][0] + rank_pattern = {k: v for k, v in rank_dict.items() if v != r} + else: + r = 1 + rank_pattern = {} + + lokr_config_kwargs = { + "r": r, + "alpha": r, # alpha baked into weights, so runtime scaling = alpha/r = 1.0 + "target_modules": list(target_modules), + "rank_pattern": rank_pattern, + "alpha_pattern": dict(rank_pattern), # keep alpha=r per module + "decompose_both": decompose_both, + "decompose_factor": decompose_factor, + } + + try: + return LoKrConfig(**lokr_config_kwargs) + except TypeError as e: + raise TypeError("`LoKrConfig` class could not be instantiated.") from e + + +def _convert_adapter_to_lora(model, rank, adapter_name="default"): + """Convert a loaded non-LoRA peft adapter (e.g., LoKR) to LoRA via truncated SVD. + + Wraps ``peft.convert_to_lora`` which materializes each adapter layer's delta weight + and decomposes it as ``U @ diag(S) @ V ≈ lora_B @ lora_A``. The conversion is lossy: + higher ``rank`` preserves more fidelity at the cost of larger LoRA matrices. + + Args: + model: ``nn.Module`` with a peft adapter already injected. + rank: ``int`` for a fixed LoRA rank, or ``float`` in (0, 1] as an energy threshold + (picks the smallest rank capturing that fraction of singular value energy). + adapter_name: Name of the adapter to convert. + + Returns: + Tuple of ``(LoraConfig, state_dict)`` for the converted LoRA adapter. + + Raises: + ImportError: If peft does not provide ``convert_to_lora`` (requires peft >= 0.19.0). + """ + try: + from peft import convert_to_lora + except ImportError: + raise ImportError( + "`peft.convert_to_lora` is required for lossy LoKR-to-LoRA conversion. " + "Install peft >= 0.19.0 or from source: pip install git+https://github.com/huggingface/peft.git" + ) + return convert_to_lora(model, rank, adapter_name=adapter_name) + + def _create_lora_config( state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None ):