Conversation
Signed-off-by: Meng Xin <mxin@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds LTX-2 distillation training support, including a comprehensive trainer implementation that integrates ModelOpt quantization, distributed training via FSDP, teacher-student model coordination, and a CLI-driven workflow with calibration and checkpoint management. Changes
Sequence DiagramsequenceDiagram
actor User
participant Main as Main Script
participant Trainer as DistillationTrainer
participant Teacher as Teacher Model
participant Calibration as Calibration Engine
participant Quantizer as ModelOpt Quantizer
participant Student as Student Model
participant DataLoader as DataLoader
User->>Main: Run with YAML config + CLI overrides
Main->>Trainer: Instantiate with config
Trainer->>Teacher: Load & freeze teacher model
Trainer->>Trainer: Prepare teacher with accelerator
alt Quantization Path
Trainer->>Calibration: Collect statistics via forward pass
Calibration->>DataLoader: Fetch calibration prompts
DataLoader-->>Calibration: Batch data
Calibration->>Teacher: Forward pass through teacher
Calibration->>Trainer: Cache embeddings
Trainer->>Quantizer: Apply quantization to student
end
Trainer->>Student: Prepare student model (post-quant)
Trainer->>Trainer: Start training loop
loop Each Training Step
DataLoader->>Trainer: Provide batch (latents, conditions)
Trainer->>Student: Forward pass (student predictions)
Trainer->>Teacher: Forward pass (teacher predictions, no grad)
Trainer->>Trainer: Compute L_task (student loss)
Trainer->>Trainer: Compute L_distill (teacher-student loss)
Trainer->>Trainer: Combine: L_total = α·L_task + (1-α)·L_distill
Trainer->>Trainer: Backward & optimize
end
Trainer->>Trainer: Save checkpoint & quantized model
Trainer-->>User: Training complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #892 +/- ##
=======================================
Coverage 73.74% 73.74%
=======================================
Files 199 199
Lines 21163 21163
=======================================
Hits 15606 15606
Misses 5557 5557 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
Signed-off-by: Meng Xin <mxin@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Fix all issues with AI agents
In `@examples/diffusers/distillation/configs/accelerate/fsdp.yaml`:
- Line 8: The usage comment references the wrong config filename—update the
commented flag line that mentions "--config_file
configs/accelerate/fsdp_multinode.yaml" so it matches the actual file name
"fsdp.yaml" (or change it to the intended multinode config if that file should
exist); modify the comment near the "--config_file" example in the file to
reference "configs/accelerate/fsdp.yaml" (or create/point to
"fsdp_multinode.yaml" consistently) so the example and actual filenames match.
In `@examples/diffusers/distillation/configs/distillation_example.yaml`:
- Line 86: The inline comment for optimizer_type is misleading: it claims "Use
8-bit optimizer for memory efficiency" while the config sets optimizer_type:
"adamw" (full-precision). Update either the comment or the value so they
match—e.g., change optimizer_type to an 8-bit optimizer name (if supported) or
modify the comment to reflect that "adamw" is full-precision; reference
optimizer_type and its value "adamw" when making the change.
- Line 60: The example config's resume_from_checkpoint setting currently
defaults to "latest", causing fresh runs to unintentionally resume; update the
example in distillation_example.yaml by changing the resume_from_checkpoint
field from "latest" to null (or remove/comment the line) so runs do not
auto-resume, and add a brief inline comment indicating users can set it to
"latest" or a specific checkpoint path when intentional.
In `@examples/diffusers/distillation/distillation_trainer.py`:
- Around line 746-750: The error message raised when
self._cached_calibration_embeddings is empty concatenates two string literals
without a space, producing "available!Probably"; update the RuntimeError message
in the block that checks self._cached_calibration_embeddings (the raise
RuntimeError in distillation_trainer.py) to include a space or combine into a
single properly spaced sentence so the message reads e.g. "No cached calibration
embeddings available! Probably the saved checkpoint has no modelopt_state.pt or
is corrupted."
- Around line 1676-1679: The reported steps_per_second is inflated because
actual_steps is set to cfg.optimization.steps - resume_step (planned steps)
rather than the number of steps actually executed when the loop exits early due
to a time-limit break; change the calculation of actual_steps to use the real
completed step count (e.g., current_step, steps_done, trainer_state.global_step,
or the loop's step counter) instead of cfg.optimization.steps — compute
actual_steps = max(0, completed_steps - resume_step) (where completed_steps is
the finalized step counter at loop exit), then recompute steps_per_second =
actual_steps / total_time_seconds (guarding for total_time_seconds == 0).
- Around line 1228-1232: The atomic rename can fail if final_dir already exists;
replace the direct tmp_dir.rename(final_dir) call inside the is_global_rank0()
block with a safe sequence: if final_dir.exists(): remove it (use shutil.rmtree
for directories) then perform the rename, and wrap the operation in a try/except
to catch and log/raise errors; update the code locations referencing tmp_dir,
final_dir, and the call to tmp_dir.rename(final_dir) so the rename first ensures
final_dir is removed before renaming and only runs on is_global_rank0().
- Around line 1043-1048: The cosine branch in the distillation loss currently
returns early and ignores loss_mask; modify the block handling loss_type ==
"cosine" (where student_pred, teacher_pred, student_flat, teacher_flat, cos_sim
are computed) to apply the existing loss_mask: flatten or broadcast loss_mask to
match cos_sim's shape, multiply or index cos_sim by the mask, compute the masked
mean (sum(masked cosine distances) / mask.sum().clamp_min(1)), and return that
masked loss instead of the unmasked cos_sim.mean(); ensure you do not
short-circuit before the mask logic used elsewhere.
- Line 524: Remove the debug print that dumps the full model architecture (the
line printing f"Quantized model: {self._transformer}"); replace it with a
logger.debug call or remove it entirely so the model isn't printed to stdout on
every quantized run—locate the print in the DistillationTrainer (or the method
containing self._transformer) and change it to logger.debug("Quantized model:
%s", self._transformer) or simply delete the line.
In `@examples/diffusers/distillation/README.md`:
- Line 128: The README description for calibration_prompts_file is inverted;
update the table entry for calibration_prompts_file to read that it is a text
file with one prompt per line and that if the value is null the code falls back
to the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (i.e., "Text
file with one prompt per line. If null, uses the HuggingFace dataset
'Gustavosta/Stable-Diffusion-Prompts'.") so the prose matches the actual
behavior referenced by calibration_prompts_file.
🧹 Nitpick comments (4)
examples/diffusers/distillation/requirements.txt (1)
1-4: Pin git dependencies to a specific commit or tag for reproducibility.All three LTX-2 packages reference the default branch without a commit hash or tag. If the upstream repo introduces breaking changes, this example will silently break. Consider pinning, e.g.:
-ltx-core @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core +ltx-core @ git+https://github.com/Lightricks/LTX-2.git@<commit-or-tag>#subdirectory=packages/ltx-coreSame for
ltx-pipelinesandltx-trainer. Also consider pinningomegaconfto a version range.examples/diffusers/distillation/distillation_trainer.py (3)
1759-1773: Remove or gate debug prints behind a flag.Multiple
print(f"[DEBUG] ...")statements will emit on every rank in every run. These are noisy in production multi-node jobs. Consider usinglogger.debugor removing them.
594-596:weights_only=Falseintorch.load— security note is acknowledged but still risky.The comment notes this is ModelOpt-generated state, but if a user points
resume_from_checkpointto an untrusted path, arbitrary code execution is possible. Consider usingweights_only=Trueif the state dict doesn't contain unpicklable objects, or add a more prominent warning in the config docs.
939-940: Useregister_save_state_pre_hook()to filter models instead of accessing private_models/_optimizersattributes.Accelerate doesn't provide a public unregister API for tracked models/optimizers. Instead of mutating
_modelsand_optimizersdirectly, use the public hook API to filter out the teacher model before checkpointing:def drop_teacher_from_checkpoint(models, optimizers, input_dir): # Remove teacher model and its paired optimizer from checkpoint for i in range(len(models) - 1, -1, -1): if models[i] is self._teacher_transformer: models.pop(i) optimizers.pop(i) break self._accelerator.register_save_state_pre_hook(drop_teacher_from_checkpoint)This approach avoids coupling to private implementation details while achieving the same goal.
examples/diffusers/distillation/configs/distillation_example.yaml
Outdated
Show resolved
Hide resolved
Signed-off-by: Meng Xin <mxin@nvidia.com>
jingyu-ml
left a comment
There was a problem hiding this comment.
LGTM, lets merge it first
What does this PR do?
Type of change: new example
Overview:
Adding LTX-2 distillation trainer.
Usage
accelerate launch \ --config_file configs/accelerate/fsdp.yaml \ --num_processes 8 \ distillation_trainer.py --config configs/distillation_example.yamlSee readme for more details.
Testing
Run training with single/multiple nodes.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features