Skip to content

Add SFT to Torch Titan #167

@abhash-er

Description

@abhash-er

By adding SFT support, we can leverage TorchTitan’s 3D parallelism (FSDP2 /TP / CP) for instruction tuning, alignment task and reasoning tasks. Key implementation features on top of a general SFT loop would be-

a) Packed Sequence Loading with Masking: Implement a collator that packs variable-length instruction-response pairs into fixed blocks. Crucially, this must support label_mask to ensure loss is calculated only on assistant responses, not user prompts (as mentioned in this Issue).

b) Smart Batching & Auto-Accumulation: Add a BatchSizeConfig utility that automatically calculates the optimal micro_batch_size and gradient_accumulation_steps based on detected hardware limits (e.g., H100 vs B200) to hit the target global batch size without OOM errors.

c) Adaptive Context Parallelism: Logic to dynamically switch CP strategies (e.g., from ZigZag to Ring/Llama3 style) when packed datasets are detected, ensuring that attention masking correctly respects document boundaries across distributed ranks.

d) Hybrid Checkpoint Recovery: A robust loading mechanism that automatically detects context: it should strictly load weights-only when initializing from a pre-trained base model, but fully load optimizer states if resuming a crashed SFT run.

Metadata

Metadata

Assignees

Type

No type

Projects

Status

Todo

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions