4848from orbax import checkpoint as ocp
4949
5050from tunix .sft import metrics_logger , peft_trainer , profiler
51+ from tunix .perf import trace , span
5152
5253from maxtext .configs import pyconfig
5354from maxtext .trainers .pre_train .train import loss_fn
@@ -141,6 +142,23 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
141142 return trainer
142143
143144
145+ def sft_perf_export_fn (query ):
146+ """Custom exporter to extract SFT step time from Tunix spans."""
147+ # PeftTrainer uses a Span named "peft_train_step" (not a SpanGroup).
148+ # We have to find it manually in the root inner items.
149+ timelines = query ._timelines # pylint: disable=protected-access
150+ main_tl = timelines .get (query .get_main_thread_id ())
151+ if not main_tl :
152+ return {}
153+
154+ # Search for the last peft_train_step span
155+ for item in reversed (main_tl .root .inner ):
156+ if isinstance (item , span .Span ) and item .name == "peft_train_step" :
157+ return {"perf/step_time_seconds" : (item .duration , None )}
158+
159+ return {}
160+
161+
144162def setup_trainer_state (mt_config , goodput_recorder = None ):
145163 """Set up prerequisites for training loop."""
146164 tunix_config = get_tunix_config (mt_config )
@@ -161,7 +179,11 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
161179 training_hooks = hooks .SFTTrainingHooks (mt_config , mesh , learning_rate_schedule , goodput_recorder )
162180 data_hooks = hooks .SFTDataHooks (mt_config , mesh , goodput_recorder )
163181
164- trainer = peft_trainer .PeftTrainer (model , optimizer , tunix_config )
182+ # Initialize the official Tunix performance tracer
183+ # Convert mesh devices to list to avoid ambiguous truth value error in Tunix tracer
184+ perf_tracer = trace .PerfTracer (mesh .devices .flatten ().tolist (), sft_perf_export_fn )
185+
186+ trainer = peft_trainer .PeftTrainer (model , optimizer , tunix_config , perf_tracer = perf_tracer )
165187 trainer .with_training_hooks (training_hooks )
166188 trainer .with_data_hooks (data_hooks )
167189 trainer = use_maxtext_loss_function (trainer , mt_config )
0 commit comments