From 3c3bffff392d83825fe9358f1f044a0990497aa1 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Mon, 6 Apr 2026 12:15:31 -0700 Subject: [PATCH] Fix ZeroDivisionError in MetricLogger with manual step timing --- src/maxtext/trainers/post_train/sft/hooks.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index 03fa610837..1d0cda71f7 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -25,6 +25,7 @@ from typing_extensions import override import jax +import time import jax.numpy as jnp from flax import nnx @@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder): self.metadata = {} self.train_metadata = defaultdict(float) self.eval_metadata = defaultdict(float) + self.step_start_time = 0.0 @override def on_train_start(self, train_ctx: peft_trainer.PeftTrainer): @@ -93,6 +95,7 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable= @override def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of a training step.""" + self.step_start_time = time.time() if self.config.enable_goodput_recording: record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps) @@ -109,7 +112,7 @@ def on_train_step_end( train_ctx: peft_trainer.PeftTrainer, train_step: int, train_loss: float, - step_time: float, + deprecated_step_time: float = 0.0, # No longer provided. See https://github.com/google/tunix/pull/1289. ): """Called at the end of training step. This hook is called by Tunix after the step counter has been incremented for logging purposes. @@ -125,13 +128,16 @@ def on_train_step_end( if self.metadata["first_train_step"] == train_step - 1: max_utils.print_mem_stats("After params initialized") + # Use our own timing since Tunix might pass 0.0 + actual_step_time = time.time() - self.step_start_time + metrics = { "scalar": { "learning/loss": train_loss, "learning/total_weights": self.train_metadata[train_step - 1]["total_weights"], } } - self.metric_logger.record_train_metrics(metrics, train_step, step_time) + self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time) self.metric_logger.write_metrics(metrics, train_step) del self.train_metadata[train_step - 1]