-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Add Flux2 LoKR adapter support prototype with dual conversion paths #13326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): | |
| temp_state_dict[new_key] = v | ||
| original_state_dict = temp_state_dict | ||
|
|
||
| # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed. | ||
| # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1. | ||
| alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")] | ||
| for alpha_key in alpha_keys: | ||
| alpha = original_state_dict.pop(alpha_key).item() | ||
| module_path = alpha_key[: -len(".alpha")] | ||
| lora_a_key = f"{module_path}.lora_A.weight" | ||
| if lora_a_key in original_state_dict: | ||
| rank = original_state_dict[lora_a_key].shape[0] | ||
| scale = alpha / rank | ||
| original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale | ||
|
|
||
| num_double_layers = 0 | ||
| num_single_layers = 0 | ||
| for key in original_state_dict.keys(): | ||
|
|
@@ -2443,6 +2455,152 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): | |
| return converted_state_dict | ||
|
|
||
|
|
||
| def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict): | ||
| """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format. | ||
| Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by | ||
| `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's | ||
| QKV projections before injecting the adapter. | ||
|
Comment on lines
+2461
to
+2463
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot assume that the user will always |
||
| """ | ||
| converted_state_dict = {} | ||
|
|
||
| prefix = "diffusion_model." | ||
| original_state_dict = {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} | ||
|
|
||
| num_double_layers = 0 | ||
| num_single_layers = 0 | ||
| for key in original_state_dict: | ||
| if key.startswith("single_blocks."): | ||
| num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1) | ||
| elif key.startswith("double_blocks."): | ||
| num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1) | ||
|
|
||
| lokr_suffixes = ("lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2") | ||
|
|
||
| def _remap_lokr_module(bfl_path, diff_path): | ||
| """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path.""" | ||
| alpha_key = f"{bfl_path}.alpha" | ||
| alpha = original_state_dict.pop(alpha_key).item() if alpha_key in original_state_dict else None | ||
|
|
||
| for suffix in lokr_suffixes: | ||
| src_key = f"{bfl_path}.{suffix}" | ||
| if src_key not in original_state_dict: | ||
| continue | ||
|
|
||
| weight = original_state_dict.pop(src_key) | ||
|
|
||
| # Bake alpha/rank scaling into the first w1 tensor encountered for this module. | ||
| # After baking, peft's config uses alpha=r so its runtime scaling is 1.0. | ||
| if alpha is not None and suffix in ("lokr_w1", "lokr_w1_a"): | ||
| w2a_key = f"{bfl_path}.lokr_w2_a" | ||
| w1a_key = f"{bfl_path}.lokr_w1_a" | ||
| if w2a_key in original_state_dict: | ||
| r_eff = original_state_dict[w2a_key].shape[1] | ||
| elif w1a_key in original_state_dict: | ||
| r_eff = original_state_dict[w1a_key].shape[1] | ||
| else: | ||
| r_eff = alpha | ||
| scale = alpha / r_eff | ||
| weight = weight * scale | ||
| alpha = None # only bake once per module | ||
|
|
||
| converted_state_dict[f"{diff_path}.{suffix}"] = weight | ||
|
|
||
| # --- Single blocks --- | ||
| for sl in range(num_single_layers): | ||
| _remap_lokr_module(f"single_blocks.{sl}.linear1", f"single_transformer_blocks.{sl}.attn.to_qkv_mlp_proj") | ||
| _remap_lokr_module(f"single_blocks.{sl}.linear2", f"single_transformer_blocks.{sl}.attn.to_out") | ||
|
|
||
| # --- Double blocks --- | ||
| for dl in range(num_double_layers): | ||
| tb = f"transformer_blocks.{dl}" | ||
| db = f"double_blocks.{dl}" | ||
|
|
||
| # QKV → fused to_qkv / to_added_qkv (model must be fused before injection) | ||
| _remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv") | ||
| _remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv") | ||
|
|
||
| # Projections | ||
| _remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0") | ||
| _remap_lokr_module(f"{db}.txt_attn.proj", f"{tb}.attn.to_add_out") | ||
|
|
||
| # MLPs | ||
| _remap_lokr_module(f"{db}.img_mlp.0", f"{tb}.ff.linear_in") | ||
| _remap_lokr_module(f"{db}.img_mlp.2", f"{tb}.ff.linear_out") | ||
| _remap_lokr_module(f"{db}.txt_mlp.0", f"{tb}.ff_context.linear_in") | ||
| _remap_lokr_module(f"{db}.txt_mlp.2", f"{tb}.ff_context.linear_out") | ||
|
|
||
| # --- Extra mappings (embedders, modulation, final layer) --- | ||
| extra_mappings = { | ||
| "img_in": "x_embedder", | ||
| "txt_in": "context_embedder", | ||
| "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", | ||
| "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", | ||
| "final_layer.linear": "proj_out", | ||
| "final_layer.adaLN_modulation.1": "norm_out.linear", | ||
| "single_stream_modulation.lin": "single_stream_modulation.linear", | ||
| "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", | ||
| "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", | ||
| } | ||
| for bfl_key, diff_key in extra_mappings.items(): | ||
| _remap_lokr_module(bfl_key, diff_key) | ||
|
|
||
| if len(original_state_dict) > 0: | ||
| raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.") | ||
|
|
||
| for key in list(converted_state_dict.keys()): | ||
| converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key) | ||
|
|
||
| return converted_state_dict | ||
|
|
||
|
|
||
| def _refuse_flux2_lora_state_dict(state_dict): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will also change given the above comment. |
||
| """Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys. | ||
| When the model's QKV projections are fused, incoming LoRA keys targeting separate | ||
| to_q/to_k/to_v must be re-fused to match. Inverse of the QKV split performed by | ||
| ``_convert_non_diffusers_flux2_lora_to_diffusers``. | ||
| """ | ||
| converted = {} | ||
| remaining = dict(state_dict) | ||
|
|
||
| # Detect double block indices from keys | ||
| num_double_layers = 0 | ||
| for key in remaining: | ||
| if ".transformer_blocks." in key: | ||
| parts = key.split(".") | ||
| idx = parts.index("transformer_blocks") + 1 | ||
| if idx < len(parts) and parts[idx].isdigit(): | ||
| num_double_layers = max(num_double_layers, int(parts[idx]) + 1) | ||
|
|
||
| # Fuse Q/K/V for image stream and text stream | ||
| qkv_groups = [ | ||
| (["to_q", "to_k", "to_v"], "to_qkv"), | ||
| (["add_q_proj", "add_k_proj", "add_v_proj"], "to_added_qkv"), | ||
| ] | ||
|
|
||
| for dl in range(num_double_layers): | ||
| attn_prefix = f"transformer.transformer_blocks.{dl}.attn" | ||
| for separate_keys, fused_name in qkv_groups: | ||
| for lora_key in ("lora_A", "lora_B"): | ||
| src_keys = [f"{attn_prefix}.{sk}.{lora_key}.weight" for sk in separate_keys] | ||
| if not all(k in remaining for k in src_keys): | ||
| continue | ||
|
|
||
| weights = [remaining.pop(k) for k in src_keys] | ||
| dst_key = f"{attn_prefix}.{fused_name}.{lora_key}.weight" | ||
| if lora_key == "lora_A": | ||
| # lora_A was replicated during split - all three are identical, take the first | ||
| converted[dst_key] = weights[0] | ||
| else: | ||
| # lora_B was chunked along dim=0 - concatenate back | ||
| converted[dst_key] = torch.cat(weights, dim=0) | ||
|
|
||
| # Pass through all non-QKV keys unchanged | ||
| converted.update(remaining) | ||
| return converted | ||
|
|
||
|
|
||
| def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict): | ||
| """ | ||
| Convert non-diffusers ZImage LoRA state dict to diffusers format. | ||
|
|
@@ -2600,14 +2758,14 @@ def get_alpha_scales(down_weight, alpha_key): | |
|
|
||
| base = k[: -len(lora_dot_down_key)] | ||
|
|
||
| # Skip combined "qkv" projection — individual to.q/k/v keys are also present. | ||
| # Skip combined "qkv" projection - individual to.q/k/v keys are also present. | ||
| if base.endswith(".qkv"): | ||
| state_dict.pop(k) | ||
| state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) | ||
| state_dict.pop(base + ".alpha", None) | ||
| continue | ||
|
|
||
| # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection. | ||
| # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection. | ||
| if re.search(r"\.out$", base) and ".to_out" not in base: | ||
| state_dict.pop(k) | ||
| state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,7 @@ | |
| _convert_hunyuan_video_lora_to_diffusers, | ||
| _convert_kohya_flux_lora_to_diffusers, | ||
| _convert_musubi_wan_lora_to_diffusers, | ||
| _convert_non_diffusers_flux2_lokr_to_diffusers, | ||
| _convert_non_diffusers_flux2_lora_to_diffusers, | ||
| _convert_non_diffusers_hidream_lora_to_diffusers, | ||
| _convert_non_diffusers_lora_to_diffusers, | ||
|
|
@@ -56,6 +57,7 @@ | |
| _convert_non_diffusers_z_image_lora_to_diffusers, | ||
| _convert_xlabs_flux_lora_to_diffusers, | ||
| _maybe_map_sgm_blocks_to_diffusers, | ||
| _refuse_flux2_lora_state_dict, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -5679,12 +5681,18 @@ def lora_state_dict( | |
|
|
||
| is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict) | ||
| if is_ai_toolkit: | ||
| state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) | ||
| is_lokr = any("lokr_" in k for k in state_dict) | ||
| if is_lokr: | ||
| state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) | ||
| if metadata is None: | ||
| metadata = {} | ||
| metadata["is_lokr"] = "true" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could append to metadata["is_lokr"] = is_lokrWe could make it a bit more general (e.g., adapter type) if needed in a future PR. Cc: @BenjaminBossan |
||
| else: | ||
| state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict) | ||
|
|
||
| out = (state_dict, metadata) if return_lora_metadata else state_dict | ||
| return out | ||
|
|
||
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
| def load_lora_weights( | ||
| self, | ||
| pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], | ||
|
|
@@ -5712,13 +5720,26 @@ def load_lora_weights( | |
| kwargs["return_lora_metadata"] = True | ||
| state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
|
||
| is_correct_format = all("lora" in key for key in state_dict.keys()) | ||
| is_correct_format = all("lora" in key or "lokr" in key for key in state_dict.keys()) | ||
| if not is_correct_format: | ||
| raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") | ||
| raise ValueError("Invalid LoRA/LoKR checkpoint. Make sure all param names contain `'lora'` or `'lokr'`.") | ||
|
|
||
| # For LoKR adapters, fuse QKV projections so peft can target the fused modules directly. | ||
| is_lokr = metadata is not None and metadata.get("is_lokr") == "true" | ||
| transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer | ||
| if is_lokr: | ||
| transformer.fuse_qkv_projections() | ||
| elif ( | ||
| hasattr(transformer, "transformer_blocks") | ||
| and len(transformer.transformer_blocks) > 0 | ||
| and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False) | ||
| ): | ||
| # Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match. | ||
| state_dict = _refuse_flux2_lora_state_dict(state_dict) | ||
|
Comment on lines
+5730
to
+5738
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not assume it as mentioned previously. Doing this will also simplify the code (getting rid of the conditionals, for example). |
||
|
|
||
| self.load_lora_into_transformer( | ||
| state_dict, | ||
| transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | ||
| transformer=transformer, | ||
| adapter_name=adapter_name, | ||
| metadata=metadata, | ||
| _pipeline=self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,15 +38,15 @@ | |
| set_adapter_layers, | ||
| set_weights_and_activate_adapters, | ||
| ) | ||
| from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys | ||
| from ..utils.peft_utils import _create_lokr_config, _create_lora_config, _maybe_warn_for_unhandled_keys | ||
| from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading | ||
| from .unet_loader_utils import _maybe_expand_lora_scales | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| _SET_ADAPTER_SCALE_FN_MAPPING = defaultdict( | ||
| lambda: (lambda model_cls, weights: weights), | ||
| lambda: lambda model_cls, weights: weights, | ||
| { | ||
| "UNet2DConditionModel": _maybe_expand_lora_scales, | ||
| "UNetMotionModel": _maybe_expand_lora_scales, | ||
|
|
@@ -213,56 +213,65 @@ def load_lora_adapter( | |
| "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." | ||
| ) | ||
|
|
||
| # check with first key if is not in peft format | ||
| first_key = next(iter(state_dict.keys())) | ||
| if "lora_A" not in first_key: | ||
| state_dict = convert_unet_state_dict_to_peft(state_dict) | ||
|
|
||
| # Control LoRA from SAI is different from BFL Control LoRA | ||
| # https://huggingface.co/stabilityai/control-lora | ||
| # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors | ||
| is_sai_sd_control_lora = "lora_controlnet" in state_dict | ||
| if is_sai_sd_control_lora: | ||
| state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) | ||
|
|
||
| rank = {} | ||
| for key, val in state_dict.items(): | ||
| # Cannot figure out rank from lora layers that don't have at least 2 dimensions. | ||
| # Bias layers in LoRA only have a single dimension | ||
| if "lora_B" in key and val.ndim > 1: | ||
| # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. | ||
| # We may run into some ambiguous configuration values when a model has module | ||
| # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, | ||
| # for example) and they have different LoRA ranks. | ||
| rank[f"^{key}"] = val.shape[1] | ||
|
|
||
| if network_alphas is not None and len(network_alphas) >= 1: | ||
| alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||
| network_alphas = { | ||
| k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys | ||
| } | ||
|
|
||
| # adapter_name | ||
| if adapter_name is None: | ||
| adapter_name = get_adapter_name(self) | ||
|
|
||
| # create LoraConfig | ||
| lora_config = _create_lora_config( | ||
| state_dict, | ||
| network_alphas, | ||
| metadata, | ||
| rank, | ||
| model_state_dict=self.state_dict(), | ||
| adapter_name=adapter_name, | ||
| ) | ||
| # Detect whether this is a LoKR adapter (Kronecker product, not low-rank) | ||
| is_lokr = any("lokr_" in k for k in state_dict) | ||
|
|
||
| if is_lokr: | ||
| if adapter_name is None: | ||
| adapter_name = get_adapter_name(self) | ||
| lora_config = _create_lokr_config(state_dict) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's call it |
||
| is_sai_sd_control_lora = False | ||
| else: | ||
| # check with first key if is not in peft format | ||
| first_key = next(iter(state_dict.keys())) | ||
| if "lora_A" not in first_key: | ||
| state_dict = convert_unet_state_dict_to_peft(state_dict) | ||
|
|
||
| # Control LoRA from SAI is different from BFL Control LoRA | ||
| # https://huggingface.co/stabilityai/control-lora | ||
| # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors | ||
| is_sai_sd_control_lora = "lora_controlnet" in state_dict | ||
| if is_sai_sd_control_lora: | ||
| state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) | ||
|
|
||
| rank = {} | ||
| for key, val in state_dict.items(): | ||
| # Cannot figure out rank from lora layers that don't have at least 2 dimensions. | ||
| # Bias layers in LoRA only have a single dimension | ||
| if "lora_B" in key and val.ndim > 1: | ||
| # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. | ||
| # We may run into some ambiguous configuration values when a model has module | ||
| # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, | ||
| # for example) and they have different LoRA ranks. | ||
| rank[f"^{key}"] = val.shape[1] | ||
|
|
||
| if network_alphas is not None and len(network_alphas) >= 1: | ||
| alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] | ||
| network_alphas = { | ||
| k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys | ||
| } | ||
|
|
||
| # adapter_name | ||
| if adapter_name is None: | ||
| adapter_name = get_adapter_name(self) | ||
|
|
||
| # create LoraConfig | ||
| lora_config = _create_lora_config( | ||
| state_dict, | ||
| network_alphas, | ||
| metadata, | ||
| rank, | ||
| model_state_dict=self.state_dict(), | ||
| adapter_name=adapter_name, | ||
| ) | ||
|
|
||
| # Adjust LoRA config for Control LoRA | ||
| if is_sai_sd_control_lora: | ||
| lora_config.lora_alpha = lora_config.r | ||
| lora_config.alpha_pattern = lora_config.rank_pattern | ||
| lora_config.bias = "all" | ||
| lora_config.modules_to_save = lora_config.exclude_modules | ||
| lora_config.exclude_modules = None | ||
| # Adjust LoRA config for Control LoRA | ||
| if is_sai_sd_control_lora: | ||
| lora_config.lora_alpha = lora_config.r | ||
| lora_config.alpha_pattern = lora_config.rank_pattern | ||
| lora_config.bias = "all" | ||
| lora_config.modules_to_save = lora_config.exclude_modules | ||
| lora_config.exclude_modules = None | ||
|
|
||
| # <Unsafe code | ||
| # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could it rebase with upstream please? I think this was fixed with #13313