Skip to content

Commit abcb6e2

Browse files
committed
Fix Metric Reporting Crash (ZeroDivisionError) using Tunix PerfTracer
1 parent cb2ef64 commit abcb6e2

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

src/maxtext/trainers/post_train/sft/hooks.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,19 @@ def on_train_step_end(
125125
if self.metadata["first_train_step"] == train_step - 1:
126126
max_utils.print_mem_stats("After params initialized")
127127

128+
# Try to get high-fidelity metrics from Tunix PerfTracer
129+
# Fallback to a small epsilon to avoid ZeroDivisionError in MetricLogger
130+
# Note: the `step_time` argument is currently hardcoded to 0.0 in Tunix library.
131+
tracer_metrics = train_ctx._perf_tracer.export() # pylint: disable=protected-access
132+
actual_step_time, _ = tracer_metrics.get("perf/step_time_seconds", (1e-6, None))
133+
128134
metrics = {
129135
"scalar": {
130136
"learning/loss": train_loss,
131137
"learning/total_weights": self.train_metadata[train_step - 1]["total_weights"],
132138
}
133139
}
134-
self.metric_logger.record_train_metrics(metrics, train_step, step_time)
140+
self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time)
135141
self.metric_logger.write_metrics(metrics, train_step)
136142
del self.train_metadata[train_step - 1]
137143

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from orbax import checkpoint as ocp
4949

5050
from tunix.sft import metrics_logger, peft_trainer, profiler
51+
from tunix.perf import trace, span
5152

5253
from maxtext.configs import pyconfig
5354
from maxtext.trainers.pre_train.train import loss_fn
@@ -141,6 +142,23 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
141142
return trainer
142143

143144

145+
def sft_perf_export_fn(query):
146+
"""Custom exporter to extract SFT step time from Tunix spans."""
147+
# PeftTrainer uses a Span named "peft_train_step" (not a SpanGroup).
148+
# We have to find it manually in the root inner items.
149+
timelines = query._timelines # pylint: disable=protected-access
150+
main_tl = timelines.get(query.get_main_thread_id())
151+
if not main_tl:
152+
return {}
153+
154+
# Search for the last peft_train_step span
155+
for item in reversed(main_tl.root.inner):
156+
if isinstance(item, span.Span) and item.name == "peft_train_step":
157+
return {"perf/step_time_seconds": (item.duration, None)}
158+
159+
return {}
160+
161+
144162
def setup_trainer_state(mt_config, goodput_recorder=None):
145163
"""Set up prerequisites for training loop."""
146164
tunix_config = get_tunix_config(mt_config)
@@ -161,7 +179,11 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
161179
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
162180
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
163181

164-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
182+
# Initialize the official Tunix performance tracer
183+
# Convert mesh devices to list to avoid ambiguous truth value error in Tunix tracer
184+
perf_tracer = trace.PerfTracer(mesh.devices.flatten().tolist(), sft_perf_export_fn)
185+
186+
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config, perf_tracer=perf_tracer)
165187
trainer.with_training_hooks(training_hooks)
166188
trainer.with_data_hooks(data_hooks)
167189
trainer = use_maxtext_loss_function(trainer, mt_config)

0 commit comments

Comments
 (0)