Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/maxtext/trainers/post_train/sft/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing_extensions import override

import jax
import time
import jax.numpy as jnp

from flax import nnx
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder):
self.metadata = {}
self.train_metadata = defaultdict(float)
self.eval_metadata = defaultdict(float)
self.step_start_time = 0.0

@override
def on_train_start(self, train_ctx: peft_trainer.PeftTrainer):
Expand Down Expand Up @@ -93,6 +95,7 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable=
@override
def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer):
"""Called at the beginning of a training step."""
self.step_start_time = time.time()
if self.config.enable_goodput_recording:
record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps)

Expand All @@ -109,7 +112,7 @@ def on_train_step_end(
train_ctx: peft_trainer.PeftTrainer,
train_step: int,
train_loss: float,
step_time: float,
deprecated_step_time: float = 0.0, # No longer provided. See https://github.com/google/tunix/pull/1289.
):
"""Called at the end of training step.
This hook is called by Tunix after the step counter has been incremented for logging purposes.
Expand All @@ -125,13 +128,16 @@ def on_train_step_end(
if self.metadata["first_train_step"] == train_step - 1:
max_utils.print_mem_stats("After params initialized")

# Use our own timing since Tunix might pass 0.0
actual_step_time = time.time() - self.step_start_time

metrics = {
"scalar": {
"learning/loss": train_loss,
"learning/total_weights": self.train_metadata[train_step - 1]["total_weights"],
}
}
self.metric_logger.record_train_metrics(metrics, train_step, step_time)
self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time)
self.metric_logger.write_metrics(metrics, train_step)
del self.train_metadata[train_step - 1]

Expand Down
Loading