Add Flux2 LoKR adapter support prototype with dual conversion paths#13326
Add Flux2 LoKR adapter support prototype with dual conversion paths#13326CalamitousFelicitousness wants to merge 1 commit intohuggingface:mainfrom
Conversation
- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV) - Generic lossy path: optional SVD conversion via peft.convert_to_lora - Fix alpha handling for lora_down/lora_up format checkpoints - Re-fuse LoRA keys when model QKV is fused from prior LoKR load
|
PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19). |
As I mentioned in the message I can't currently fit it for the SVD tests, OOMs on my RTX 3090 and my 6000 Ada is not available at the moment. For now I only know programatic tests pass. |
|
I will help with the lossy path results. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for the clean PR! Left a couple of comments. LMK if anything is unclear.
| # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed. | ||
| # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1. | ||
| alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")] | ||
| for alpha_key in alpha_keys: | ||
| alpha = original_state_dict.pop(alpha_key).item() | ||
| module_path = alpha_key[: -len(".alpha")] | ||
| lora_a_key = f"{module_path}.lora_A.weight" | ||
| if lora_a_key in original_state_dict: | ||
| rank = original_state_dict[lora_a_key].shape[0] | ||
| scale = alpha / rank | ||
| original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale |
There was a problem hiding this comment.
Could it rebase with upstream please? I think this was fixed with #13313
| Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by | ||
| `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's | ||
| QKV projections before injecting the adapter. |
There was a problem hiding this comment.
We cannot assume that the user will always fuse_projections(). So, let's split the fused QKV as done in the other conversion workflows.
| return converted_state_dict | ||
|
|
||
|
|
||
| def _refuse_flux2_lora_state_dict(state_dict): |
There was a problem hiding this comment.
This will also change given the above comment.
| state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict) | ||
| if metadata is None: | ||
| metadata = {} | ||
| metadata["is_lokr"] = "true" |
There was a problem hiding this comment.
We could append to metadata as:
metadata["is_lokr"] = is_lokrWe could make it a bit more general (e.g., adapter type) if needed in a future PR. Cc: @BenjaminBossan
| if is_lokr: | ||
| transformer.fuse_qkv_projections() | ||
| elif ( | ||
| hasattr(transformer, "transformer_blocks") | ||
| and len(transformer.transformer_blocks) > 0 | ||
| and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False) | ||
| ): | ||
| # Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match. | ||
| state_dict = _refuse_flux2_lora_state_dict(state_dict) |
There was a problem hiding this comment.
Let's not assume it as mentioned previously. Doing this will also simplify the code (getting rid of the conditionals, for example).
| if is_lokr: | ||
| if adapter_name is None: | ||
| adapter_name = get_adapter_name(self) | ||
| lora_config = _create_lokr_config(state_dict) |
There was a problem hiding this comment.
Let's call it adapter_config to be a bit more idiomatic.
| values that reproduce the same decomposition pattern as the checkpoint. For modules with full | ||
| (non-decomposed) lokr_w2, we set rank = max(lokr_w2.shape) so that peft also creates a full w2. | ||
| """ | ||
| from peft import LoKrConfig |
|
|
||
| # Determine default rank (most common) and per-module rank pattern | ||
| if rank_dict: | ||
| import collections |
| raise TypeError("`LoKrConfig` class could not be instantiated.") from e | ||
|
|
||
|
|
||
| def _convert_adapter_to_lora(model, rank, adapter_name="default"): |
There was a problem hiding this comment.
Didn't see its use in the benchmarking script. If you update the script, I can help it run across different settings (without quantization to isolate the impact) and post the results.
There was a problem hiding this comment.
Oh I see it in convert_to_lora().
|
@claude please review this PR as well. |
|
Also @CalamitousFelicitousness you might want to change the default prompt in the benchmarking script. That is highly NSFW. Let's be mindful of that. |
|
If someone can provide me with a link to a SFW LoKR I can use I can update it, the prompt was taken verbatim from the creator's examples. Using a prompt unrelated to the aim of the adaptor is not ideal. |
|
Then we will have to wait for one that is SFW. Respectfully, we cannot base our work on the grounds of NSFW content. @chaowenguo can you provide a LoKR checkpoint that doesn't use any form of nudity? |
Adds support for Flux2 LoKR, with dual path to benchmark implementations.
Given that Civitai does not have a LoKR category I didn't feel like digging for a SFW one, so I just used what user brought up to me when they reported the issue.
Benchmark test
Not sure if I'm doing it correctly for PEFT, but due to lack of quantization support I couldn't run the lossy path for now.
What does this PR do?
Fixes #13261
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @BenjaminBossan