Note: This case study uses Stagehand within Serenity. The training framework, model loading, and SquareQ integration are Serenity code — Stagehand provides the block-swapping runtime.
Training a 12-billion-parameter image diffusion model (Flux 2 Dev) with a 24-billion-parameter text encoder (Mistral 3), using SquareQ INT8 quantization for frozen weights and Stagehand block-swapping, on a single 24GB GPU.
SquareQ stores frozen transformer weights in INT8, dequantizing to bf16 on-the-fly during block loading. Combined with Stagehand's block-swapping, this keeps peak VRAM at ~6 GB for a model that needs ~60 GB in full precision.
Flux 2 Dev is Black Forest Labs' image generation model. Its components:
| Component | Parameters | bf16 Size |
|---|---|---|
| Transformer (8 double-stream + 48 single-stream blocks) | ~12B | ~24 GB |
| Text encoder (Mistral 3 24B) | 24B | ~48 GB |
| VAE | ~160M | ~0.3 GB |
| Total | ~36B | ~72 GB |
Neither the transformer nor the text encoder fits on a 24GB card individually. Loading both is impossible. The text encoder alone is 2x the card's capacity.
Three techniques work together:
- SquareQ INT8 slab — Frozen transformer weights quantized to INT8 (per-row symmetric). 203 layers in a ~30GB safetensors slab on disk. Dequantized to bf16 during block loading.
- Stagehand block-swapping — Streams one transformer block at a time through a pinned memory pool. Forward and backward hooks manage the lifecycle automatically.
- accelerate.cpu_offload — Mistral 3 text encoder uses per-layer CPU offload during the embedding caching pass. Each layer streams to GPU one at a time (~1.2 GB/layer peak).
Load: Mistral 3 24B text encoder + VAE
Strategy: accelerate.cpu_offload (per-layer dispatch)
Work: Process all captions, cache embeddings + latents to disk
Mistral 3's 48 GB of weights stream through GPU one layer at a time via accelerate.cpu_offload(). Each layer's parameters are placed on meta device with the state dict held on CPU. During forward, each submodule is moved to GPU, runs its forward, then moves back. Peak VRAM ~1.2 GB per layer.
After caching, the text encoder is unloaded: accelerate hooks are removed, the model moves to CPU, and GPU memory drops to near zero.
Load: Flux 2 Dev transformer (from SquareQ INT8 slab)
Strategy: StagehandStrategy with SquareQ V2 backing
Work: LoRA fine-tuning with cached embeddings
The transformer has 56 blocks (8 double-stream, 48 single-stream). With SquareQ backing, frozen weights are stored as INT8 in the slab file. On each block load:
- INT8 weights read from slab via memory-mapped I/O
- Dequantized to bf16 in the pinned CPU slab
- Copied to GPU via async DMA on a dedicated CUDA stream
- Module parameters repointed to the GPU tensor
- Forward/backward runs with bf16 precision
- Block evicted: GPU tensor freed, LoRA grads preserved on CPU
Only LoRA adapter weights (~50 MB at rank 16) and non-block submodules (~1.5 GB: norms, projections, embeddings) stay on GPU permanently.
| Resource | Minimum | Used |
|---|---|---|
| GPU VRAM | 16 GB (estimated) | 24 GB (RTX 3090/4090) |
| System RAM | 32 GB | 62 GB |
| Disk | 35 GB free (for slab) | NVMe |
Text encoding phase (~26 GB CPU):
- Mistral 3 weights on CPU: ~48 GB (meta device + state dict)
- Peak GPU: ~1.2 GB (single layer)
Training steady state:
- GPU allocated: 5.52 GB (non-block submodules + LoRA + activations)
- GPU reserved: 6.00 GB (with
empty_cache()after each step) - Pinned pool: 8 GB
- SquareQ slab: memory-mapped from disk (~30 GB file, not in RAM)
{
"model_type": "flux_2_dev",
"training_method": "lora",
"model": {
"type": "flux_2_dev",
"path": "black-forest-labs/FLUX.2-dev"
},
"adapter": {
"type": "lora",
"rank": 16,
"alpha": 16.0,
"target_modules": [
"attn.to_q", "attn.to_k", "attn.to_v", "attn.to_out",
"attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj", "attn.to_add_out"
]
},
"memory": {
"gradient_checkpointing": "on",
"strategy": "stagehand",
"stagehand": {
"pinned_pool_mb": 8192,
"pinned_slab_mb": 2048,
"vram_high_watermark_mb": 10000,
"vram_low_watermark_mb": 6000,
"squareq_slab_path": "output/squareq_slabs/flux2dev_int8.safetensors",
"squareq_manifest_path": "output/squareq_slabs/flux2dev_int8.manifest.json"
}
},
"resolution": 512,
"batch_size": 1,
"learning_rate": 1e-4,
"bucket_policy_enabled": false
}Training on 118 images with text captions, 200 steps, LoRA rank 16, cosine LR schedule with 10-step warmup.
Step 1: loss=0.696 avg=0.696 lr=2.00e-05 gn=0.030
Step 5: loss=0.554 avg=0.726 lr=6.00e-05 gn=0.098
Step 10: loss=0.626 avg=0.731 lr=1.00e-04 gn=0.061
Step 15: loss=0.714 avg=0.741 lr=9.97e-05 gn=0.032
Step 20: loss=0.616 avg=0.742 lr=9.90e-05 gn=0.037
Step 25: loss=0.675 avg=0.765 lr=9.83e-05 gn=0.025
Step 30: loss=0.584 avg=0.748 lr=9.70e-05 gn=0.039
Loss stabilized around 0.71-0.75 average within the first 30 steps. Gradient norms stayed healthy in the 0.02-0.10 range, indicating stable convergence.
| Metric | Value |
|---|---|
| Steady-state step time | ~160 s |
| VRAM allocated (steady state) | 5.52 GB |
| VRAM reserved (steady state) | 6.00 GB |
| Text encoding (118 samples) | ~12 min |
| Latent caching (118 samples) | ~12 min |
| Training (200 steps, estimated) | ~8.9 hours |
| SquareQ params matched | 192/203 |
| OOM events | 0 (after all fixes) |
The VRAM profile is remarkably flat compared to full-precision Stagehand training:
- Allocated: Constant 5.52 GB — non-block submodules + LoRA weights + per-step activations
- Reserved: Constant 6.00 GB — the
empty_cache()call after each step keeps reserved memory tightly controlled - Gap: Only 0.48 GB between allocated and reserved, meaning almost no wasted cache
This is possible because SquareQ INT8 blocks are ~half the size of bf16 blocks. Each block loads, dequantizes, runs forward/backward, and evicts within the watermark budget without approaching the card's 24 GB limit.
SquareQ-backed blocks use a different loading path than module-backed or file-backed:
# In scheduler._load_block():
if block_entry.squareq_backed:
squareq_layers = self._get_squareq_v2_layers(source_path)
_copy_squareq_backed_params_into_buffer(
module=module,
buffer=slab.buffer,
layout=layout,
squareq_layers=squareq_layers,
)INT8 weights are read from the slab, dequantized per-row using stored scales, and written into the pinned slab buffer as bf16. The buffer is then DMA'd to GPU. This happens transparently — the training loop sees bf16 module parameters.
Flux 2 Dev uses diffusers' naming internally (e.g. ff.net.0.proj) but the SquareQ slab was built from HF checkpoint names (e.g. ff.linear_in). The registry's _candidate_tensor_keys() uses bidirectional alias tables to bridge this gap:
aliases = (
("ff.net.0.proj", "ff.linear_in"),
("ff.net.2", "ff.linear_out"),
("ff_context.net.0.proj", "ff_context.linear_in"),
("ff_context.net.2", "ff_context.linear_out"),
)Additionally, LoRA injection renames base weights (attn.to_q.weight → attn.to_q.orig.weight). The _candidate_squareq_layer_keys() method strips these adapter suffixes before matching.
Stagehand now calls torch.cuda.empty_cache() after each step in forward_context(). Without this, PyTorch's caching allocator holds reserved memory from evicted blocks indefinitely. With aspect ratio bucketing (varying activation tensor sizes), the allocator can't reuse fixed-size cached blocks, and reserved memory grows until OOM.
Before the fix: vram=5.52/14.95G (9.4 GB dead reserved memory).
After the fix: vram=5.52/6.00G (0.48 GB overhead).
Five bugs were identified and fixed to make this work:
-
Mistral 3 text encoder OOM:
text_encoder.to(cuda)tries to move 48 GB to GPU. Fix: useaccelerate.cpu_offload()for per-layer streaming. File:serenity/models/flux2.py. -
Legacy config silently discarding Stagehand config: Config keys
base_model_nameoroutput_model_destinationtrigger_is_legacy_config()which rebuilds the entire config dict, discarding all stagehand/squareq settings. Fix: use new Serenity config format. File: config restructure. -
SquareQ matching only 32/203 params: Two sub-causes — LoRA
.origsuffix not stripped during key matching, and Flux 2 FF layer naming mismatch between diffusers and HF checkpoint. Fix: suffix stripping + bidirectional aliases in_candidate_tensor_keys(). File:stagehand/registry.py. -
CUDA memory fragmentation: Stagehand evicts blocks but PyTorch holds reserved memory. Over multiple steps, reserved grows to fill the card. Fix:
torch.cuda.empty_cache()after each step. File:serenity/memory/stagehand_strategy.py. -
Bucket policy disabling gradient checkpointing: After 3 successful steps, the memory predictor sees low VRAM (Stagehand makes it appear nearly empty) and auto-assigns
FASTmode which disables gradient checkpointing. Without checkpointing, activations blow up VRAM. Fix: disable bucket_policy when Stagehand is active. File: config setting.
The SquareQ slab was built with scripts/build_flux2dev_slab.py:
python scripts/build_flux2dev_slab.py \
--model black-forest-labs/FLUX.2-dev \
--output output/squareq_slabs/flux2dev_int8 \
--dtype int8Output:
flux2dev_int8.safetensors— ~30 GB slab with 203 quantized layersflux2dev_int8.manifest.json— canonical name → offset/shape/scale mapping
Quantization is per-row symmetric INT8: each row of a weight matrix gets its own scale factor. This preserves more precision than per-tensor quantization while keeping the format simple (no zero points, no group quantization).
A 12B image model + 24B text encoder = 36B total parameters, trained with LoRA on a single 24GB GPU. The frozen weights are INT8 quantized (halving their size) and streamed from disk one block at a time. Only ~6 GB of VRAM is used at steady state.
The cost is speed: ~160 seconds per step due to block-swapping overhead and INT8→bf16 dequantization. For LoRA fine-tuning (100-500 steps), this completes in 4-22 hours on a single consumer GPU — a task that previously required 48+ GB VRAM or multi-GPU setups.
The combination of SquareQ + Stagehand is strictly more memory-efficient than full-precision Stagehand alone: INT8 blocks are half the size, so they transfer faster and leave more VRAM headroom for activations. The trade-off is a small quantization error in the frozen base weights, which LoRA can compensate for during fine-tuning.