-
Notifications
You must be signed in to change notification settings - Fork 0
Description
By adding SFT support, we can leverage TorchTitan’s 3D parallelism (FSDP2 /TP / CP) for instruction tuning, alignment task and reasoning tasks. Key implementation features on top of a general SFT loop would be-
a) Packed Sequence Loading with Masking: Implement a collator that packs variable-length instruction-response pairs into fixed blocks. Crucially, this must support label_mask to ensure loss is calculated only on assistant responses, not user prompts (as mentioned in this Issue).
b) Smart Batching & Auto-Accumulation: Add a BatchSizeConfig utility that automatically calculates the optimal micro_batch_size and gradient_accumulation_steps based on detected hardware limits (e.g., H100 vs B200) to hit the target global batch size without OOM errors.
c) Adaptive Context Parallelism: Logic to dynamically switch CP strategies (e.g., from ZigZag to Ring/Llama3 style) when packed datasets are detected, ensuring that attention masking correctly respects document boundaries across distributed ranks.
d) Hybrid Checkpoint Recovery: A robust loading mechanism that automatically detects context: it should strictly load weights-only when initializing from a pre-trained base model, but fully load optimizer states if resuming a crashed SFT run.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status