-
Notifications
You must be signed in to change notification settings - Fork 273
[OMNIML-3505] LTX-2 Distillation Trainer #892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,177
−0
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # LTX-2 Distillation Training with ModelOpt | ||
|
|
||
| Knowledge distillation for LTX-2 DiT models using NVIDIA ModelOpt. A frozen **teacher** guides a trainable **student** through a combined loss: | ||
|
|
||
| ```text | ||
| L_total = α × L_task + (1-α) × L_distill | ||
| ``` | ||
|
|
||
| Currently supported: | ||
|
|
||
| - **Quantization-Aware Distillation (QAD)** — student uses ModelOpt fake quantization | ||
|
|
||
| Planned: | ||
|
|
||
| - **Sparsity-Aware Distillation (SAD)** — student uses ModelOpt sparsity | ||
|
|
||
| ## Installation | ||
|
|
||
| ```bash | ||
| # From the distillation example directory | ||
| cd examples/diffusers/distillation | ||
|
|
||
| # Install Model-Optimizer (from repo root) | ||
| pip install -e ../../.. | ||
|
|
||
| # Install all dependencies (ltx-trainer, ltx-core, ltx-pipelines, omegaconf) | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ### 1. Prepare Your Dataset | ||
|
|
||
| Use the ltx-trainer preprocessing to extract latents and text embeddings: | ||
|
|
||
| ```bash | ||
| python -m ltx_trainer.preprocess \ | ||
| --input_dir /path/to/videos \ | ||
| --output_dir /path/to/preprocessed \ | ||
| --model_path /path/to/ltx2/checkpoint.safetensors | ||
| ``` | ||
|
|
||
| ### 2. Configure | ||
|
|
||
| Copy and edit the example config: | ||
|
|
||
| ```bash | ||
| cp configs/distillation_example.yaml configs/my_experiment.yaml | ||
| ``` | ||
|
|
||
| Key settings to update: | ||
|
|
||
| ```yaml | ||
| model: | ||
| model_path: "/path/to/ltx2/checkpoint.safetensors" | ||
| text_encoder_path: "/path/to/gemma/model" | ||
|
|
||
| data: | ||
| preprocessed_data_root: "/path/to/preprocessed/data" | ||
|
|
||
| distillation: | ||
| distillation_alpha: 0.5 # 1.0 = pure task loss, 0.0 = pure distillation | ||
| quant_cfg: "FP8_DEFAULT_CFG" # or INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, null | ||
|
|
||
| # IMPORTANT: disable ltx-trainer's built-in quantization | ||
| acceleration: | ||
| quantization: null | ||
| ``` | ||
|
|
||
| ### 3. Run Training | ||
|
|
||
| #### Single GPU | ||
|
|
||
| ```bash | ||
| python distillation_trainer.py --config configs/my_experiment.yaml | ||
| ``` | ||
|
|
||
| #### Multi-GPU (Single Node) with Accelerate | ||
|
|
||
| ```bash | ||
| accelerate launch \ | ||
| --config_file configs/accelerate/fsdp.yaml \ | ||
| --num_processes 8 \ | ||
| distillation_trainer.py --config configs/my_experiment.yaml | ||
| ``` | ||
|
|
||
| #### Multi-node Training with Accelerate | ||
|
|
||
| To launch on multiple nodes, make sure to set the following environment variables on each node: | ||
|
|
||
| - `NUM_NODES`: Total number of nodes | ||
| - `GPUS_PER_NODE`: Number of GPUs per node | ||
| - `NODE_RANK`: Unique rank/index of this node (0-based) | ||
| - `MASTER_ADDR`: IP address of the master node (rank 0) | ||
| - `MASTER_PORT`: Communication port (e.g., 29500) | ||
|
|
||
| Then run this (on every node): | ||
|
|
||
| ```bash | ||
| accelerate launch \ | ||
| --config_file configs/accelerate/fsdp.yaml \ | ||
| --num_machines $NUM_NODES \ | ||
| --num_processes $((NUM_NODES * GPUS_PER_NODE)) \ | ||
| --machine_rank $NODE_RANK \ | ||
| --main_process_ip $MASTER_ADDR \ | ||
| --main_process_port $MASTER_PORT \ | ||
| distillation_trainer.py --config configs/my_experiment.yaml | ||
| ``` | ||
|
|
||
| **Config overrides** can be passed via CLI using dotted notation: | ||
|
|
||
| ```bash | ||
| accelerate launch ... distillation_trainer.py \ | ||
| --config configs/my_experiment.yaml \ | ||
| ++distillation.distillation_alpha=0.6 \ | ||
| ++distillation.quant_cfg=INT8_DEFAULT_CFG \ | ||
| ++optimization.learning_rate=1e-5 | ||
| ``` | ||
|
|
||
| ## Configuration Reference | ||
|
|
||
| ### Calibration | ||
|
|
||
| Before training begins, calibration runs full denoising inference to collect activation statistics for accurate quantizer scales. This is cached as a step-0 checkpoint and reused on subsequent runs. | ||
|
|
||
| | Parameter | Default | Description | | ||
| |-----------|---------|-------------| | ||
| | `calibration_prompts_file` | null | Text file with one prompt per line. Use the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' if null. | | ||
| | `calibration_size` | 128 | Number of prompts (each runs a full denoising loop) | | ||
| | `calibration_n_steps` | 30 | Denoising steps per prompt | | ||
| | `calibration_guidance_scale` | 4.0 | CFG scale (should match inference-time) | | ||
|
|
||
| ### Checkpoint Resume | ||
|
|
||
| | Parameter | Default | Description | | ||
| |-----------|---------|-------------| | ||
| | `resume_from_checkpoint` | null | `"latest"` to auto-detect, or explicit path | | ||
| | `must_save_by` | null | Minutes after which to save and exit (for Slurm time limits) | | ||
| | `restore_quantized_checkpoint` | null | Restore a pre-quantized model (skips calibration) | | ||
| | `save_quantized_checkpoint` | null | Path to save the final quantized model | | ||
|
|
||
| ### Custom Quantization Configs | ||
|
|
||
| To define custom quantization configs, add entries to `CUSTOM_QUANT_CONFIGS` in `distillation_trainer.py`: | ||
|
|
||
| ```python | ||
| CUSTOM_QUANT_CONFIGS["MY_FP8_CFG"] = { | ||
| "quant_cfg": mtq.FP8_DEFAULT_CFG["quant_cfg"], | ||
| "algorithm": "max", | ||
| } | ||
| ``` | ||
|
|
||
| Then reference it in your YAML: `quant_cfg: MY_FP8_CFG`. |
45 changes: 45 additions & 0 deletions
45
examples/diffusers/distillation/configs/accelerate/fsdp.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # FSDP Configuration | ||
| # | ||
| # FULL_SHARD across all GPUs for maximum memory efficiency. | ||
| # For multi-node training with `accelerate launch`. | ||
| # | ||
| # Usage: | ||
| # accelerate launch \ | ||
| # --config_file configs/accelerate/fsdp.yaml \ | ||
| # --num_processes 16 \ | ||
| # --num_machines 2 \ | ||
| # --machine_rank $MACHINE_RANK \ | ||
| # --main_process_ip $MASTER_IP \ | ||
| # --main_process_port 29500 \ | ||
| # distillation_trainer.py --config configs/distillation_example.yaml | ||
|
|
||
| distributed_type: FSDP | ||
| downcast_bf16: 'no' | ||
| enable_cpu_affinity: false | ||
|
|
||
| fsdp_config: | ||
| # FULL_SHARD: Shard optimizer states, gradients, and parameters across ALL GPUs | ||
| # This provides maximum memory efficiency for large models like LTX-2 19B | ||
| # Parameters are fully sharded across all nodes (not replicated) | ||
| fsdp_sharding_strategy: FULL_SHARD | ||
|
|
||
| # Enable activation checkpointing to reduce memory during backward pass | ||
| # Critical for 19B model training | ||
| fsdp_activation_checkpointing: true | ||
|
|
||
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| fsdp_backward_prefetch: BACKWARD_PRE | ||
| fsdp_cpu_ram_efficient_loading: true | ||
| fsdp_forward_prefetch: false | ||
| fsdp_offload_params: false | ||
| fsdp_reshard_after_forward: true | ||
| fsdp_state_dict_type: SHARDED_STATE_DICT | ||
| fsdp_sync_module_states: true | ||
| fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock | ||
| fsdp_use_orig_params: true | ||
| fsdp_version: 1 | ||
|
|
||
| # Note: num_machines and num_processes are overridden by accelerate launch command-line args | ||
| # These are just defaults for local testing | ||
| num_machines: 1 | ||
| num_processes: 8 |
142 changes: 142 additions & 0 deletions
142
examples/diffusers/distillation/configs/distillation_example.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| # LTX-2 Distillation Training Configuration with ModelOpt | ||
|
|
||
| # Model Configuration | ||
| model: | ||
| # Path to the LTX-2 checkpoint (used for both teacher and student) | ||
| model_path: "/path/to/ltx2/checkpoint.safetensors" | ||
|
|
||
| # Path to Gemma text encoder (required for LTX-2) | ||
| text_encoder_path: "/path/to/gemma/model" | ||
|
|
||
| # Training mode: "lora" is not supported yet | ||
| training_mode: "full" | ||
|
|
||
| # Distillation Configuration | ||
| distillation: | ||
| # Path to teacher model (if different from model.model_path) | ||
| # Set to null to use the same checkpoint as student (loaded without quantization) | ||
| teacher_model_path: | ||
|
|
||
| # Weight for task loss: L_total = α * L_task + (1-α) * L_distill | ||
| # α = 1.0: pure task loss (no distillation) | ||
| # α = 0.0: pure distillation loss | ||
| distillation_alpha: 0.0 | ||
|
|
||
| # Type of distillation loss | ||
| # "mse": Mean squared error (recommended - transformer outputs are continuous velocity predictions) | ||
| # "cosine": Cosine similarity loss (matches direction only, ignores magnitude) | ||
| distillation_loss_type: "mse" | ||
|
|
||
| # Data type for teacher model (bfloat16 recommended for memory efficiency) | ||
| teacher_dtype: "bfloat16" | ||
|
|
||
| # ModelOpt Quantization Settings | ||
| # Name of the mtq config, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG. | ||
| # Custom configs defined in CUSTOM_QUANT_CONFIGS (distillation_trainer.py) are also supported. | ||
| quant_cfg: | ||
|
|
||
| # Full-inference calibration settings (matching PTQ workflow). | ||
| # Each prompt runs a complete denoising loop through the DiT, covering all noise levels. | ||
| # Path to a text file with one prompt per line. If null, uses the default | ||
| # HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (same as PTQ). | ||
| calibration_prompts_file: | ||
| # Total number of calibration prompts (set to 0 to skip calibration) | ||
| calibration_size: 128 | ||
| # Number of denoising steps per prompt (matches PTQ --n-steps) | ||
| calibration_n_steps: 30 | ||
| # CFG guidance scale during calibration (4.0 = PTQ default, calls transformer | ||
| # twice per step for positive + negative prompt; 1.0 = no CFG, saves memory) | ||
| calibration_guidance_scale: 4.0 | ||
|
|
||
| # Path to restore a previously quantized model (from mto.save) | ||
| restore_quantized_checkpoint: | ||
|
|
||
| # Path to save the final quantized model checkpoint | ||
| save_quantized_checkpoint: | ||
|
|
||
| # Resume from a full training state checkpoint (saves model + optimizer + RNG + step) | ||
| # Set to "latest" to auto-find the most recent checkpoint in output_dir/checkpoints/ | ||
| # Or set to an explicit path like "/path/to/checkpoints/step_001000" | ||
| resume_from_checkpoint: latest | ||
|
|
||
| # Time-limit-aware saving for Slurm jobs. | ||
| # Minutes after which training must save a checkpoint and exit gracefully. | ||
| # Set slightly below your Slurm --time limit (e.g. time=30min -> must_save_by: 25). | ||
| # Timer starts when train() is called (after model loading/calibration). | ||
| must_save_by: | ||
|
|
||
| # Debug/Test: Use mock data instead of real preprocessed data | ||
| # Useful for testing the training pipeline without preparing a dataset | ||
| use_mock_data: false | ||
| mock_data_samples: 100 | ||
|
|
||
| # Training Strategy | ||
| training_strategy: | ||
| name: "text_to_video" | ||
| first_frame_conditioning_p: 0.1 | ||
| with_audio: false | ||
|
|
||
| # Optimization Configuration | ||
| optimization: | ||
| learning_rate: 2.0e-6 | ||
| steps: 10000 | ||
| batch_size: 1 | ||
| gradient_accumulation_steps: 4 | ||
| max_grad_norm: 1.0 | ||
| optimizer_type: "adamw" # # Use "adamw8bit" for memory efficiency | ||
| scheduler_type: "cosine" | ||
| enable_gradient_checkpointing: true # Essential for memory savings | ||
|
|
||
| # Acceleration Configuration | ||
| acceleration: | ||
| mixed_precision_mode: "bf16" | ||
|
|
||
| # NOTE: Set to null - we use ModelOpt quantization instead of ltx-trainer's quanto | ||
| quantization: | ||
|
|
||
| # 8-bit text encoder for memory savings | ||
| load_text_encoder_in_8bit: false | ||
|
|
||
| # Data Configuration | ||
| data: | ||
| # Path to preprocessed training data (created by process_dataset.py) | ||
| preprocessed_data_root: "/path/to/preprocessed/data" | ||
| num_dataloader_workers: 2 | ||
|
|
||
| # Validation Configuration | ||
| validation: | ||
| prompts: | ||
| - "A beautiful sunset over the ocean with gentle waves" | ||
| - "A cat playing with a ball of yarn in a cozy living room" | ||
| negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" | ||
| video_dims: [512, 320, 33] # [width, height, frames] | ||
| frame_rate: 25.0 | ||
| inference_steps: 30 | ||
| interval: 500 # Validate every 500 steps | ||
| guidance_scale: 4.0 | ||
| seed: 42 | ||
|
|
||
| # Checkpointing Configuration | ||
| checkpoints: | ||
| interval: 1000 # Save checkpoint every 1000 steps | ||
| keep_last_n: 3 # Keep only last 3 checkpoints | ||
| precision: "bfloat16" | ||
|
|
||
| # Weights & Biases Logging | ||
| wandb: | ||
| enabled: true | ||
| project: "ltx2-distillation" | ||
| entity: # Your W&B username or team | ||
| tags: | ||
| - "distillation" | ||
| - "modelopt" | ||
| log_validation_videos: true | ||
|
|
||
| # Flow Matching Configuration | ||
| flow_matching: | ||
| timestep_sampling_mode: "shifted_logit_normal" | ||
| timestep_sampling_params: {} | ||
|
|
||
| # General Settings | ||
| seed: 42 | ||
| output_dir: "./outputs/distillation_experiment" | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.