Skip to content

[OMNIML-3505] LTX-2 Distillation Trainer#892

Merged
mxinO merged 7 commits intomainfrom
mxin/ltx-distill
Feb 14, 2026
Merged

[OMNIML-3505] LTX-2 Distillation Trainer#892
mxinO merged 7 commits intomainfrom
mxin/ltx-distill

Conversation

@mxinO
Copy link
Contributor

@mxinO mxinO commented Feb 14, 2026

What does this PR do?

Type of change: new example

Overview:
Adding LTX-2 distillation trainer.

Usage

accelerate launch \
    --config_file configs/accelerate/fsdp.yaml \
    --num_processes 8 \
    distillation_trainer.py --config configs/distillation_example.yaml

See readme for more details.

Testing

Run training with single/multiple nodes.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: NA
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

New Features

  • Added distillation training support for LTX-2 models with quantization integration.
  • Introduced comprehensive documentation and example configurations for distillation workflows.
  • Includes multi-GPU and multi-node training setup with distributed training support and customizable configuration templates.

Signed-off-by: Meng Xin <mxin@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 14, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds LTX-2 distillation training support, including a comprehensive trainer implementation that integrates ModelOpt quantization, distributed training via FSDP, teacher-student model coordination, and a CLI-driven workflow with calibration and checkpoint management.

Changes

Cohort / File(s) Summary
Documentation & Changelog
CHANGELOG.rst, examples/diffusers/distillation/README.md
Added changelog entry and comprehensive README documenting distillation training setup, installation, configuration options, multi-GPU execution, quantization features, and concrete usage examples.
Configuration Files
examples/diffusers/distillation/configs/accelerate/fsdp.yaml, examples/diffusers/distillation/configs/distillation_example.yaml
Added FSDP configuration for distributed training and a detailed distillation training config covering model paths, distillation loss settings, quantization, calibration, optimization, validation, checkpointing, and logging parameters.
Core Implementation
examples/diffusers/distillation/distillation_trainer.py
Introduced DistillationTrainer class extending LtxvTrainer with teacher-student model handling, ModelOpt quantization workflows, calibration via embeddings caching, combined loss training (task + distillation), distributed checkpointing with FSDP/LoRA support, resume logic, and CLI entrypoint.
Dependencies
examples/diffusers/distillation/requirements.txt
Added VCS references for ltx-core, ltx-pipelines, ltx-trainer packages and omegaconf dependency.

Sequence Diagram

sequenceDiagram
    actor User
    participant Main as Main Script
    participant Trainer as DistillationTrainer
    participant Teacher as Teacher Model
    participant Calibration as Calibration Engine
    participant Quantizer as ModelOpt Quantizer
    participant Student as Student Model
    participant DataLoader as DataLoader

    User->>Main: Run with YAML config + CLI overrides
    Main->>Trainer: Instantiate with config
    Trainer->>Teacher: Load & freeze teacher model
    Trainer->>Trainer: Prepare teacher with accelerator
    
    alt Quantization Path
        Trainer->>Calibration: Collect statistics via forward pass
        Calibration->>DataLoader: Fetch calibration prompts
        DataLoader-->>Calibration: Batch data
        Calibration->>Teacher: Forward pass through teacher
        Calibration->>Trainer: Cache embeddings
        Trainer->>Quantizer: Apply quantization to student
    end
    
    Trainer->>Student: Prepare student model (post-quant)
    Trainer->>Trainer: Start training loop
    
    loop Each Training Step
        DataLoader->>Trainer: Provide batch (latents, conditions)
        Trainer->>Student: Forward pass (student predictions)
        Trainer->>Teacher: Forward pass (teacher predictions, no grad)
        Trainer->>Trainer: Compute L_task (student loss)
        Trainer->>Trainer: Compute L_distill (teacher-student loss)
        Trainer->>Trainer: Combine: L_total = α·L_task + (1-α)·L_distill
        Trainer->>Trainer: Backward & optimize
    end
    
    Trainer->>Trainer: Save checkpoint & quantized model
    Trainer-->>User: Training complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly describes the main addition: an LTX-2 Distillation Trainer. It is clear, specific, and accurately reflects the primary change in the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 93.75% which is sufficient. The required threshold is 80.00%.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch mxin/ltx-distill

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mxinO mxinO changed the title LTX-2 Distillation Trainer [OMNIML-3505] LTX-2 Distillation Trainer Feb 14, 2026
@codecov
Copy link

codecov bot commented Feb 14, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.74%. Comparing base (5c4ef8e) to head (ce45cde).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #892   +/-   ##
=======================================
  Coverage   73.74%   73.74%           
=======================================
  Files         199      199           
  Lines       21163    21163           
=======================================
  Hits        15606    15606           
  Misses       5557     5557           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
@mxinO mxinO marked this pull request as ready for review February 14, 2026 03:41
@mxinO mxinO requested a review from a team as a code owner February 14, 2026 03:41
@mxinO mxinO requested a review from jingyu-ml February 14, 2026 03:41
@mxinO mxinO self-assigned this Feb 14, 2026
@mxinO mxinO requested review from Edwardf0t1 and kaix-nv February 14, 2026 03:42
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In `@examples/diffusers/distillation/configs/accelerate/fsdp.yaml`:
- Line 8: The usage comment references the wrong config filename—update the
commented flag line that mentions "--config_file
configs/accelerate/fsdp_multinode.yaml" so it matches the actual file name
"fsdp.yaml" (or change it to the intended multinode config if that file should
exist); modify the comment near the "--config_file" example in the file to
reference "configs/accelerate/fsdp.yaml" (or create/point to
"fsdp_multinode.yaml" consistently) so the example and actual filenames match.

In `@examples/diffusers/distillation/configs/distillation_example.yaml`:
- Line 86: The inline comment for optimizer_type is misleading: it claims "Use
8-bit optimizer for memory efficiency" while the config sets optimizer_type:
"adamw" (full-precision). Update either the comment or the value so they
match—e.g., change optimizer_type to an 8-bit optimizer name (if supported) or
modify the comment to reflect that "adamw" is full-precision; reference
optimizer_type and its value "adamw" when making the change.
- Line 60: The example config's resume_from_checkpoint setting currently
defaults to "latest", causing fresh runs to unintentionally resume; update the
example in distillation_example.yaml by changing the resume_from_checkpoint
field from "latest" to null (or remove/comment the line) so runs do not
auto-resume, and add a brief inline comment indicating users can set it to
"latest" or a specific checkpoint path when intentional.

In `@examples/diffusers/distillation/distillation_trainer.py`:
- Around line 746-750: The error message raised when
self._cached_calibration_embeddings is empty concatenates two string literals
without a space, producing "available!Probably"; update the RuntimeError message
in the block that checks self._cached_calibration_embeddings (the raise
RuntimeError in distillation_trainer.py) to include a space or combine into a
single properly spaced sentence so the message reads e.g. "No cached calibration
embeddings available! Probably the saved checkpoint has no modelopt_state.pt or
is corrupted."
- Around line 1676-1679: The reported steps_per_second is inflated because
actual_steps is set to cfg.optimization.steps - resume_step (planned steps)
rather than the number of steps actually executed when the loop exits early due
to a time-limit break; change the calculation of actual_steps to use the real
completed step count (e.g., current_step, steps_done, trainer_state.global_step,
or the loop's step counter) instead of cfg.optimization.steps — compute
actual_steps = max(0, completed_steps - resume_step) (where completed_steps is
the finalized step counter at loop exit), then recompute steps_per_second =
actual_steps / total_time_seconds (guarding for total_time_seconds == 0).
- Around line 1228-1232: The atomic rename can fail if final_dir already exists;
replace the direct tmp_dir.rename(final_dir) call inside the is_global_rank0()
block with a safe sequence: if final_dir.exists(): remove it (use shutil.rmtree
for directories) then perform the rename, and wrap the operation in a try/except
to catch and log/raise errors; update the code locations referencing tmp_dir,
final_dir, and the call to tmp_dir.rename(final_dir) so the rename first ensures
final_dir is removed before renaming and only runs on is_global_rank0().
- Around line 1043-1048: The cosine branch in the distillation loss currently
returns early and ignores loss_mask; modify the block handling loss_type ==
"cosine" (where student_pred, teacher_pred, student_flat, teacher_flat, cos_sim
are computed) to apply the existing loss_mask: flatten or broadcast loss_mask to
match cos_sim's shape, multiply or index cos_sim by the mask, compute the masked
mean (sum(masked cosine distances) / mask.sum().clamp_min(1)), and return that
masked loss instead of the unmasked cos_sim.mean(); ensure you do not
short-circuit before the mask logic used elsewhere.
- Line 524: Remove the debug print that dumps the full model architecture (the
line printing f"Quantized model: {self._transformer}"); replace it with a
logger.debug call or remove it entirely so the model isn't printed to stdout on
every quantized run—locate the print in the DistillationTrainer (or the method
containing self._transformer) and change it to logger.debug("Quantized model:
%s", self._transformer) or simply delete the line.

In `@examples/diffusers/distillation/README.md`:
- Line 128: The README description for calibration_prompts_file is inverted;
update the table entry for calibration_prompts_file to read that it is a text
file with one prompt per line and that if the value is null the code falls back
to the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (i.e., "Text
file with one prompt per line. If null, uses the HuggingFace dataset
'Gustavosta/Stable-Diffusion-Prompts'.") so the prose matches the actual
behavior referenced by calibration_prompts_file.
🧹 Nitpick comments (4)
examples/diffusers/distillation/requirements.txt (1)

1-4: Pin git dependencies to a specific commit or tag for reproducibility.

All three LTX-2 packages reference the default branch without a commit hash or tag. If the upstream repo introduces breaking changes, this example will silently break. Consider pinning, e.g.:

-ltx-core @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core
+ltx-core @ git+https://github.com/Lightricks/LTX-2.git@<commit-or-tag>#subdirectory=packages/ltx-core

Same for ltx-pipelines and ltx-trainer. Also consider pinning omegaconf to a version range.

examples/diffusers/distillation/distillation_trainer.py (3)

1759-1773: Remove or gate debug prints behind a flag.

Multiple print(f"[DEBUG] ...") statements will emit on every rank in every run. These are noisy in production multi-node jobs. Consider using logger.debug or removing them.


594-596: weights_only=False in torch.load — security note is acknowledged but still risky.

The comment notes this is ModelOpt-generated state, but if a user points resume_from_checkpoint to an untrusted path, arbitrary code execution is possible. Consider using weights_only=True if the state dict doesn't contain unpicklable objects, or add a more prominent warning in the config docs.


939-940: Use register_save_state_pre_hook() to filter models instead of accessing private _models/_optimizers attributes.

Accelerate doesn't provide a public unregister API for tracked models/optimizers. Instead of mutating _models and _optimizers directly, use the public hook API to filter out the teacher model before checkpointing:

def drop_teacher_from_checkpoint(models, optimizers, input_dir):
    # Remove teacher model and its paired optimizer from checkpoint
    for i in range(len(models) - 1, -1, -1):
        if models[i] is self._teacher_transformer:
            models.pop(i)
            optimizers.pop(i)
            break

self._accelerator.register_save_state_pre_hook(drop_teacher_from_checkpoint)

This approach avoids coupling to private implementation details while achieving the same goal.

Signed-off-by: Meng Xin <mxin@nvidia.com>
@mxinO mxinO enabled auto-merge (squash) February 14, 2026 10:19
Copy link
Contributor

@jingyu-ml jingyu-ml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, lets merge it first

@mxinO mxinO merged commit ca1f968 into main Feb 14, 2026
36 checks passed
@mxinO mxinO deleted the mxin/ltx-distill branch February 14, 2026 21:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants