Skip to content

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326

Open
CalamitousFelicitousness wants to merge 1 commit intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr
Open

Add Flux2 LoKR adapter support prototype with dual conversion paths#13326
CalamitousFelicitousness wants to merge 1 commit intohuggingface:mainfrom
CalamitousFelicitousness:feature/flux2-klein-lokr

Conversation

@CalamitousFelicitousness
Copy link
Contributor

@CalamitousFelicitousness CalamitousFelicitousness commented Mar 25, 2026

Adds support for Flux2 LoKR, with dual path to benchmark implementations.

  • Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
  • Generic lossy path: optional SVD conversion via peft.convert_to_lora
  • Fix alpha handling for lora_down/lora_up format checkpoints
  • Re-fuse LoRA keys when model QKV is fused from prior LoKR load

Given that Civitai does not have a LoKR category I didn't feel like digging for a SFW one, so I just used what user brought up to me when they reported the issue.

Benchmark test
"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B.

Generates images using both conversion paths for visual comparison.
Uses 4-bit quantization + CPU offload to fit on a single 24GB GPU.

Usage:
    python benchmark_lokr.py
    python benchmark_lokr.py --prompt "a cat in a garden" --ranks 32 64 128
"""

import argparse
import gc
import os
import time

import torch

# Use local model cache
os.environ["HF_HUB_CACHE"] = "/home/ohiom/database/models/huggingface"

from diffusers import Flux2KleinPipeline  # noqa: E402
from peft import convert_to_lora  # noqa: E402

MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
LOKR_PATH = "/home/ohiom/database/models/Lora/Flux.2 Klein 9B/klein_snofs_v1_1.safetensors"
OUTPUT_DIR = "/home/ohiom/diffusers/benchmark_output"


def load_pipeline():
    """Load Flux2 Klein 9B in bf16 with model CPU offload."""
    pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
    pipe.enable_model_cpu_offload()
    return pipe


def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0):
    """Generate a single image with fixed seed for reproducibility."""
    generator = torch.Generator(device="cpu").manual_seed(seed)
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        generator=generator,
        height=1024,
        width=1024,
    ).images[0]
    return image


def benchmark_lossless(pipe, prompt, seed):
    """Path A: Load LoKR natively (lossless)."""
    print("\n=== Path A: Lossless LoKR ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    print(f"  Loaded in {time.time() - t0:.1f}s")

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_lossy(pipe, prompt, seed, rank):
    """Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
    print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
    t0 = time.time()
    pipe.load_lora_weights(LOKR_PATH)
    load_time = time.time() - t0

    # Detect the actual adapter name assigned by peft
    adapter_name = pipe.transformer.peft_config.keys().__iter__().__next__()
    print(f"  Adapter name: {adapter_name}")

    pipe.transformer.to("cuda")
    t0 = time.time()
    lora_config, lora_sd = convert_to_lora(pipe.transformer, rank, adapter_name=adapter_name, progressbar=True)
    convert_time = time.time() - t0
    print(f"  Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s")

    # Replace LoKR adapter with converted LoRA
    from peft import inject_adapter_in_model, set_peft_model_state_dict

    pipe.transformer.delete_adapter(adapter_name)
    inject_adapter_in_model(pipe.transformer, lora_config, adapter_name=adapter_name)
    set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)

    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")

    pipe.unload_lora_weights()
    return image


def benchmark_baseline(pipe, prompt, seed):
    """Baseline: No adapter."""
    print("\n=== Baseline: No adapter ===")
    t0 = time.time()
    image = generate(pipe, prompt, seed)
    print(f"  Generated in {time.time() - t0:.1f}s")
    return image


def main():
    parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
    parser.add_argument(
        "--prompt",
        default="A high-angle POV photograph shows a nude white woman with blonde hair",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
    parser.add_argument("--steps", type=int, default=28)
    parser.add_argument("--skip-baseline", action="store_true")
    args = parser.parse_args()

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    print(f"Model: {MODEL_ID}")
    print(f"LoKR:  {LOKR_PATH}")
    print(f"Prompt: {args.prompt}")
    print(f"Seed: {args.seed}")
    print(f"SVD ranks to test: {args.ranks}")

    print("\nLoading pipeline (bf16, model CPU offload)...")
    pipe = load_pipeline()

    # Baseline
    if not args.skip_baseline:
        img = benchmark_baseline(pipe, args.prompt, args.seed)
        path = os.path.join(OUTPUT_DIR, "baseline.png")
        img.save(path)
        print(f"  Saved: {path}")

    # Path A: Lossless LoKR
    img = benchmark_lossless(pipe, args.prompt, args.seed)
    path = os.path.join(OUTPUT_DIR, "lokr_lossless.png")
    img.save(path)
    print(f"  Saved: {path}")

    gc.collect()
    torch.cuda.empty_cache()

    # Path B: Lossy LoRA via SVD at various ranks
    for rank in args.ranks:
        img = benchmark_lossy(pipe, args.prompt, args.seed, rank)
        path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png")
        img.save(path)
        print(f"  Saved: {path}")

        gc.collect()
        torch.cuda.empty_cache()

    print(f"\nAll results saved to {OUTPUT_DIR}/")
    print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png")


if __name__ == "__main__":
    main()

Not sure if I'm doing it correctly for PEFT, but due to lack of quantization support I couldn't run the lossy path for now.

What does this PR do?

Fixes #13261

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @BenjaminBossan

- Custom lossless path: BFL LoKR keys → peft LoKrConfig (fuse-first QKV)
- Generic lossy path: optional SVD conversion via peft.convert_to_lora
- Fix alpha handling for lora_down/lora_up format checkpoints
- Re-fuse LoRA keys when model QKV is fused from prior LoKR load
@BenjaminBossan
Copy link
Member

PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19).

@CalamitousFelicitousness
Copy link
Contributor Author

CalamitousFelicitousness commented Mar 25, 2026

PR looks good AFAICT, but of course I'm no Diffusers expert. @CalamitousFelicitousness do you have results from the test script that you can share? Also, for other readers: The LoRA conversion currently requires installing PEFT from the main branch (in the future: version 0.19).

As I mentioned in the message I can't currently fit it for the SVD tests, OOMs on my RTX 3090 and my 6000 Ada is not available at the moment. For now I only know programatic tests pass.

@sayakpaul
Copy link
Member

I will help with the lossy path results.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the clean PR! Left a couple of comments. LMK if anything is unclear.

Comment on lines +2334 to +2344
# Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
# Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
alpha_keys = [k for k in original_state_dict if k.endswith(".alpha")]
for alpha_key in alpha_keys:
alpha = original_state_dict.pop(alpha_key).item()
module_path = alpha_key[: -len(".alpha")]
lora_a_key = f"{module_path}.lora_A.weight"
if lora_a_key in original_state_dict:
rank = original_state_dict[lora_a_key].shape[0]
scale = alpha / rank
original_state_dict[lora_a_key] = original_state_dict[lora_a_key] * scale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it rebase with upstream please? I think this was fixed with #13313

Comment on lines +2461 to +2463
Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
`fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
QKV projections before injecting the adapter.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot assume that the user will always fuse_projections(). So, let's split the fused QKV as done in the other conversion workflows.

return converted_state_dict


def _refuse_flux2_lora_state_dict(state_dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also change given the above comment.

state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
if metadata is None:
metadata = {}
metadata["is_lokr"] = "true"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could append to metadata as:

metadata["is_lokr"] = is_lokr

We could make it a bit more general (e.g., adapter type) if needed in a future PR. Cc: @BenjaminBossan

Comment on lines +5730 to +5738
if is_lokr:
transformer.fuse_qkv_projections()
elif (
hasattr(transformer, "transformer_blocks")
and len(transformer.transformer_blocks) > 0
and getattr(transformer.transformer_blocks[0].attn, "fused_projections", False)
):
# Model QKV is fused but LoRA targets separate Q/K/V - re-fuse the keys to match.
state_dict = _refuse_flux2_lora_state_dict(state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not assume it as mentioned previously. Doing this will also simplify the code (getting rid of the conditionals, for example).

if is_lokr:
if adapter_name is None:
adapter_name = get_adapter_name(self)
lora_config = _create_lokr_config(state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it adapter_config to be a bit more idiomatic.

values that reproduce the same decomposition pattern as the checkpoint. For modules with full
(non-decomposed) lokr_w2, we set rank = max(lokr_w2.shape) so that peft also creates a full w2.
"""
from peft import LoKrConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can go at the top.


# Determine default rank (most common) and per-module rank pattern
if rank_dict:
import collections
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

raise TypeError("`LoKrConfig` class could not be instantiated.") from e


def _convert_adapter_to_lora(model, rank, adapter_name="default"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't see its use in the benchmarking script. If you update the script, I can help it run across different settings (without quantization to isolate the impact) and post the results.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see it in convert_to_lora().

@sayakpaul
Copy link
Member

@claude please review this PR as well.

@sayakpaul
Copy link
Member

Also @CalamitousFelicitousness you might want to change the default prompt in the benchmarking script. That is highly NSFW. Let's be mindful of that.

@CalamitousFelicitousness
Copy link
Contributor Author

If someone can provide me with a link to a SFW LoKR I can use I can update it, the prompt was taken verbatim from the creator's examples. Using a prompt unrelated to the aim of the adaptor is not ideal.

@sayakpaul
Copy link
Member

Then we will have to wait for one that is SFW. Respectfully, we cannot base our work on the grounds of NSFW content.

@chaowenguo can you provide a LoKR checkpoint that doesn't use any form of nudity?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

black-forest-labs/FLUX.2-klein-9B not working with lora with lokr

3 participants