diff --git a/README.md b/README.md index 6f1e393..b359310 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,8 @@ That's when the first version of LightDiffusion was born which only counted [300 📚 Learn more in the [official documentation](https://aatricks.github.io/LightDiffusion-Next/) +For a source-based breakdown of the optimization stack, see the [Implemented Optimizations Report](https://aatricks.github.io/LightDiffusion-Next/implemented-optimizations-report/). + --- ## 🌟 Highlights @@ -181,7 +183,7 @@ docker-compose build \ Set `INSTALL_STABLE_FAST=1` to enable the compilation step for stable-fast, or `INSTALL_OLLAMA=1` to bake in the prompt enhancer runtime. > [!NOTE] -> RTX 50 series (compute 12.0) GPUs currently only support SageAttention. +> RTX 50 series (compute 12.0) GPUs currently use SageAttention when the SageAttention kernel is installed. SpargeAttn remains limited to earlier supported architectures. **Access the Web Interface:** - **Streamlit UI** (default): `http://localhost:8501` diff --git a/docs/advanced-cfg-optimizations.md b/docs/advanced-cfg-optimizations.md index bd0b83e..93ed168 100644 --- a/docs/advanced-cfg-optimizations.md +++ b/docs/advanced-cfg-optimizations.md @@ -12,7 +12,7 @@ This document describes three advanced optimizations for Classifier-Free Guidanc ### What It Does -Instead of running two separate forward passes for conditional and unconditional predictions, this optimization combines them into a single batched forward pass. +Instead of running two separate forward passes for conditional and unconditional predictions, this optimization can combine them into a single batched forward pass. **Before:** ```python @@ -47,13 +47,15 @@ samples = sampling.sample1( steps=20, cfg=7.5, # ... other params ... - batched_cfg=True, # Enable batched CFG (default: True) + batched_cfg=True, # Joint cond/uncond batching (default: True) ) ``` +In the current implementation, the heavy lifting still happens in the central conditioning packing path. `batched_cfg` controls whether conditional and unconditional branches are packed together into the same forward pass when possible. Conditioning chunks within each branch are still packed by the shared batching logic. + ### When to Use -- **Always recommended** - This is a pure speed optimization with no quality tradeoff +- **Usually recommended** - This reduces duplicate cond/uncond forward passes when memory allows - Particularly beneficial for high-resolution images or batch generation - Compatible with all samplers and schedulers @@ -232,7 +234,7 @@ samples = sampling.sample1( ### Batched CFG Issues **Problem**: Memory errors with batched CFG -**Solution**: System may not have enough VRAM. Disable with `batched_cfg=False` +**Solution**: System may not have enough VRAM for joint cond/uncond batching. Disable it with `batched_cfg=False`, which keeps the conditioning path active but runs the two branches separately. ### Dynamic CFG Issues diff --git a/docs/implemented-optimizations-report.md b/docs/implemented-optimizations-report.md new file mode 100644 index 0000000..109b0a4 --- /dev/null +++ b/docs/implemented-optimizations-report.md @@ -0,0 +1,484 @@ +# Implemented Optimizations Report + +This document presents a source-based engineering report on the optimization stack used across generation, model loading, and serving in LightDiffusion-Next. + +Unlike the overview pages: + +- The source tree is treated as the primary reference point. +- Each optimization is described in terms of purpose, implementation, integration, and trade-offs. +- Supporting infrastructure and codebase groundwork are included when they materially contribute to the performance profile of the project. + +## Report Scope + +### Usage Profile Definitions + +- `default`: selected in the standard execution path +- `integrated`: part of the current generation or serving flow +- `optional`: integrated, but enabled through request settings, configuration, or model capabilities +- `conditional`: available when hardware, dependencies, or runtime capabilities allow it +- `implementation-specific`: implemented and used, but its effective behavior is shaped by a narrower internal path than the request surface alone suggests +- `infrastructure-level`: supports the fast path indirectly through loading, transfer, caching, or serving behavior +- `codebase groundwork`: implemented in the codebase as part of the optimization stack, but not yet surfaced as a broad standard pipeline option + +### What This Report Covers + +This report covers both model-level and system-level optimizations: + +- inference and sampling speedups +- precision and memory reductions +- request batching and pipeline throughput improvements +- preview and output-path latency reductions + +It does not catalog ordinary features unless they clearly reduce compute, memory, or end-to-end latency. + +## Quick Inventory + +| Optimization | Usage Profile | Main Goal | Primary Evidence | +|---|---|---|---| +| CUDA runtime tuning (TF32, cuDNN benchmark, SDPA enablement) | integrated, conditional | faster kernels and better backend selection | `src/Device/Device.py` | +| Attention backend cascade (SpargeAttn/SageAttention/xformers/SDPA) | integrated, conditional | faster attention kernels with fallback | `src/Attention/Attention.py`, `src/Attention/AttentionMethods.py` | +| Flux2 SDPA backend priority | integrated, conditional | prefer cuDNN/Flash SDPA for Flux2 attention | `src/NeuralNetwork/flux2/layers.py`, `src/Device/Device.py` | +| Cross-attention K/V projection cache | integrated | skip repeated key/value projection work for static context | `src/Attention/Attention.py` | +| Prompt embedding cache | integrated | avoid re-encoding repeated prompts | `src/Utilities/prompt_cache.py`, `src/clip/Clip.py` | +| Conditioning batch packing and memory-aware concatenation | integrated | reduce forward passes and pack compatible condition chunks | `src/cond/cond.py` | +| CFG=1 unconditional-skip fast path | integrated | skip unnecessary unconditional branch at CFG 1.0 | `src/sample/CFG.py`, `src/sample/BaseSampler.py` | +| AYS scheduler | default | reach similar quality in fewer steps | `src/sample/ays_scheduler.py`, `src/sample/ksampler_util.py` | +| CFG++ samplers | integrated | improve denoising behavior with momentum-style correction | `src/sample/BaseSampler.py` | +| CFG-Free sampling | integrated, optional | taper CFG late in sampling for better detail/naturalness | `src/sample/CFG.py` | +| Dynamic CFG rescaling | integrated, optional | reduce overshoot and saturation from strong CFG | `src/sample/CFG.py` | +| Adaptive noise scheduling | integrated, optional | adjust schedule based on observed complexity | `src/sample/CFG.py` | +| `batched_cfg` request surface | implementation-specific | request-facing control around the deeper conditioning batching path | `src/sample/sampling.py`, `src/cond/cond.py` | +| Multi-scale latent switching | integrated, optional | do some denoising at reduced spatial resolution | `src/sample/BaseSampler.py` | +| HiDiffusion MSW-MSA patching | integrated, optional | patch UNet attention for high-resolution multiscale workflows | `src/Core/Pipeline.py`, `src/hidiffusion/msw_msa_attention.py` | +| Stable-Fast | integrated, conditional | trace/compile UNet forward path | `src/StableFast/StableFast.py`, `src/Core/Pipeline.py` | +| `torch.compile` | integrated, optional | compiler-based model speedup without Stable-Fast | `src/Device/Device.py`, `src/Core/AbstractModel.py` | +| VAE compile, tiled path, and transfer tuning | integrated | speed up decode/encode and avoid OOM | `src/AutoEncoders/VariationalAE.py` | +| BF16/FP16 automatic dtype selection | integrated, conditional | reduce memory and improve throughput on supported hardware | `src/Device/Device.py` | +| FP8 weight quantization | integrated, conditional | reduce weight memory and enable Flux2-friendly inference paths | `src/Core/AbstractModel.py`, `src/Model/ModelPatcher.py` | +| NVFP4 weight quantization | integrated, optional | stronger memory reduction than FP8 | `src/Core/AbstractModel.py`, `src/Model/ModelPatcher.py`, `src/Utilities/Quantization.py` | +| Flux2 load-time weight-only quantization | integrated, conditional | keep large Flux2/Klein components workable on smaller VRAM budgets | `src/Core/Models/Flux2KleinModel.py` | +| ToMe | integrated, optional | reduce attention cost by token merging on UNet models | `src/Model/ModelPatcher.py`, `src/Core/Pipeline.py` | +| DeepCache | integrated, optional, implementation-specific | reuse prior denoiser output between update steps | `src/WaveSpeed/deepcache_nodes.py`, `src/Core/Pipeline.py` | +| First Block Cache for Flux | codebase groundwork | cache transformer work for Flux-like models | `src/WaveSpeed/first_block_cache.py` | +| Low-VRAM partial loading and offload policy | integrated | load only what fits and offload the rest | `src/cond/cond_util.py`, `src/Device/Device.py`, `src/Model/ModelPatcher.py` | +| Async transfer helpers and pinned checkpoint tensors | integrated, infrastructure-level | reduce host/device transfer overhead | `src/Device/Device.py`, `src/Utilities/util.py` | +| Request coalescing and queue batching | integrated | increase throughput across compatible API requests | `server.py` | +| Large-group chunking and image-save guardrails | integrated | keep large coalesced runs from blowing up save/decode paths | `server.py`, `src/FileManaging/ImageSaver.py` | +| Next-model prefetch | integrated | hide future checkpoint load latency | `server.py`, `src/Device/ModelCache.py`, `src/Utilities/util.py` | +| Keep-models-loaded cache | integrated | reuse loaded checkpoints and reduce warm starts | `src/Device/ModelCache.py`, `server.py` | +| In-memory PNG byte buffer | integrated | avoid disk round-trip for API responses | `src/FileManaging/ImageSaver.py`, `server.py` | +| TAESD preview pacing and preview fidelity control | integrated, conditional | reduce preview overhead while keeping live feedback usable | `src/sample/BaseSampler.py`, `src/AutoEncoders/taesd.py`, `server.py` | + +## Executive Summary + +The optimization strategy in LightDiffusion-Next is layered and cumulative rather than dependent on a single acceleration mechanism. + +1. The core generation path combines runtime kernel selection, conditioning batching, lower-precision execution, and schedule optimization. +2. Several optimizations are part of the standard execution path, most notably AYS scheduling, prompt caching, attention backend selection, low-VRAM loading policy, and server-side request grouping. +3. A second layer of optional mechanisms provides workload-specific extensions, including Stable-Fast, `torch.compile`, ToMe, multiscale sampling, quantization, and guidance refinements such as CFG-Free and dynamic rescaling. +4. The serving layer contributes materially to end-to-end throughput and latency through request coalescing, chunking, model prefetching, keep-loaded caching, and in-memory response handling. +5. The codebase also contains foundational work for additional caching paths, particularly around Flux-oriented first-block caching, alongside the currently integrated DeepCache path. + +## Runtime And Attention Optimizations + +### CUDA runtime tuning + +- Status: `integrated, conditional` +- Purpose: use faster math modes and let the backend choose more aggressive convolution and attention kernels. +- Implementation in LightDiffusion-Next: `src/Device/Device.py` enables TF32 (`torch.backends.cuda.matmul.allow_tf32`, `torch.backends.cudnn.allow_tf32`), enables cuDNN benchmarking, and turns on PyTorch math/flash/memory-efficient SDPA when available. +- Project integration: these are process-wide defaults. They do not require per-request toggles, so supported CUDA deployments get them automatically. +- Effect: reduces matmul/convolution cost and opens better SDPA backends with no extra application-layer work. +- Benefits: automatic, broad coverage, low complexity. +- Trade-offs: hardware-conditional; benefits depend on GPU generation and PyTorch build. +- Evidence: `src/Device/Device.py`. + +### Attention backend cascade: SpargeAttn, SageAttention, xformers, PyTorch SDPA + +- Status: `integrated, conditional` +- Purpose: use the fastest available attention kernel and fall back safely when unsupported. +- Implementation in LightDiffusion-Next: UNet/VAE attention chooses `SpargeAttn > SageAttention > xformers > PyTorch` in `src/Attention/Attention.py`; the concrete kernels and fallback behavior live in `src/Attention/AttentionMethods.py`. +- Project integration: the selection happens once when the attention module is imported/constructed. Sage/Sparge paths reshape inputs to HND layouts and pad unsupported head sizes to supported dimensions where possible; larger unsupported head sizes fall back. +- Effect: faster attention on supported CUDA systems without changing calling code. +- Benefits: automatic fallback chain, works across UNet cross-attention and VAE attention blocks, handles padding for awkward head sizes. +- Trade-offs: dependency- and GPU-dependent; not all head sizes stay on the fast path; behavior differs between generic UNet/VAE attention and Flux2 attention. +- Evidence: `src/Attention/Attention.py`, `src/Attention/AttentionMethods.py`. + +### Flux2 SDPA backend priority + +- Status: `integrated, conditional` +- Purpose: prefer the best PyTorch SDPA backend for Flux2 transformer attention. +- Implementation in LightDiffusion-Next: `src/Device/Device.py` builds an SDPA priority context preferring cuDNN attention, then Flash, then efficient, then math; `src/NeuralNetwork/flux2/layers.py` uses `Device.get_sdpa_context()` around `scaled_dot_product_attention`. +- Project integration: Flux2 uses a separate attention implementation from the generic UNet attention path. It first tries prioritized SDPA, then xformers, then plain SDPA. +- Effect: prioritized fast attention for Flux2 with robust fallback behavior. +- Benefits: keeps Flux2 on the most optimized native backend available; does not require custom kernels. +- Trade-offs: benefits depend heavily on PyTorch version, backend support, and GPU runtime. +- Evidence: `src/Device/Device.py`, `src/NeuralNetwork/flux2/layers.py`. + +### Cross-attention static K/V projection cache + +- Status: `integrated` +- Purpose: when the context tensor is unchanged across denoising steps, avoid recomputing K/V projections every step. +- Implementation in LightDiffusion-Next: `CrossAttention` in `src/Attention/Attention.py` keeps a small `_context_cache` keyed by `id(context)` and caches projected `k` and `v`. +- Project integration: this primarily targets prompt-conditioning cases where context is static while the latent evolves. The cache is tiny and self-pruning. +- Effect: shaves repeated linear-projection work from cross-attention-heavy denoising loops. +- Benefits: simple, training-free, no user configuration. +- Trade-offs: keyed by object identity, so it only helps when the exact context object is reused; small cache size limits reuse breadth. +- Evidence: `src/Attention/Attention.py`. + +### Prompt embedding cache + +- Status: `integrated` +- Purpose: cache text encoder outputs for repeated prompts instead of re-encoding them each time. +- Implementation in LightDiffusion-Next: `src/Utilities/prompt_cache.py` stores `(cond, pooled)` entries keyed by prompt hash and CLIP identity; `src/clip/Clip.py` checks the cache before tokenization/encoding and writes back after encode. +- Project integration: prompt caching is globally enabled by default, applies to single prompts and prompt lists, and prunes old entries once the cache exceeds its configured maximum. +- Effect: reduces prompt-side overhead in repeated-prompt workflows, especially seed sweeps and incremental prompt refinement. +- Benefits: low complexity, wired into the actual CLIP encode path, no quality trade-off. +- Trade-offs: cache size is estimate-based and global, not per-model-session aware. +- Evidence: `src/Utilities/prompt_cache.py`, `src/clip/Clip.py`, cache clear hook in `src/Core/Pipeline.py`. + +### Conditioning batch packing and CFG=1 fast path + +- Status: `integrated` +- Purpose: concatenate compatible conditioning work into fewer forward calls, and skip unconditional work entirely when CFG is effectively disabled. +- Implementation in LightDiffusion-Next: `src/cond/cond.py::calc_cond_batch()` groups compatible condition chunks by shape and memory budget, concatenates them, and falls back per chunk when transformer options mismatch. `src/sample/CFG.py` sets `uncond_ = None` when `cond_scale == 1.0` and the optimization is not disabled. +- Project integration: this path is central to the standard sampling flow. The batching logic also validates Flux-style transformer image sizes and falls back when they do not match token grids. +- Effect: fewer model invocations, better GPU utilization, and a lower-cost path for CFG=1 workloads. +- Benefits: real throughput win, memory-aware, includes safety fallback for positional/shape mismatches. +- Trade-offs: batching heuristics are shape- and memory-sensitive; fallback behavior can reduce speed when conditions diverge. +- Evidence: `src/cond/cond.py`, `src/sample/CFG.py`, `src/sample/BaseSampler.py`, `tests/unit/test_calc_cond_batch_fallback.py`. + +## Sampling And Guidance Optimizations + +### AYS scheduler + +- Status: `default` +- Purpose: use precomputed sigma schedules that spend steps where they matter most, so fewer steps can reach comparable quality. +- Implementation in LightDiffusion-Next: schedules are encoded in `src/sample/ays_scheduler.py`; `src/sample/ksampler_util.py` routes `ays`, `ays_sd15`, and `ays_sdxl` to the scheduler and auto-detects model type when possible. +- Project integration: both `server.py` and `src/user/pipeline.py` default the scheduler to `ays`. Exact schedules are used when present; otherwise the code resamples or interpolates schedules. +- Effect: fewer denoising steps for similar output quality, especially on SD1.5 and SDXL. +- Benefits: training-free, defaulted into the request path, compatible with the sampler stack. +- Trade-offs: produces different trajectories than classic schedulers; unsupported step counts use interpolation rather than paper-derived schedules. +- Evidence: `src/sample/ays_scheduler.py`, `src/sample/ksampler_util.py`, defaults in `server.py` and `src/user/pipeline.py`, benchmark usage in `tests/benchmark_performance.py`. + +### CFG++ samplers + +- Status: `integrated` +- Purpose: apply CFG++-style momentum behavior in sampler variants to improve denoising stability and quality. +- Implementation in LightDiffusion-Next: sampler registry maps `_cfgpp` sampler names to the same sampler classes, and `get_sampler()` enables `use_momentum` whenever the sampler name contains `_cfgpp`. +- Project integration: the sampler loop stores prior denoised state and applies momentum-style correction through `BaseSampler.apply_cfg()`. The server default sampler is `dpmpp_sde_cfgpp`. +- Effect: better denoising behavior than plain sampler variants without a separate post-process stage. +- Benefits: integrated directly into the sampler registry; default sampler already uses it. +- Trade-offs: only applies on `_cfgpp` variants; behavior is coupled to sampler implementation details rather than being a universal guidance layer. +- Evidence: `src/sample/BaseSampler.py`, default sampler in `server.py`. + +### CFG-Free sampling + +- Status: `integrated, optional` +- Purpose: reduce CFG late in the denoising process so the model can finish with less over-guidance. +- Implementation in LightDiffusion-Next: `CFGGuider` stores `cfg_free_enabled` and `cfg_free_start_percent`, tracks current sigma position, and progressively reduces `self.cfg` once the configured progress threshold is crossed. +- Project integration: the flag is part of the request/context surface and is forwarded by SD1.5, SDXL, Flux2, HiResFix, and Img2Img code paths. +- Effect: potentially better detail recovery and more natural late-stage refinement. +- Benefits: integrated and actually wired through multiple pipelines; easy to combine with the rest of the sampler stack. +- Trade-offs: quality optimization rather than pure speedup; exact effect is prompt- and sampler-dependent. +- Evidence: `src/sample/CFG.py`, `src/Core/Models/SD15Model.py`, `src/Core/Models/SDXLModel.py`, `src/Core/Models/Flux2KleinModel.py`, `src/Processors/HiresFix.py`, `src/Processors/Img2Img.py`. + +### Dynamic CFG rescaling + +- Status: `integrated, optional` +- Purpose: reduce effective CFG when the guidance delta becomes too strong. +- Implementation in LightDiffusion-Next: `CFGGuider._apply_dynamic_cfg_rescaling()` computes either a variance-based or range-based adjustment and clamps the result. +- Project integration: it runs inside `cfg_function()` before CFG mixing is finalized, so it affects the real denoising path rather than acting as a post-hoc metric. +- Effect: reduces oversaturation and over-guided outputs for high-CFG workloads. +- Benefits: low incremental overhead and direct integration into CFG computation. +- Trade-offs: not a pure speed optimization; the chosen formulas are heuristic and can flatten outputs if pushed too hard. +- Evidence: `src/sample/CFG.py`. + +### Adaptive noise scheduling + +- Status: `integrated, optional` +- Purpose: use observed prediction complexity to perturb the sigma schedule during sampling. +- Implementation in LightDiffusion-Next: `CFGGuider` records complexity history during prediction and scales `sigmas` inside `inner_sample()` if adaptive mode is enabled. +- Project integration: complexity can be estimated with a spatial-difference metric or variance-like behavior, depending on the selected method. +- Effect: attempts to spend effort where the current prediction appears more complex. +- Benefits: implemented end-to-end in the guider. +- Trade-offs: heuristic, can alter reproducibility, and its benefit is much less established in this repo than AYS or request coalescing. +- Evidence: `src/sample/CFG.py`. + +### `batched_cfg` request surface + +- Status: `implementation-specific` +- Purpose: expose control over conditional/unconditional batching. +- Implementation in LightDiffusion-Next: the field exists in the request and context models and is passed into sampling, where it is stored in `model_options["batched_cfg"]`. +- Project integration: the main batching behavior is centered in `calc_cond_batch()`, while `batched_cfg` is carried through `model_options` as part of the request-side control surface around that path. +- Effect: provides a request-facing handle for a batching path whose heavy lifting is performed centrally in conditioning packing. +- Benefits: fits cleanly into the existing request and sampling pipeline. +- Trade-offs: its effect is indirect because the main concatenation behavior is implemented deeper in the conditioning layer. +- Evidence: `src/sample/sampling.py`, `src/Core/Context.py`, `src/cond/cond.py`. + +## Multiscale And Architecture-Specific Optimizations + +### Multi-scale latent switching + +- Status: `integrated, optional` +- Purpose: run some denoising steps at a downscaled latent resolution and return to full resolution for selected steps. +- Implementation in LightDiffusion-Next: `MultiscaleManager` in `src/sample/BaseSampler.py` computes a per-step full-resolution schedule and uses bilinear downscale/upscale around sampler model calls. +- Project integration: the samplers consult `ms.use_fullres(i)` each step. Flux and Flux2 are explicitly excluded because the code treats multiscale as incompatible with DiT-style architectures. +- Effect: lower compute on some denoising steps for compatible samplers and architectures. +- Benefits: actually participates in the sampler loop; configurable by factor and schedule. +- Trade-offs: it necessarily changes the denoising path and can trade detail for speed; not available for Flux/Flux2. +- Evidence: `src/sample/BaseSampler.py`, `src/sample/sampling.py`, `src/Core/Models/Flux2KleinModel.py`. + +### HiDiffusion MSW-MSA patching + +- Status: `integrated, optional` +- Purpose: patch UNet attention for high-resolution workflows using HiDiffusion-style MSW-MSA attention changes. +- Implementation in LightDiffusion-Next: the pipeline clones the inner model and applies `ApplyMSWMSAAttentionSimple` when multiscale is enabled on UNet architectures. +- Project integration: the patch is explicitly blocked for Flux/Flux2 and disabled in some sub-pipelines like refiner or certain detail passes where the project wants to avoid artifact risk. +- Effect: makes the multiscale/high-resolution path more efficient or more stable on SD1.5/SDXL-style UNets. +- Benefits: architecture-aware and guarded against obvious misuse. +- Trade-offs: not universal; adds another patching layer and can be brittle if architecture assumptions drift. +- Evidence: `src/Core/Pipeline.py`, `src/hidiffusion/msw_msa_attention.py`, `src/Core/AbstractModel.py`, `src/Core/Models/SD15Model.py`, `src/Core/Models/SDXLModel.py`. + +## Model Compilation, Precision, And Memory Optimizations + +### Stable-Fast + +- Status: `integrated, conditional` +- Purpose: trace and wrap UNet execution to reduce Python overhead and optionally use CUDA graph behavior. +- Implementation in LightDiffusion-Next: `src/StableFast/StableFast.py` builds a lazy trace module around the model function and stores compiled modules in a cache keyed by converted kwargs; `Pipeline._apply_optimizations()` applies it when `stable_fast` is enabled. +- Project integration: only model types that advertise `supports_stable_fast=True` can use it. Flux2 explicitly opts out at the capability layer. +- Effect: faster repeated UNet execution when the optional `sfast` dependency is present and shapes stay compatible enough for compilation reuse. +- Benefits: capability-gated, optional dependency handled defensively, integrated into the core optimization application phase. +- Trade-offs: dependency-sensitive, compilation overhead can dominate short runs, CUDA graph behavior is less flexible. +- Evidence: `src/StableFast/StableFast.py`, `src/Core/Pipeline.py`, `src/Core/Models/SD15Model.py`, `src/Core/Models/SDXLModel.py`, `src/Core/Models/Flux2KleinModel.py`. + +### `torch.compile` + +- Status: `integrated, optional` +- Purpose: rely on PyTorch compiler paths instead of Stable-Fast. +- Implementation in LightDiffusion-Next: `src/Device/Device.py::compile_model()` defaults to `max-autotune-no-cudagraphs`; `src/Core/AbstractModel.py::apply_torch_compile()` applies it to the top-level module or diffusion submodule when possible. +- Project integration: the optimization is mutually exclusive with Stable-Fast in the main pipeline. +- Effect: compiler-based speedups with a safer default mode than more fragile CUDA-graph-heavy settings. +- Benefits: built on standard PyTorch, tested for safe default mode. +- Trade-offs: compiler behavior is environment-dependent; still vulnerable to dynamic-shape and dynamic-state limitations. +- Evidence: `src/Device/Device.py`, `src/Core/AbstractModel.py`, `src/Core/Pipeline.py`, `tests/unit/test_fp8_compile.py`. + +### VAE compile, tiled path, and transfer tuning + +- Status: `integrated` +- Purpose: speed up VAE encode/decode, reduce overhead, and avoid OOM by choosing tiled or batched paths. +- Implementation in LightDiffusion-Next: `VariationalAE.VAE` compiles the decoder on first use, runs decode/encode under `torch.inference_mode()`, uses channels-last where useful, chooses tiled fallback when memory is tight, and uses non-blocking transfers. +- Project integration: this is automatic. Callers do not opt in. +- Effect: faster VAE stages, less repeated Python/autograd overhead, and better robustness under constrained memory. +- Benefits: always enabled and directly applied in the decode and encode hot path. +- Trade-offs: decoder compile still depends on `torch.compile` availability; tiling adds complexity and can affect throughput at small sizes. +- Evidence: `src/AutoEncoders/VariationalAE.py`. + +### BF16/FP16 automatic dtype selection + +- Status: `integrated, conditional` +- Purpose: pick a lower-precision working dtype that matches the hardware and model constraints. +- Implementation in LightDiffusion-Next: `src/Device/Device.py` contains the dtype selection logic for UNet, text encoder, and VAE devices/dtypes, including bf16 support checks and fallback rules. +- Project integration: loaders and patchers consult these helpers when deciding how to instantiate and place components. +- Effect: reduced memory footprint and better arithmetic throughput on modern hardware. +- Benefits: broad, centralized policy. +- Trade-offs: heuristic; wrong hardware assumptions can reduce numerical stability or disable a faster path. +- Evidence: `src/Device/Device.py`, `src/Model/ModelPatcher.py`, `src/FileManaging/Loader.py`. + +### FP8 weight quantization + +- Status: `integrated, conditional` +- Purpose: store weights in FP8 while casting them back to the input dtype during execution. +- Implementation in LightDiffusion-Next: `AbstractModel.apply_fp8()` hardware-gates support using `Device.is_fp8_supported()`, rewrites eligible weights to FP8, and enables runtime cast behavior on `CastWeightBiasOp` modules. The lower-level `ModelPatcher.weight_only_quantize()` also supports FP8-style quantization. +- Project integration: it is available through generation settings and also used in Flux2 load paths when appropriate. +- Effect: lower model weight memory with an execution path that avoids dtype-mismatch crashes. +- Benefits: tested explicitly, integrates with cast-aware modules, useful for large models. +- Trade-offs: hardware-gated; quality/performance trade-offs depend on model and layer mix. +- Evidence: `src/Core/AbstractModel.py`, `src/Device/Device.py`, `src/Model/ModelPatcher.py`, `tests/unit/test_fp8_compile.py`. + +### NVFP4 weight quantization + +- Status: `integrated, optional` +- Purpose: use a more aggressive 4-bit weight-only format to reduce memory further than FP8. +- Implementation in LightDiffusion-Next: both `AbstractModel.apply_nvfp4()` and `ModelPatcher.weight_only_quantize("nvfp4")` quantize supported weights, store scale buffers, and enable runtime casting/dequantization. +- Project integration: the quantization path is used most clearly in Flux2/Klein loading, but the abstract model path also exists for supported models. +- Effect: significant memory reduction at the cost of more aggressive approximation. +- Benefits: strongest memory reduction path in the repo. +- Trade-offs: more invasive than FP8, more likely to affect quality, and only applies to some weight shapes. +- Evidence: `src/Core/AbstractModel.py`, `src/Model/ModelPatcher.py`, `src/Utilities/Quantization.py`, `tests/test_nvfp4.py`, `tests/test_nvfp4_integration.py`. + +### Flux2 load-time weight-only quantization + +- Status: `integrated, conditional` +- Purpose: automatically quantize large Flux2 diffusion and Klein text encoder weights during loading when the configuration or hardware path calls for it. +- Implementation in LightDiffusion-Next: `Flux2KleinModel.load()` selects a quantization format and applies weight-only quantization to the diffusion model; `_load_klein_text_encoder()` applies the same idea to the text encoder before offloading it back to CPU. +- Project integration: Flux2 is the clearest example in the codebase where quantization is implemented as a first-class loading strategy rather than as a generic capability alone. +- Effect: keeps a large Flux2/Klein stack usable on lower-VRAM systems than an uncompressed load would allow. +- Benefits: integrated, architecture-specific, and directly aligned with large-model VRAM constraints. +- Trade-offs: tightly coupled to Flux2/Klein assumptions; not equivalent to a universally available quantized-mode toggle. +- Evidence: `src/Core/Models/Flux2KleinModel.py`. + +### ToMe + +- Status: `integrated, optional` +- Purpose: merge similar tokens to reduce attention workload in UNet-based models. +- Implementation in LightDiffusion-Next: `ModelPatcher.apply_tome()` applies and removes `tomesd` patches; `Pipeline._apply_optimizations()` applies it only when the model capabilities allow it. +- Project integration: SD1.5 and SDXL advertise `supports_tome=True`; Flux2 advertises `False`. +- Effect: lower attention cost on supported UNet models, particularly at higher token counts. +- Benefits: explicitly capability-gated, integrated into the core optimization phase. +- Trade-offs: optional dependency, UNet-only in current practice, and quality can soften if pushed too aggressively. +- Evidence: `src/Model/ModelPatcher.py`, `src/Core/Pipeline.py`, capability declarations in `src/Core/Models/*`, `tests/unit/test_tome_fix.py`. + +### DeepCache + +- Status: `integrated, optional, implementation-specific` +- Purpose: reuse work across denoising steps rather than running a full forward pass every time. +- Implementation in LightDiffusion-Next: `ApplyDeepCacheOnModel.patch()` clones the model and wraps its UNet function. On cache-update steps it runs the model normally and stores the output; on reuse steps it returns the cached output directly. +- Project integration: the main pipeline applies it from `_apply_optimizations()` when `deepcache_enabled` is true and the model advertises support. +- Effect: fewer full model computations on reuse steps, trading some fidelity for speed. +- Benefits: live integrated path, simple integration model, and capability gating. +- Trade-offs: the implementation works at whole-output reuse granularity rather than a finer-grained internal block reuse strategy, so its speed/fidelity profile is comparatively coarse. +- Evidence: `src/WaveSpeed/deepcache_nodes.py`, `src/Core/Pipeline.py`, `src/Core/AbstractModel.py`, `src/Core/Models/SD15Model.py`, `src/Core/Models/SDXLModel.py`, `tests/test_core_functionalities.py`. + +### First Block Cache for Flux + +- Status: `codebase groundwork` +- Purpose: cache downstream transformer work when the first-block residual indicates the state has not changed much. +- Implementation in LightDiffusion-Next: `src/WaveSpeed/first_block_cache.py` contains cache contexts and patch builders for both UNet-like and Flux-like forward paths. +- Project integration: the module provides the machinery for a Flux-oriented first-block caching path. In the current project flow, the directly surfaced caching path is DeepCache, while this module remains groundwork for a more specialized integration. +- Effect: establishes the components needed for a transformer-oriented cache path in the codebase. +- Benefits: nontrivial implementation foundation already exists. +- Trade-offs: it is not yet surfaced as a broad standard option in the same way as the main integrated optimizations. +- Evidence: `src/WaveSpeed/first_block_cache.py`. + +## Memory Management And Serving Optimizations + +### Low-VRAM partial loading and offload policy + +- Status: `integrated` +- Purpose: keep only the amount of model state in VRAM that current free memory allows, offloading the rest. +- Implementation in LightDiffusion-Next: `cond_util.prepare_sampling()` calls `Device.load_models_gpu(..., force_full_load=False)`; `Device.load_models_gpu()` computes low-VRAM budgets and delegates partial loading to `ModelPatcher.patch_model_lowvram()` and `partially_load()`. +- Project integration: this is a core loading behavior, not a side option. Text encoder and VAE also have explicit offload-device helpers. +- Effect: keeps generation viable on limited VRAM systems and reduces full reload pressure. +- Benefits: central to memory behavior in constrained environments, architecture-aware, and tied into checkpoint, text encoder, and VAE device policy. +- Trade-offs: more complex state management; partial loading can increase latency and complicate debugging. +- Evidence: `src/cond/cond_util.py`, `src/Device/Device.py`, `src/Model/ModelPatcher.py`. + +### Async transfer helpers and pinned checkpoint tensors + +- Status: `integrated, infrastructure-level` +- Purpose: reduce CPU<->GPU transfer cost with asynchronous copies, streams, and pinned host memory. +- Implementation in LightDiffusion-Next: `Device.cast_to()` can issue transfers on offload streams; checkpoint tensors are pinned on CUDA loads in `util.load_torch_file()`; VAE encode/decode uses non-blocking transfers. +- Project integration: these mechanisms appear most clearly in checkpoint loading, model movement, and VAE data flow. Some parts act as general transfer infrastructure rather than as a single user-facing optimization toggle. +- Effect: faster host/device movement and less transfer-induced stalling in hot paths that actually use the helpers. +- Benefits: useful on CUDA systems, especially during model load and VAE stages. +- Trade-offs: integration is uneven; some helper functions look broader than their current call footprint. +- Evidence: `src/Device/Device.py`, `src/Utilities/util.py`, `src/AutoEncoders/VariationalAE.py`. + +### Request coalescing and queue batching + +- Status: `integrated` +- Purpose: batch compatible API requests together so the backend does fewer larger pipeline invocations. +- Implementation in LightDiffusion-Next: `server.py::GenerationBuffer` groups pending requests by a signature that includes model, size, scheduler, sampler, steps, multiscale settings, and other batch-level properties. +- Project integration: the worker chooses the oldest eligible group, optionally waits for more arrivals, flattens per-request samples into one pipeline call, and later remaps saved results back to request futures. +- Effect: better throughput and GPU utilization for concurrent API use. +- Benefits: real server-level optimization, clearly implemented, includes observability-oriented logs. +- Trade-offs: requires careful grouping keys; incompatible request options fragment batching opportunities. +- Evidence: `server.py`. + +### Singleton policy, large-group chunking, and image-save guardrails + +- Status: `integrated` +- Purpose: prevent batching from hurting latency for lone requests, and prevent oversized coalesced batches from exploding decode/save paths. +- Implementation in LightDiffusion-Next: `LD_BATCH_WAIT_SINGLETONS` controls whether singletons wait; `LD_MAX_IMAGES_PER_GROUP` and `ImageSaver.MAX_IMAGES_PER_SAVE` drive chunking; large groups are split into smaller sequential pipeline runs. +- Project integration: the server keeps the coalescing optimization from turning into pathological giant save/decode operations, and tests cover the chunking behavior. +- Effect: better tail latency for single requests and more stable handling of large batched workloads. +- Benefits: directly addresses operational failure modes in large batched workloads. +- Trade-offs: chunking reduces some batching benefits; many environment variables affect behavior. +- Evidence: `server.py`, `src/FileManaging/ImageSaver.py`, `tests/unit/test_generation_buffer_chunking.py`, `docs/quirks.md`. + +### Next-model prefetch + +- Status: `integrated` +- Purpose: while one batch is running, read the next checkpoint into CPU RAM if the queued next batch needs a different model. +- Implementation in LightDiffusion-Next: `GenerationBuffer._look_ahead_and_prefetch()` resolves the next checkpoint, loads it via `util.load_torch_file()` on a background task, and stores it in `ModelCache` as a prefetched state dict. +- Project integration: the next load can reuse the prefetched state dict through `util.load_torch_file()` before the cache entry is cleared. +- Effect: overlaps some future checkpoint load cost with current generation work. +- Benefits: server-side latency hiding with minimal interface impact. +- Trade-offs: only helps when queued work is predictable; increases CPU RAM usage. +- Evidence: `server.py`, `src/Device/ModelCache.py`, `src/Utilities/util.py`. + +### Keep-models-loaded cache + +- Status: `integrated` +- Purpose: keep recently used checkpoints and sampling models resident instead of cleaning them up after every request. +- Implementation in LightDiffusion-Next: `ModelCache` stores checkpoints, TAESD models, sampling models, and the keep-loaded policy; `server.py` temporarily applies the request's `keep_models_loaded` directive for a group. +- Project integration: when enabled, main models are retained and only auxiliary control models are cleaned up aggressively. +- Effect: lower warm-start cost between related generations and less repetitive reload churn. +- Benefits: simple end-user behavior for a meaningful latency/memory trade-off. +- Trade-offs: consumes more VRAM/RAM; can make memory pressure less predictable on multi-user servers. +- Evidence: `src/Device/ModelCache.py`, `server.py`. + +### In-memory PNG byte buffer + +- Status: `integrated` +- Purpose: return API images from memory instead of reading them back from disk after save. +- Implementation in LightDiffusion-Next: `ImageSaver` can store encoded PNG bytes in `_image_bytes_buffer`; `server.py` first calls `pop_image_bytes()` when fulfilling request futures. +- Project integration: batched pipeline runs can still save images normally while the API path avoids a disk round-trip for the response payload. +- Effect: lower response latency and less unnecessary disk I/O for served images. +- Benefits: directly reduces response-path disk I/O in API-serving scenarios. +- Trade-offs: consumes temporary RAM; only helps when the buffer path is actually populated. +- Evidence: `src/FileManaging/ImageSaver.py`, `server.py`. + +### TAESD preview pacing and preview fidelity control + +- Status: `integrated, conditional` +- Purpose: keep live previews useful without letting preview generation dominate sampling time. +- Implementation in LightDiffusion-Next: `SamplerCallback` caches preview settings, only triggers previews at a coarse interval, and runs preview work on a background thread; the server also applies per-request preview fidelity presets (`low`, `balanced`, `high`). +- Project integration: previews are generated only when previewing is enabled, and the preview cadence is adaptive to total step count. +- Effect: live feedback with bounded preview overhead. +- Benefits: explicit pacing, non-blocking thread model, request-level fidelity override. +- Trade-offs: still extra work during sampling; fidelity presets are intentionally coarse. +- Evidence: `src/sample/BaseSampler.py`, `src/AutoEncoders/taesd.py`, `server.py`, preview tests under `tests/e2e` and `tests/integration/api`. + +## Integration Notes + +These notes highlight how several optimizations are currently integrated and used inside the project. + +### 1. Flux-oriented first block caching + +- The codebase contains a dedicated `src/WaveSpeed/first_block_cache.py` module with cache contexts and patch builders for Flux-oriented paths. +- In the current optimization stack, the directly surfaced caching path is DeepCache, while First Block Cache remains implementation groundwork for a more specialized integration. +- This establishes the core components for a transformer-oriented cache path even though it is not yet surfaced as a primary standard option. + +### 2. DeepCache reuse granularity + +- DeepCache is integrated through `src/WaveSpeed/deepcache_nodes.py` and is applied from the main pipeline when enabled. +- In this project, it works by reusing prior denoiser outputs on designated reuse steps. +- This yields a clear speed-fidelity profile based on output reuse rather than on finer-grained internal block caching. + +### 3. Conditioning batching control + +- Conditioning batching is centered in `src/cond/cond.py::calc_cond_batch()`, where compatible condition chunks are packed and concatenated. +- The `batched_cfg` request field participates as request-side control metadata around this behavior. +- In operation, the batching outcome is therefore shaped mainly by the central conditioning logic rather than by a standalone external switch. + +### 4. GPU attention backend selection + +- Attention backend selection is hardware- and build-aware, with the runtime choosing among SpargeAttn, SageAttention, xformers, and PyTorch SDPA based on capability checks. +- The exact backend used in practice therefore depends on the active GPU generation, dependencies, and runtime configuration. +- Backend acceleration is therefore largely automatic from the user perspective while remaining environment-specific in implementation. + +### 5. Prompt cache behavior + +- Prompt caching is implemented as a global dict-backed cache keyed by prompt hash and CLIP identity. +- The cache prunes old entries once it exceeds its configured size threshold. +- In operation, it primarily benefits repeated-prompt workflows such as seed sweeps and prompt iteration. + +## Conclusion + +LightDiffusion-Next uses a layered optimization strategy spanning runtime kernels, scheduling, guidance logic, precision and memory control, model patching, and server-side throughput management. + +- The core operational stack is built around AYS scheduling, attention backend selection, conditioning batching, low-VRAM loading policy, prompt caching, VAE tuning, and request coalescing. +- Optional paths such as Stable-Fast, `torch.compile`, ToMe, DeepCache, multiscale sampling, and quantization extend that stack for specific hardware targets, model families, and workload profiles. +- The serving layer is a first-class component of the performance model, with batching, chunking, prefetching, keep-loaded caches, and in-memory responses contributing directly to end-to-end latency and throughput. diff --git a/docs/optimizations.md b/docs/optimizations.md index 7cf8044..c1f704e 100644 --- a/docs/optimizations.md +++ b/docs/optimizations.md @@ -2,6 +2,8 @@ LightDiffusion-Next achieves its industry-leading inference speed through a layered stack of training-free optimizations that can be selectively enabled based on your hardware and quality requirements. This page provides an overview of each acceleration technique and links to detailed guides. +For a detailed source-based report on what is implemented today, including server-side throughput optimizations and practical implementation notes, see the [Implemented Optimizations Report](implemented-optimizations-report.md). + ## Optimization Stack Overview The pipeline orchestrates six primary acceleration paths: @@ -113,10 +115,10 @@ Multi-Scale Diffusion optimizes performance by processing images at multiple res ### WaveSpeed Caching -**What it does:** Exploits temporal redundancy in diffusion processes by caching high-level features in the UNet/Transformer architecture and reusing them across multiple denoising steps. Includes two strategies: +**What it does:** Exploits temporal redundancy in diffusion processes by reusing work across denoising steps. In the current project stack this primarily means DeepCache on supported UNet models, with additional Flux-oriented cache groundwork present in the codebase. -1. **DeepCache** — Caches middle/output block activations in UNet models (SD1.5, SDXL) -2. **First Block Cache (FBCache)** — Caches initial Transformer block outputs in Flux models +1. **DeepCache** — Reuses prior denoiser outputs on selected steps in UNet models (SD1.5, SDXL) +2. **First Block Cache (FBCache)** — Flux-oriented cache machinery available for specialized integration work **When to use:** - Any workflow where you can tolerate slight smoothing in exchange for 2-3x speedup @@ -177,9 +179,9 @@ steps: 10 # Reduced from 15 (same quality with AYS) stable_fast: false # not supported sageattention: auto prompt_cache_enabled: true -fbcache: +deepcache: enabled: true - residual_threshold: 0.01 # strict caching + interval: 2 ``` **Expected:** ~2x speedup with minimal quality impact diff --git a/docs/Prompt-caching.md b/docs/prompt-caching.md similarity index 73% rename from docs/Prompt-caching.md rename to docs/prompt-caching.md index d5f1c4a..e29e34b 100644 --- a/docs/Prompt-caching.md +++ b/docs/prompt-caching.md @@ -1,4 +1,4 @@ -## 1. Prompt Attention Caching +# Prompt Attention Caching ### What It Does @@ -29,10 +29,11 @@ print(f"Hit rate: {stats['hit_rate']:.1%}") ``` **Cache Settings**: -- Maximum entries: 128 prompts -- Memory usage: ~50-200MB -- Cache cleared on: restart or manual clear -- Automatic pruning: removes oldest 25% when full +- Maximum entries: 256 prompts before pruning +- Cache structure: global dict keyed by prompt hash and CLIP identity +- Memory usage: workload-dependent, estimated from cached embedding tensors +- Cache cleared on: restart, disable, or manual clear +- Automatic pruning: removes the oldest 25% of entries when the cache exceeds its limit ### Viewing Cache Stats @@ -60,3 +61,4 @@ prompt_cache.print_cache_stats() 2. **Monitor hit rate** - should be >50% in typical workflows 3. **Clear cache** when switching models or major prompt changes 4. **Batch similar prompts** to maximize cache hits +5. **Expect global behavior** because the cache is shared across repeated prompt encodes rather than being scoped to a single generation session diff --git a/docs/sageattention.md b/docs/sageattention.md index eff8e07..0679939 100644 --- a/docs/sageattention.md +++ b/docs/sageattention.md @@ -94,7 +94,7 @@ python -c "import spas_sage_attn; print('SpargeAttn installed successfully')" | RTX 4060/4070/4080/4090 | 8.9 | `"8.9"` | | A100 | 8.0 | `"8.0"` | | H100 | 9.0 | `"9.0"` | -| RTX 5060/5070/5080/5090 | 12.0 | Not supported yet | +| RTX 5060/5070/5080/5090 | 12.0 | SageAttention supported, SpargeAttn pending | ### Docker Installation diff --git a/docs/wavespeed.md b/docs/wavespeed.md index c2519c7..124e84a 100644 --- a/docs/wavespeed.md +++ b/docs/wavespeed.md @@ -2,14 +2,14 @@ ## Overview -WaveSpeed is a collection of **feature caching strategies** that exploit temporal redundancy in diffusion processes. By reusing high-level features across multiple denoising steps, WaveSpeed can provide significant speedup with tunable quality trade-offs. +WaveSpeed is the project's caching-oriented optimization layer for reusing work across denoising steps. In the current codebase, the integrated path is DeepCache for UNet-based models, and the repository also contains groundwork for a Flux-oriented First Block Cache path. -LightDiffusion-Next implements two WaveSpeed variants: +LightDiffusion-Next contains two WaveSpeed-related implementations: -1. **DeepCache** — For UNet-based models (SD1.5, SDXL) -2. **First Block Cache (FBCache)** — For Transformer-based models (Flux) +1. **DeepCache** — Integrated for UNet-based models (SD1.5, SDXL) +2. **First Block Cache (FBCache)** — Flux-oriented cache machinery present in the codebase -Both are **training-free**, work alongside other optimizations and can be toggled per-generation. +Both are training-free. DeepCache is the user-facing path today; First Block Cache is codebase groundwork for a more specialized transformer caching path. ## How It Works @@ -20,36 +20,25 @@ Diffusion models denoise images iteratively over 20-50 steps. Researchers observ - **High-level features** (semantic structure, composition) change slowly across steps - **Low-level features** (fine details, textures) require frequent updates -WaveSpeed caches the expensive high-level computations and reuses them for several steps, only updating low-level details cheaply. +WaveSpeed aims to reduce repeated computation across nearby denoising steps by reusing information from earlier steps where practical. ### DeepCache (UNet Models) {#deepcache} -DeepCache targets the middle and output blocks of the UNet architecture: - -``` -┌─────────────────────────────────────────┐ -│ Input Blocks (always computed) │ -├─────────────────────────────────────────┤ -│ Middle Blocks (cached every N steps) │ ← DeepCache caching zone -├─────────────────────────────────────────┤ -│ Output Blocks (cached every N steps) │ ← DeepCache caching zone -└─────────────────────────────────────────┘ -``` +DeepCache is the integrated WaveSpeed path for UNet models. **Cache step (every N steps):** -1. Run full forward pass through all UNet blocks -2. Store middle/output block activations in cache +1. Run the full denoiser path +2. Store the output for later reuse -**Reuse step (N-1 intermediate steps):** -1. Run only input blocks -2. Retrieve cached middle/output activations -3. Skip expensive middle/output block computation +**Reuse step (intermediate steps):** +1. Reuse the cached denoiser output +2. Skip the full model recomputation for that step **Speedup:** ~50-70% time saved per reuse step → 2-3x total speedup with `interval=3` ### First Block Cache (Flux Models) -Flux uses Transformer blocks instead of UNet convolutions. FBCache applies a similar principle: +Flux uses Transformer blocks instead of UNet convolutions. The repository includes a First Block Cache implementation for this architecture family: ``` ┌─────────────────────────────────────────┐ @@ -65,7 +54,7 @@ Flux uses Transformer blocks instead of UNet convolutions. FBCache applies a sim 3. If difference < threshold: reuse cached remaining blocks 4. If difference ≥ threshold: run all blocks and update cache -**Adaptive caching:** Automatically decides when to cache vs. recompute based on feature similarity. +In the current project structure, this cache path is implementation groundwork rather than a standard generation toggle like DeepCache. ## DeepCache Configuration @@ -160,7 +149,7 @@ end_step: 800 ### Usage -FBCache is applied automatically when generating Flux images. No UI controls yet — configured via pipeline code: +First Block Cache is not currently exposed as a standard per-generation toggle. The implementation is available in the codebase for specialized integration work: ```python # In src/user/pipeline.py @@ -169,7 +158,7 @@ from src.WaveSpeed import fbcache_nodes # Create cache context cache_context = fbcache_nodes.create_cache_context() -# Apply caching to Flux model +# Apply caching to a Flux-style model with fbcache_nodes.cache_context(cache_context): patched_model = fbcache_nodes.create_patch_flux_forward_orig( flux_model, @@ -196,7 +185,7 @@ Speedup scales with cache interval and depth: | SD1.5 | 3 | Good speedup, slight quality loss | | SD1.5 | 5 | High speedup, noticeable quality loss | | SDXL | 3 | Good speedup, slight quality loss | -| Flux (FBCache) | auto | Moderate speedup, minimal quality loss | +| Flux-style caching paths | implementation-specific | Depends on the integration path | **Performance varies based on:** - GPU architecture diff --git a/mkdocs.yml b/mkdocs.yml index e2980c2..3b570f6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - REST API: api.md - Performance Optimizations: - Overview: optimizations.md + - Implementation Report: implemented-optimizations-report.md - CFG-Free Sampling: cfg-free-sampling.md - Token Merging (ToMe): tome.md - SageAttention & SpargeAttn: sageattention.md diff --git a/src/Device/Device.py b/src/Device/Device.py index 40b7527..69f6c7e 100644 --- a/src/Device/Device.py +++ b/src/Device/Device.py @@ -853,12 +853,6 @@ def get_autocast_device(dev) -> str: def sageattention_enabled() -> bool: if cpu_state != CPUState.GPU or is_intel_xpu() or directml_enabled or is_rocm(): return False - if torch.cuda.is_available(): - try: - if torch.cuda.get_device_capability()[0] >= 12: - return False - except: - pass return SAGEATTENTION_IS_AVAILABLE diff --git a/src/Model/ModelPatcher.py b/src/Model/ModelPatcher.py index cc52a12..257d237 100644 --- a/src/Model/ModelPatcher.py +++ b/src/Model/ModelPatcher.py @@ -47,6 +47,56 @@ def __call__(self, weight: torch.Tensor) -> torch.Tensor: return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key) +class ModelFunctionWrapperChain: + """Compose multiple model_function_wrapper hooks without overwriting them. + + Several optimizations patch the same U-Net wrapper hook. Keeping only the + last wrapper silently disables earlier optimizations. This chain preserves + application order by making the most recently-added wrapper the outermost + wrapper around the existing stack. + """ + + def __init__(self, wrappers=None): + self.wrappers = list(wrappers or []) + + def add_outer(self, wrapper): + self.wrappers.insert(0, wrapper) + return self + + def __call__(self, model_function, params): + return self._invoke(0, model_function, params) + + def _invoke(self, index, model_function, params): + if index >= len(self.wrappers): + return model_function( + params["input"], + params["timestep"], + **params.get("c", {}), + ) + + wrapper = self.wrappers[index] + + def next_model_function(input_x, timestep, **c_kwargs): + next_params = dict(params) + next_params["input"] = input_x + next_params["timestep"] = timestep + next_params["c"] = c_kwargs + return self._invoke(index + 1, model_function, next_params) + + return wrapper(next_model_function, params) + + def to(self, device): + updated = [] + for wrapper in self.wrappers: + if hasattr(wrapper, "to"): + moved = wrapper.to(device) + updated.append(moved if moved is not None else wrapper) + else: + updated.append(wrapper) + self.wrappers = updated + return self + + class ModelPatcher: def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size: int = 0, current_device: torch.device = None, weight_inplace_update: bool = False): @@ -91,7 +141,17 @@ def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) def set_model_unet_function_wrapper(self, f): - self.model_options["model_function_wrapper"] = f + existing = self.model_options.get("model_function_wrapper") + if existing is None: + self.model_options["model_function_wrapper"] = f + return + + if isinstance(existing, ModelFunctionWrapperChain): + existing.add_outer(f) + self.model_options["model_function_wrapper"] = existing + return + + self.model_options["model_function_wrapper"] = ModelFunctionWrapperChain([f, existing]) def set_model_denoise_mask_function(self, f): self.model_options["denoise_mask_function"] = f diff --git a/src/cond/cond.py b/src/cond/cond.py index 577e344..d97506e 100644 --- a/src/cond/cond.py +++ b/src/cond/cond.py @@ -119,6 +119,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list: out_conds = [torch.zeros_like(x_in) for _ in range(len(conds))] out_counts = [torch.ones_like(x_in) * 1e-37 for _ in range(len(conds))] to_run = [] + batched_cfg = model_options.get("batched_cfg", True) for i, cond in enumerate(conds): if cond is not None: @@ -130,9 +131,15 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list: while to_run: first = to_run[0] first_shape = first[0][0].shape + first_cond_index = first[1] # Find compatible conditions - to_batch_temp = [x for x in range(len(to_run)) if cond_util.can_concat_cond(to_run[x][0], first[0])] + to_batch_temp = [ + x + for x in range(len(to_run)) + if cond_util.can_concat_cond(to_run[x][0], first[0]) + and (batched_cfg or to_run[x][1] == first_cond_index) + ] to_batch_temp.reverse() to_batch = to_batch_temp[:1] diff --git a/tests/unit/test_calc_cond_batch_fallback.py b/tests/unit/test_calc_cond_batch_fallback.py index 1ae97f5..6f00faa 100644 --- a/tests/unit/test_calc_cond_batch_fallback.py +++ b/tests/unit/test_calc_cond_batch_fallback.py @@ -16,6 +16,16 @@ def apply_model(self, *args, **kwargs): return inp +class RecordingDummyModel(DummyModel): + def __init__(self): + self.batch_sizes = [] + + def apply_model(self, *args, **kwargs): + inp = args[0] if args else kwargs.get("input") + self.batch_sizes.append(int(inp.shape[0])) + return inp + + def test_calc_cond_batch_fallback_on_transformer_options_mismatch(monkeypatch): called = {"flag": False} @@ -45,3 +55,30 @@ def spy_run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_s assert isinstance(out, list) and len(out) == 2 assert out[0].shape == x_in.shape assert out[1].shape == x_in.shape + + +def test_calc_cond_batch_honors_batched_cfg_toggle(): + x_in = torch.zeros((1, 4, 8, 8)) + cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}} + conds = [[cond_dict], [cond_dict]] + + batched_model = RecordingDummyModel() + calc_cond_batch( + batched_model, + conds, + x_in, + timestep=0, + model_options={"batched_cfg": True}, + ) + + unbatched_model = RecordingDummyModel() + calc_cond_batch( + unbatched_model, + conds, + x_in, + timestep=0, + model_options={"batched_cfg": False}, + ) + + assert batched_model.batch_sizes == [2] + assert unbatched_model.batch_sizes == [1, 1] diff --git a/tests/unit/test_optimization_plumbing.py b/tests/unit/test_optimization_plumbing.py new file mode 100644 index 0000000..95fc7e7 --- /dev/null +++ b/tests/unit/test_optimization_plumbing.py @@ -0,0 +1,75 @@ +import torch + +from src.Device import Device +from src.Model.ModelPatcher import ModelPatcher + + +class DummyDiffusionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def memory_required(self, input_shape=None): + return 1 + + +def test_model_function_wrappers_compose_in_application_order(): + patcher = ModelPatcher( + DummyDiffusionModel(), + load_device=torch.device("cpu"), + offload_device=torch.device("cpu"), + ) + call_order = [] + + def wrapper_one(model_function, params): + call_order.append("wrapper_one_before") + out = model_function(params["input"], params["timestep"], **params["c"]) + call_order.append("wrapper_one_after") + return out + 1 + + def wrapper_two(model_function, params): + call_order.append("wrapper_two_before") + out = model_function(params["input"], params["timestep"], **params["c"]) + call_order.append("wrapper_two_after") + return out * 2 + + patcher.set_model_unet_function_wrapper(wrapper_one) + patcher.set_model_unet_function_wrapper(wrapper_two) + + wrapped = patcher.model_options["model_function_wrapper"] + + def base_model_function(input_x, timestep, **c_kwargs): + call_order.append("base") + return input_x + c_kwargs["bias"] + + result = wrapped( + base_model_function, + { + "input": torch.tensor([1.0]), + "timestep": torch.tensor([0.0]), + "c": {"bias": torch.tensor([3.0])}, + }, + ) + + assert torch.equal(result, torch.tensor([10.0])) + assert call_order == [ + "wrapper_two_before", + "wrapper_one_before", + "base", + "wrapper_one_after", + "wrapper_two_after", + ] + + +def test_sageattention_enabled_allows_compute_12_when_available(monkeypatch): + monkeypatch.setattr(Device, "cpu_state", Device.CPUState.GPU) + monkeypatch.setattr(Device, "directml_enabled", False) + monkeypatch.setattr(Device, "SAGEATTENTION_IS_AVAILABLE", True) + monkeypatch.setattr(Device, "SPARGEATTN_IS_AVAILABLE", True) + monkeypatch.setattr(Device, "is_intel_xpu", lambda: False) + monkeypatch.setattr(Device, "is_rocm", lambda: False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (12, 0)) + + assert Device.sageattention_enabled() is True + assert Device.spargeattn_enabled() is False