Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/maxtext/trainers/post_train/sft/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,19 @@ def on_train_step_end(
if self.metadata["first_train_step"] == train_step - 1:
max_utils.print_mem_stats("After params initialized")

# Try to get high-fidelity metrics from Tunix PerfTracer
# Fallback to a small epsilon to avoid ZeroDivisionError in MetricLogger
# Note: the `step_time` argument is currently hardcoded to 0.0 in Tunix library.
tracer_metrics = train_ctx._perf_tracer.export() # pylint: disable=protected-access
actual_step_time, _ = tracer_metrics.get("perf/step_time_seconds", (1e-6, None))

metrics = {
"scalar": {
"learning/loss": train_loss,
"learning/total_weights": self.train_metadata[train_step - 1]["total_weights"],
}
}
self.metric_logger.record_train_metrics(metrics, train_step, step_time)
self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time)
self.metric_logger.write_metrics(metrics, train_step)
del self.train_metadata[train_step - 1]

Expand Down
24 changes: 23 additions & 1 deletion src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from orbax import checkpoint as ocp

from tunix.sft import metrics_logger, peft_trainer, profiler
from tunix.perf import trace, span

from maxtext.configs import pyconfig
from maxtext.trainers.pre_train.train import loss_fn
Expand Down Expand Up @@ -141,6 +142,23 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
return trainer


def sft_perf_export_fn(query):
"""Custom exporter to extract SFT step time from Tunix spans."""
# PeftTrainer uses a Span named "peft_train_step" (not a SpanGroup).
# We have to find it manually in the root inner items.
timelines = query._timelines # pylint: disable=protected-access
main_tl = timelines.get(query.get_main_thread_id())
if not main_tl:
return {}

# Search for the last peft_train_step span
for item in reversed(main_tl.root.inner):
if isinstance(item, span.Span) and item.name == "peft_train_step":
return {"perf/step_time_seconds": (item.duration, None)}

return {}


def setup_trainer_state(mt_config, goodput_recorder=None):
"""Set up prerequisites for training loop."""
tunix_config = get_tunix_config(mt_config)
Expand All @@ -161,7 +179,11 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)

trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
# Initialize the official Tunix performance tracer
# Convert mesh devices to list to avoid ambiguous truth value error in Tunix tracer
perf_tracer = trace.PerfTracer(mesh.devices.flatten().tolist(), sft_perf_export_fn)

trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config, perf_tracer=perf_tracer)
trainer.with_training_hooks(training_hooks)
trainer.with_data_hooks(data_hooks)
trainer = use_maxtext_loss_function(trainer, mt_config)
Expand Down
Loading