diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index 03fa610837..dd66b5796f 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -125,13 +125,19 @@ def on_train_step_end( if self.metadata["first_train_step"] == train_step - 1: max_utils.print_mem_stats("After params initialized") + # Try to get high-fidelity metrics from Tunix PerfTracer + # Fallback to a small epsilon to avoid ZeroDivisionError in MetricLogger + # Note: the `step_time` argument is currently hardcoded to 0.0 in Tunix library. + tracer_metrics = train_ctx._perf_tracer.export() # pylint: disable=protected-access + actual_step_time, _ = tracer_metrics.get("perf/step_time_seconds", (1e-6, None)) + metrics = { "scalar": { "learning/loss": train_loss, "learning/total_weights": self.train_metadata[train_step - 1]["total_weights"], } } - self.metric_logger.record_train_metrics(metrics, train_step, step_time) + self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time) self.metric_logger.write_metrics(metrics, train_step) del self.train_metadata[train_step - 1] diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..f991eb4b3c 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -48,6 +48,7 @@ from orbax import checkpoint as ocp from tunix.sft import metrics_logger, peft_trainer, profiler +from tunix.perf import trace, span from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn @@ -141,6 +142,23 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ return trainer +def sft_perf_export_fn(query): + """Custom exporter to extract SFT step time from Tunix spans.""" + # PeftTrainer uses a Span named "peft_train_step" (not a SpanGroup). + # We have to find it manually in the root inner items. + timelines = query._timelines # pylint: disable=protected-access + main_tl = timelines.get(query.get_main_thread_id()) + if not main_tl: + return {} + + # Search for the last peft_train_step span + for item in reversed(main_tl.root.inner): + if isinstance(item, span.Span) and item.name == "peft_train_step": + return {"perf/step_time_seconds": (item.duration, None)} + + return {} + + def setup_trainer_state(mt_config, goodput_recorder=None): """Set up prerequisites for training loop.""" tunix_config = get_tunix_config(mt_config) @@ -161,7 +179,11 @@ def setup_trainer_state(mt_config, goodput_recorder=None): training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + # Initialize the official Tunix performance tracer + # Convert mesh devices to list to avoid ambiguous truth value error in Tunix tracer + perf_tracer = trace.PerfTracer(mesh.devices.flatten().tolist(), sft_perf_export_fn) + + trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config, perf_tracer=perf_tracer) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config)