Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ That's when the first version of LightDiffusion was born which only counted [300

📚 Learn more in the [official documentation](https://aatricks.github.io/LightDiffusion-Next/)

For a source-based breakdown of the optimization stack, see the [Implemented Optimizations Report](https://aatricks.github.io/LightDiffusion-Next/implemented-optimizations-report/).

---

## 🌟 Highlights
Expand Down Expand Up @@ -181,7 +183,7 @@ docker-compose build \
Set `INSTALL_STABLE_FAST=1` to enable the compilation step for stable-fast, or `INSTALL_OLLAMA=1` to bake in the prompt enhancer runtime.

> [!NOTE]
> RTX 50 series (compute 12.0) GPUs currently only support SageAttention.
> RTX 50 series (compute 12.0) GPUs currently use SageAttention when the SageAttention kernel is installed. SpargeAttn remains limited to earlier supported architectures.

**Access the Web Interface:**
- **Streamlit UI** (default): `http://localhost:8501`
Expand Down
10 changes: 6 additions & 4 deletions docs/advanced-cfg-optimizations.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This document describes three advanced optimizations for Classifier-Free Guidanc

### What It Does

Instead of running two separate forward passes for conditional and unconditional predictions, this optimization combines them into a single batched forward pass.
Instead of running two separate forward passes for conditional and unconditional predictions, this optimization can combine them into a single batched forward pass.

**Before:**
```python
Expand Down Expand Up @@ -47,13 +47,15 @@ samples = sampling.sample1(
steps=20,
cfg=7.5,
# ... other params ...
batched_cfg=True, # Enable batched CFG (default: True)
batched_cfg=True, # Joint cond/uncond batching (default: True)
)
```

In the current implementation, the heavy lifting still happens in the central conditioning packing path. `batched_cfg` controls whether conditional and unconditional branches are packed together into the same forward pass when possible. Conditioning chunks within each branch are still packed by the shared batching logic.

### When to Use

- **Always recommended** - This is a pure speed optimization with no quality tradeoff
- **Usually recommended** - This reduces duplicate cond/uncond forward passes when memory allows
- Particularly beneficial for high-resolution images or batch generation
- Compatible with all samplers and schedulers

Expand Down Expand Up @@ -232,7 +234,7 @@ samples = sampling.sample1(
### Batched CFG Issues

**Problem**: Memory errors with batched CFG
**Solution**: System may not have enough VRAM. Disable with `batched_cfg=False`
**Solution**: System may not have enough VRAM for joint cond/uncond batching. Disable it with `batched_cfg=False`, which keeps the conditioning path active but runs the two branches separately.

### Dynamic CFG Issues

Expand Down
484 changes: 484 additions & 0 deletions docs/implemented-optimizations-report.md

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions docs/optimizations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

LightDiffusion-Next achieves its industry-leading inference speed through a layered stack of training-free optimizations that can be selectively enabled based on your hardware and quality requirements. This page provides an overview of each acceleration technique and links to detailed guides.

For a detailed source-based report on what is implemented today, including server-side throughput optimizations and practical implementation notes, see the [Implemented Optimizations Report](implemented-optimizations-report.md).

## Optimization Stack Overview

The pipeline orchestrates six primary acceleration paths:
Expand Down Expand Up @@ -113,10 +115,10 @@ Multi-Scale Diffusion optimizes performance by processing images at multiple res

### WaveSpeed Caching

**What it does:** Exploits temporal redundancy in diffusion processes by caching high-level features in the UNet/Transformer architecture and reusing them across multiple denoising steps. Includes two strategies:
**What it does:** Exploits temporal redundancy in diffusion processes by reusing work across denoising steps. In the current project stack this primarily means DeepCache on supported UNet models, with additional Flux-oriented cache groundwork present in the codebase.

1. **DeepCache** — Caches middle/output block activations in UNet models (SD1.5, SDXL)
2. **First Block Cache (FBCache)** — Caches initial Transformer block outputs in Flux models
1. **DeepCache** — Reuses prior denoiser outputs on selected steps in UNet models (SD1.5, SDXL)
2. **First Block Cache (FBCache)** — Flux-oriented cache machinery available for specialized integration work

**When to use:**
- Any workflow where you can tolerate slight smoothing in exchange for 2-3x speedup
Expand Down Expand Up @@ -177,9 +179,9 @@ steps: 10 # Reduced from 15 (same quality with AYS)
stable_fast: false # not supported
sageattention: auto
prompt_cache_enabled: true
fbcache:
deepcache:
enabled: true
residual_threshold: 0.01 # strict caching
interval: 2
```
**Expected:** ~2x speedup with minimal quality impact

Expand Down
12 changes: 7 additions & 5 deletions docs/Prompt-caching.md → docs/prompt-caching.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 1. Prompt Attention Caching
# Prompt Attention Caching

### What It Does

Expand Down Expand Up @@ -29,10 +29,11 @@ print(f"Hit rate: {stats['hit_rate']:.1%}")
```

**Cache Settings**:
- Maximum entries: 128 prompts
- Memory usage: ~50-200MB
- Cache cleared on: restart or manual clear
- Automatic pruning: removes oldest 25% when full
- Maximum entries: 256 prompts before pruning
- Cache structure: global dict keyed by prompt hash and CLIP identity
- Memory usage: workload-dependent, estimated from cached embedding tensors
- Cache cleared on: restart, disable, or manual clear
- Automatic pruning: removes the oldest 25% of entries when the cache exceeds its limit

### Viewing Cache Stats

Expand Down Expand Up @@ -60,3 +61,4 @@ prompt_cache.print_cache_stats()
2. **Monitor hit rate** - should be >50% in typical workflows
3. **Clear cache** when switching models or major prompt changes
4. **Batch similar prompts** to maximize cache hits
5. **Expect global behavior** because the cache is shared across repeated prompt encodes rather than being scoped to a single generation session
2 changes: 1 addition & 1 deletion docs/sageattention.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ python -c "import spas_sage_attn; print('SpargeAttn installed successfully')"
| RTX 4060/4070/4080/4090 | 8.9 | `"8.9"` |
| A100 | 8.0 | `"8.0"` |
| H100 | 9.0 | `"9.0"` |
| RTX 5060/5070/5080/5090 | 12.0 | Not supported yet |
| RTX 5060/5070/5080/5090 | 12.0 | SageAttention supported, SpargeAttn pending |

### Docker Installation

Expand Down
45 changes: 17 additions & 28 deletions docs/wavespeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

## Overview

WaveSpeed is a collection of **feature caching strategies** that exploit temporal redundancy in diffusion processes. By reusing high-level features across multiple denoising steps, WaveSpeed can provide significant speedup with tunable quality trade-offs.
WaveSpeed is the project's caching-oriented optimization layer for reusing work across denoising steps. In the current codebase, the integrated path is DeepCache for UNet-based models, and the repository also contains groundwork for a Flux-oriented First Block Cache path.

LightDiffusion-Next implements two WaveSpeed variants:
LightDiffusion-Next contains two WaveSpeed-related implementations:

1. **DeepCache** — For UNet-based models (SD1.5, SDXL)
2. **First Block Cache (FBCache)** — For Transformer-based models (Flux)
1. **DeepCache** — Integrated for UNet-based models (SD1.5, SDXL)
2. **First Block Cache (FBCache)** — Flux-oriented cache machinery present in the codebase

Both are **training-free**, work alongside other optimizations and can be toggled per-generation.
Both are training-free. DeepCache is the user-facing path today; First Block Cache is codebase groundwork for a more specialized transformer caching path.

## How It Works

Expand All @@ -20,36 +20,25 @@ Diffusion models denoise images iteratively over 20-50 steps. Researchers observ
- **High-level features** (semantic structure, composition) change slowly across steps
- **Low-level features** (fine details, textures) require frequent updates

WaveSpeed caches the expensive high-level computations and reuses them for several steps, only updating low-level details cheaply.
WaveSpeed aims to reduce repeated computation across nearby denoising steps by reusing information from earlier steps where practical.

### DeepCache (UNet Models) {#deepcache}

DeepCache targets the middle and output blocks of the UNet architecture:

```
┌─────────────────────────────────────────┐
│ Input Blocks (always computed) │
├─────────────────────────────────────────┤
│ Middle Blocks (cached every N steps) │ ← DeepCache caching zone
├─────────────────────────────────────────┤
│ Output Blocks (cached every N steps) │ ← DeepCache caching zone
└─────────────────────────────────────────┘
```
DeepCache is the integrated WaveSpeed path for UNet models.

**Cache step (every N steps):**
1. Run full forward pass through all UNet blocks
2. Store middle/output block activations in cache
1. Run the full denoiser path
2. Store the output for later reuse

**Reuse step (N-1 intermediate steps):**
1. Run only input blocks
2. Retrieve cached middle/output activations
3. Skip expensive middle/output block computation
**Reuse step (intermediate steps):**
1. Reuse the cached denoiser output
2. Skip the full model recomputation for that step

**Speedup:** ~50-70% time saved per reuse step → 2-3x total speedup with `interval=3`

### First Block Cache (Flux Models)

Flux uses Transformer blocks instead of UNet convolutions. FBCache applies a similar principle:
Flux uses Transformer blocks instead of UNet convolutions. The repository includes a First Block Cache implementation for this architecture family:

```
┌─────────────────────────────────────────┐
Expand All @@ -65,7 +54,7 @@ Flux uses Transformer blocks instead of UNet convolutions. FBCache applies a sim
3. If difference < threshold: reuse cached remaining blocks
4. If difference ≥ threshold: run all blocks and update cache

**Adaptive caching:** Automatically decides when to cache vs. recompute based on feature similarity.
In the current project structure, this cache path is implementation groundwork rather than a standard generation toggle like DeepCache.

## DeepCache Configuration

Expand Down Expand Up @@ -160,7 +149,7 @@ end_step: 800

### Usage

FBCache is applied automatically when generating Flux images. No UI controls yet — configured via pipeline code:
First Block Cache is not currently exposed as a standard per-generation toggle. The implementation is available in the codebase for specialized integration work:

```python
# In src/user/pipeline.py
Expand All @@ -169,7 +158,7 @@ from src.WaveSpeed import fbcache_nodes
# Create cache context
cache_context = fbcache_nodes.create_cache_context()

# Apply caching to Flux model
# Apply caching to a Flux-style model
with fbcache_nodes.cache_context(cache_context):
patched_model = fbcache_nodes.create_patch_flux_forward_orig(
flux_model,
Expand All @@ -196,7 +185,7 @@ Speedup scales with cache interval and depth:
| SD1.5 | 3 | Good speedup, slight quality loss |
| SD1.5 | 5 | High speedup, noticeable quality loss |
| SDXL | 3 | Good speedup, slight quality loss |
| Flux (FBCache) | auto | Moderate speedup, minimal quality loss |
| Flux-style caching paths | implementation-specific | Depends on the integration path |

**Performance varies based on:**
- GPU architecture
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nav:
- REST API: api.md
- Performance Optimizations:
- Overview: optimizations.md
- Implementation Report: implemented-optimizations-report.md
- CFG-Free Sampling: cfg-free-sampling.md
- Token Merging (ToMe): tome.md
- SageAttention & SpargeAttn: sageattention.md
Expand Down
6 changes: 0 additions & 6 deletions src/Device/Device.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,6 @@ def get_autocast_device(dev) -> str:
def sageattention_enabled() -> bool:
if cpu_state != CPUState.GPU or is_intel_xpu() or directml_enabled or is_rocm():
return False
if torch.cuda.is_available():
try:
if torch.cuda.get_device_capability()[0] >= 12:
return False
except:
pass
return SAGEATTENTION_IS_AVAILABLE


Expand Down
62 changes: 61 additions & 1 deletion src/Model/ModelPatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,56 @@ def __call__(self, weight: torch.Tensor) -> torch.Tensor:
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)


class ModelFunctionWrapperChain:
"""Compose multiple model_function_wrapper hooks without overwriting them.

Several optimizations patch the same U-Net wrapper hook. Keeping only the
last wrapper silently disables earlier optimizations. This chain preserves
application order by making the most recently-added wrapper the outermost
wrapper around the existing stack.
"""

def __init__(self, wrappers=None):
self.wrappers = list(wrappers or [])

def add_outer(self, wrapper):
self.wrappers.insert(0, wrapper)
return self

def __call__(self, model_function, params):
return self._invoke(0, model_function, params)

def _invoke(self, index, model_function, params):
if index >= len(self.wrappers):
return model_function(
params["input"],
params["timestep"],
**params.get("c", {}),
)

wrapper = self.wrappers[index]

def next_model_function(input_x, timestep, **c_kwargs):
next_params = dict(params)
next_params["input"] = input_x
next_params["timestep"] = timestep
next_params["c"] = c_kwargs
return self._invoke(index + 1, model_function, next_params)

return wrapper(next_model_function, params)

def to(self, device):
updated = []
for wrapper in self.wrappers:
if hasattr(wrapper, "to"):
moved = wrapper.to(device)
updated.append(moved if moved is not None else wrapper)
else:
updated.append(wrapper)
self.wrappers = updated
return self


class ModelPatcher:
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device,
size: int = 0, current_device: torch.device = None, weight_inplace_update: bool = False):
Expand Down Expand Up @@ -91,7 +141,17 @@ def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)

def set_model_unet_function_wrapper(self, f):
self.model_options["model_function_wrapper"] = f
existing = self.model_options.get("model_function_wrapper")
if existing is None:
self.model_options["model_function_wrapper"] = f
return

if isinstance(existing, ModelFunctionWrapperChain):
existing.add_outer(f)
self.model_options["model_function_wrapper"] = existing
return

self.model_options["model_function_wrapper"] = ModelFunctionWrapperChain([f, existing])

def set_model_denoise_mask_function(self, f):
self.model_options["denoise_mask_function"] = f
Expand Down
9 changes: 8 additions & 1 deletion src/cond/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list:
out_conds = [torch.zeros_like(x_in) for _ in range(len(conds))]
out_counts = [torch.ones_like(x_in) * 1e-37 for _ in range(len(conds))]
to_run = []
batched_cfg = model_options.get("batched_cfg", True)

for i, cond in enumerate(conds):
if cond is not None:
Expand All @@ -130,9 +131,15 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options) -> list:
while to_run:
first = to_run[0]
first_shape = first[0][0].shape
first_cond_index = first[1]

# Find compatible conditions
to_batch_temp = [x for x in range(len(to_run)) if cond_util.can_concat_cond(to_run[x][0], first[0])]
to_batch_temp = [
x
for x in range(len(to_run))
if cond_util.can_concat_cond(to_run[x][0], first[0])
and (batched_cfg or to_run[x][1] == first_cond_index)
]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]

Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_calc_cond_batch_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ def apply_model(self, *args, **kwargs):
return inp


class RecordingDummyModel(DummyModel):
def __init__(self):
self.batch_sizes = []

def apply_model(self, *args, **kwargs):
inp = args[0] if args else kwargs.get("input")
self.batch_sizes.append(int(inp.shape[0]))
return inp


def test_calc_cond_batch_fallback_on_transformer_options_mismatch(monkeypatch):
called = {"flag": False}

Expand Down Expand Up @@ -45,3 +55,30 @@ def spy_run_model_per_chunk(model, x_in, timestep, input_x_list, c_list, batch_s
assert isinstance(out, list) and len(out) == 2
assert out[0].shape == x_in.shape
assert out[1].shape == x_in.shape


def test_calc_cond_batch_honors_batched_cfg_toggle():
x_in = torch.zeros((1, 4, 8, 8))
cond_dict = {"model_conds": {"c_crossattn": CONDRegular(torch.zeros((1, 1, 1, 1)))}}
conds = [[cond_dict], [cond_dict]]

batched_model = RecordingDummyModel()
calc_cond_batch(
batched_model,
conds,
x_in,
timestep=0,
model_options={"batched_cfg": True},
)

unbatched_model = RecordingDummyModel()
calc_cond_batch(
unbatched_model,
conds,
x_in,
timestep=0,
model_options={"batched_cfg": False},
)

assert batched_model.batch_sizes == [2]
assert unbatched_model.batch_sizes == [1, 1]
Loading
Loading