Add Megatron-Bridge recipe-free distillation example script#861
Add Megatron-Bridge recipe-free distillation example script#861kevalmorabia97 merged 6 commits intomainfrom
Conversation
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThe pull request extends the Megatron-Bridge examples with a comprehensive distillation workflow, including a new distill.py script for orchestrating student model distillation from teacher models, expanded documentation with end-to-end instructions, and minor enhancements to logging and utility scripts. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as Command Line
participant Main as main(args)
participant HF as HuggingFace<br/>Checkpoints
participant Bridge as AutoBridge<br/>Providers
participant Distill as DistillationProvider
participant Config as ConfigContainer
participant Trainer as distill()
CLI->>Main: Parse arguments (student/teacher HF paths, data, parallelism)
Main->>HF: Load student & teacher checkpoints
HF-->>Bridge: Return models
Bridge->>Bridge: Build Megatron providers
Bridge->>Bridge: Override parallelism & training settings
Main->>Distill: Wrap providers with DistillationProvider
Main->>Config: Assemble dataset, optimizer, scheduler,<br/>logging, checkpoint configs
Config-->>Trainer: Pass ConfigContainer
Main->>Trainer: Execute distill(config)
Trainer->>Trainer: Create output/checkpoint directories
Trainer->>Trainer: Run distributed training loop
Trainer-->>Main: Report completion
Main->>Main: Cleanup distributed environment
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #861 +/- ##
==========================================
- Coverage 73.72% 73.44% -0.28%
==========================================
Files 196 197 +1
Lines 20457 20657 +200
==========================================
+ Hits 15082 15172 +90
- Misses 5375 5485 +110 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| ) | ||
|
|
||
| print_rank_0("\nStarting distillation...") | ||
| distill(config) |
There was a problem hiding this comment.
Should we make it like the Nemo one where it can do either pretrain(), distill(), or finetune() all in one file? (@ChenhanYu would that be preferred?)
There was a problem hiding this comment.
How about we incrementally extend this file as we get to needing these options?
There was a problem hiding this comment.
Maybe I should rename to train.py?
There was a problem hiding this comment.
I guess right now we can easily just put a pretrain() call if the KD-specific args aren't provided.
SFT can be done later since dataset/template/etc is different.
examples/megatron_bridge/README.md
Outdated
|
|
||
| ```bash | ||
| python /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help | ||
| torchrun --nproc_per_node 1 /opt/Megatron-Bridge/3rdparty/Model-Optimizer/examples/megatron_bridge/prune_minitron.py --help |
There was a problem hiding this comment.
I want to only print help on rank 0 so need to initialize multiprocesses which will only happen on torchrun. I am not spawning multiprocesses in the script so running with python ... will also result in an error trying to fine RANK env variable during dist setup
There was a problem hiding this comment.
mbridge has it's own print_rank_0 which accounts for that
There was a problem hiding this comment.
We should additionally change our own print_rank_0 to work without dist initialized
There was a problem hiding this comment.
Our print_rank_0 works fine on non-dist env. The issue here is I am manually doing dist.setup() which fails if not running with torchrun. Since I am doing all low-level M-bridge stuff myself because of lack of top-level APIs, we dont get the dist setup from M-Bridge
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
8d780a5 to
48c74bd
Compare
eb0aa58 to
ce4d081
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
d0f930f to
1df11df
Compare
1df11df to
808c1e0
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/megatron_bridge/README.md (1)
11-11:⚠️ Potential issue | 🟡 MinorGrammatical error: "distillation" → "distilling".
"Examples of distillation a pruned or quantized model" should read "Examples of distilling a pruned or quantized model".
Proposed fix
-| Distillation | Examples of distillation a pruned or quantized model | \[[Link](`#distillation`)\] | | +| Distillation | Examples of distilling a pruned or quantized model | \[[Link](`#distillation`)\] | |
🤖 Fix all issues with AI agents
In `@examples/megatron_bridge/distill.py`:
- Around line 120-122: Make --use_mock_data and --data_paths mutually exclusive
instead of silently letting mock data win: when building the CLI parser, create
a mutually exclusive group via parser.add_mutually_exclusive_group() and add the
two flags to that group (referencing args.use_mock_data and args.data_paths),
then remove the manual validation block that raises ValueError for
neither-provided; this ensures argparse enforces exclusivity and you can keep
the later code path that reads data_paths unchanged.
- Around line 132-134: The code computes checkpoint_dir and tensorboard_dir in
main(args: argparse.Namespace) but never ensures they exist; add explicit
directory creation before these paths are passed into
CheckpointConfig/LoggerConfig by calling os.makedirs(checkpoint_dir,
exist_ok=True) and os.makedirs(tensorboard_dir, exist_ok=True) (ensure imports
include os if not already) so the directories derived from args.output_dir are
created ahead of use.
🧹 Nitpick comments (3)
modelopt/torch/utils/plugins/megatron_preprocess_data.py (1)
113-116: Minor:num2hrbon small document counts displays decimals (e.g."5.00 docs").When
countis small,num2hrbformats it as"5.00"which reads slightly oddly for a document count. This is cosmetic and doesn't affect functionality — just worth noting if you want polished early-iteration output.examples/megatron_bridge/README.md (1)
36-38: Hardcoded Python 3.12 path in site-packages mount is fragile.The volume mount path
/opt/venv/lib/python3.12/site-packages/modeloptassumes the NeMo container uses Python 3.12. This will silently break if a future container version changes the Python version. Since you pinnemo:26.02, this is acceptable for now, but consider adding a comment noting the Python version dependency so future maintainers know to update this path.Suggested comment
-v ${MODELOPT_DIR}:/opt/Model-Optimizer \ - -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ + -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ # Update python3.12 if container Python version changesexamples/megatron_bridge/distill.py (1)
162-168: Hardcodedadam_beta2=0.98— consider exposing as CLI arg or documenting the choice.
adam_beta2=0.98differs from the common default of0.999. While0.98is reasonable for distillation/pre-training, it's not configurable via CLI. A comment explaining the choice would help users who want to tune this.
808c1e0 to
86e81b1
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
86e81b1 to
50b6b7e
Compare
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
| dataset_kwargs = { | ||
| "seq_length": args.seq_length, | ||
| "path_to_cache": args.data_path_to_cache, | ||
| "random_seed": SEED, |
There was a problem hiding this comment.
can the seed also be a random arg?
There was a problem hiding this comment.
You mean randomly generate everytime? Then the results may not be reproducible
| --ulimit memlock=-1 \ | ||
| --rm -it \ | ||
| -v ${MODELOPT_DIR}:/opt/Model-Optimizer \ | ||
| -v ${MODELOPT_DIR}/modelopt:/opt/venv/lib/python3.12/site-packages/modelopt \ |
There was a problem hiding this comment.
why is mounting to venv also necessary?
There was a problem hiding this comment.
So users can mount library and examples from same version. This avoids the case where user uses old modelopt but with examples from main branch
| To convert the Megatron checkpoint from last iteration (or any intermediate iteration) to Hugging Face format, you need the pruned model config (`--output_hf_path` from `prune_minitron.py` script) and the distilled megatron checkpoint dir (`<distill_output_dir>/checkpoints/iter_<iter_number>`) to run the following command: | ||
|
|
||
| ```bash | ||
| uv run python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ |
There was a problem hiding this comment.
do we assume the user already has uv installed?
There was a problem hiding this comment.
Its in the nemo container so already installed
What does this PR do?
Type of change: New example script
Usage
Testing
Best subnet from NAS:
{'num_layers': 30, 'hidden_size': 3584, 'ffn_hidden_size': 11776} -> 5.99B params, 0.5718 scorePrevious Nemo2 experiments on depth pruned Qwen3 8B -> 6B (24 layers) had MMLU ~72.0 so more or less similar. No hparam tuning done for current M-Bridge distillation run
Before your PR is "Ready for review"
Summary by CodeRabbit
Release Notes
New Features
Documentation
Improvements