Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
- Add support for image-text data calibration in PTQ for Nemotron VL models.
- Add PTQ support for Nemotron Parse.
- Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/distillation>`_ for more details.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
153 changes: 153 additions & 0 deletions examples/diffusers/distillation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# LTX-2 Distillation Training with ModelOpt

Knowledge distillation for LTX-2 DiT models using NVIDIA ModelOpt. A frozen **teacher** guides a trainable **student** through a combined loss:

```text
L_total = α × L_task + (1-α) × L_distill
```

Currently supported:

- **Quantization-Aware Distillation (QAD)** — student uses ModelOpt fake quantization

Planned:

- **Sparsity-Aware Distillation (SAD)** — student uses ModelOpt sparsity

## Installation

```bash
# From the distillation example directory
cd examples/diffusers/distillation

# Install Model-Optimizer (from repo root)
pip install -e ../../..

# Install all dependencies (ltx-trainer, ltx-core, ltx-pipelines, omegaconf)
pip install -r requirements.txt
```

## Quick Start

### 1. Prepare Your Dataset

Use the ltx-trainer preprocessing to extract latents and text embeddings:

```bash
python -m ltx_trainer.preprocess \
--input_dir /path/to/videos \
--output_dir /path/to/preprocessed \
--model_path /path/to/ltx2/checkpoint.safetensors
```

### 2. Configure

Copy and edit the example config:

```bash
cp configs/distillation_example.yaml configs/my_experiment.yaml
```

Key settings to update:

```yaml
model:
model_path: "/path/to/ltx2/checkpoint.safetensors"
text_encoder_path: "/path/to/gemma/model"

data:
preprocessed_data_root: "/path/to/preprocessed/data"

distillation:
distillation_alpha: 0.5 # 1.0 = pure task loss, 0.0 = pure distillation
quant_cfg: "FP8_DEFAULT_CFG" # or INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, null

# IMPORTANT: disable ltx-trainer's built-in quantization
acceleration:
quantization: null
```

### 3. Run Training

#### Single GPU

```bash
python distillation_trainer.py --config configs/my_experiment.yaml
```

#### Multi-GPU (Single Node) with Accelerate

```bash
accelerate launch \
--config_file configs/accelerate/fsdp.yaml \
--num_processes 8 \
distillation_trainer.py --config configs/my_experiment.yaml
```

#### Multi-node Training with Accelerate

To launch on multiple nodes, make sure to set the following environment variables on each node:

- `NUM_NODES`: Total number of nodes
- `GPUS_PER_NODE`: Number of GPUs per node
- `NODE_RANK`: Unique rank/index of this node (0-based)
- `MASTER_ADDR`: IP address of the master node (rank 0)
- `MASTER_PORT`: Communication port (e.g., 29500)

Then run this (on every node):

```bash
accelerate launch \
--config_file configs/accelerate/fsdp.yaml \
--num_machines $NUM_NODES \
--num_processes $((NUM_NODES * GPUS_PER_NODE)) \
--machine_rank $NODE_RANK \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
distillation_trainer.py --config configs/my_experiment.yaml
```

**Config overrides** can be passed via CLI using dotted notation:

```bash
accelerate launch ... distillation_trainer.py \
--config configs/my_experiment.yaml \
++distillation.distillation_alpha=0.6 \
++distillation.quant_cfg=INT8_DEFAULT_CFG \
++optimization.learning_rate=1e-5
```

## Configuration Reference

### Calibration

Before training begins, calibration runs full denoising inference to collect activation statistics for accurate quantizer scales. This is cached as a step-0 checkpoint and reused on subsequent runs.

| Parameter | Default | Description |
|-----------|---------|-------------|
| `calibration_prompts_file` | null | Text file with one prompt per line. Use the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' if null. |
| `calibration_size` | 128 | Number of prompts (each runs a full denoising loop) |
| `calibration_n_steps` | 30 | Denoising steps per prompt |
| `calibration_guidance_scale` | 4.0 | CFG scale (should match inference-time) |

### Checkpoint Resume

| Parameter | Default | Description |
|-----------|---------|-------------|
| `resume_from_checkpoint` | null | `"latest"` to auto-detect, or explicit path |
| `must_save_by` | null | Minutes after which to save and exit (for Slurm time limits) |
| `restore_quantized_checkpoint` | null | Restore a pre-quantized model (skips calibration) |
| `save_quantized_checkpoint` | null | Path to save the final quantized model |

### Custom Quantization Configs

To define custom quantization configs, add entries to `CUSTOM_QUANT_CONFIGS` in `distillation_trainer.py`:

```python
CUSTOM_QUANT_CONFIGS["MY_FP8_CFG"] = {
"quant_cfg": mtq.FP8_DEFAULT_CFG["quant_cfg"],
"algorithm": "max",
}
```

Then reference it in your YAML: `quant_cfg: MY_FP8_CFG`.
45 changes: 45 additions & 0 deletions examples/diffusers/distillation/configs/accelerate/fsdp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# FSDP Configuration
#
# FULL_SHARD across all GPUs for maximum memory efficiency.
# For multi-node training with `accelerate launch`.
#
# Usage:
# accelerate launch \
# --config_file configs/accelerate/fsdp.yaml \
# --num_processes 16 \
# --num_machines 2 \
# --machine_rank $MACHINE_RANK \
# --main_process_ip $MASTER_IP \
# --main_process_port 29500 \
# distillation_trainer.py --config configs/distillation_example.yaml

distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false

fsdp_config:
# FULL_SHARD: Shard optimizer states, gradients, and parameters across ALL GPUs
# This provides maximum memory efficiency for large models like LTX-2 19B
# Parameters are fully sharded across all nodes (not replicated)
fsdp_sharding_strategy: FULL_SHARD

# Enable activation checkpointing to reduce memory during backward pass
# Critical for 19B model training
fsdp_activation_checkpointing: true

fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock
fsdp_use_orig_params: true
fsdp_version: 1

# Note: num_machines and num_processes are overridden by accelerate launch command-line args
# These are just defaults for local testing
num_machines: 1
num_processes: 8
142 changes: 142 additions & 0 deletions examples/diffusers/distillation/configs/distillation_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# LTX-2 Distillation Training Configuration with ModelOpt

# Model Configuration
model:
# Path to the LTX-2 checkpoint (used for both teacher and student)
model_path: "/path/to/ltx2/checkpoint.safetensors"

# Path to Gemma text encoder (required for LTX-2)
text_encoder_path: "/path/to/gemma/model"

# Training mode: "lora" is not supported yet
training_mode: "full"

# Distillation Configuration
distillation:
# Path to teacher model (if different from model.model_path)
# Set to null to use the same checkpoint as student (loaded without quantization)
teacher_model_path:

# Weight for task loss: L_total = α * L_task + (1-α) * L_distill
# α = 1.0: pure task loss (no distillation)
# α = 0.0: pure distillation loss
distillation_alpha: 0.0

# Type of distillation loss
# "mse": Mean squared error (recommended - transformer outputs are continuous velocity predictions)
# "cosine": Cosine similarity loss (matches direction only, ignores magnitude)
distillation_loss_type: "mse"

# Data type for teacher model (bfloat16 recommended for memory efficiency)
teacher_dtype: "bfloat16"

# ModelOpt Quantization Settings
# Name of the mtq config, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG.
# Custom configs defined in CUSTOM_QUANT_CONFIGS (distillation_trainer.py) are also supported.
quant_cfg:

# Full-inference calibration settings (matching PTQ workflow).
# Each prompt runs a complete denoising loop through the DiT, covering all noise levels.
# Path to a text file with one prompt per line. If null, uses the default
# HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (same as PTQ).
calibration_prompts_file:
# Total number of calibration prompts (set to 0 to skip calibration)
calibration_size: 128
# Number of denoising steps per prompt (matches PTQ --n-steps)
calibration_n_steps: 30
# CFG guidance scale during calibration (4.0 = PTQ default, calls transformer
# twice per step for positive + negative prompt; 1.0 = no CFG, saves memory)
calibration_guidance_scale: 4.0

# Path to restore a previously quantized model (from mto.save)
restore_quantized_checkpoint:

# Path to save the final quantized model checkpoint
save_quantized_checkpoint:

# Resume from a full training state checkpoint (saves model + optimizer + RNG + step)
# Set to "latest" to auto-find the most recent checkpoint in output_dir/checkpoints/
# Or set to an explicit path like "/path/to/checkpoints/step_001000"
resume_from_checkpoint: latest

# Time-limit-aware saving for Slurm jobs.
# Minutes after which training must save a checkpoint and exit gracefully.
# Set slightly below your Slurm --time limit (e.g. time=30min -> must_save_by: 25).
# Timer starts when train() is called (after model loading/calibration).
must_save_by:

# Debug/Test: Use mock data instead of real preprocessed data
# Useful for testing the training pipeline without preparing a dataset
use_mock_data: false
mock_data_samples: 100

# Training Strategy
training_strategy:
name: "text_to_video"
first_frame_conditioning_p: 0.1
with_audio: false

# Optimization Configuration
optimization:
learning_rate: 2.0e-6
steps: 10000
batch_size: 1
gradient_accumulation_steps: 4
max_grad_norm: 1.0
optimizer_type: "adamw" # # Use "adamw8bit" for memory efficiency
scheduler_type: "cosine"
enable_gradient_checkpointing: true # Essential for memory savings

# Acceleration Configuration
acceleration:
mixed_precision_mode: "bf16"

# NOTE: Set to null - we use ModelOpt quantization instead of ltx-trainer's quanto
quantization:

# 8-bit text encoder for memory savings
load_text_encoder_in_8bit: false

# Data Configuration
data:
# Path to preprocessed training data (created by process_dataset.py)
preprocessed_data_root: "/path/to/preprocessed/data"
num_dataloader_workers: 2

# Validation Configuration
validation:
prompts:
- "A beautiful sunset over the ocean with gentle waves"
- "A cat playing with a ball of yarn in a cozy living room"
negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted"
video_dims: [512, 320, 33] # [width, height, frames]
frame_rate: 25.0
inference_steps: 30
interval: 500 # Validate every 500 steps
guidance_scale: 4.0
seed: 42

# Checkpointing Configuration
checkpoints:
interval: 1000 # Save checkpoint every 1000 steps
keep_last_n: 3 # Keep only last 3 checkpoints
precision: "bfloat16"

# Weights & Biases Logging
wandb:
enabled: true
project: "ltx2-distillation"
entity: # Your W&B username or team
tags:
- "distillation"
- "modelopt"
log_validation_videos: true

# Flow Matching Configuration
flow_matching:
timestep_sampling_mode: "shifted_logit_normal"
timestep_sampling_params: {}

# General Settings
seed: 42
output_dir: "./outputs/distillation_experiment"
Loading