Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/maxtext/trainers/post_train/sft/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing_extensions import override

import jax
import time
import jax.numpy as jnp

from flax import nnx
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder):
self.metadata = {}
self.train_metadata = defaultdict(float)
self.eval_metadata = defaultdict(float)
self.step_start_time = 0.0

@override
def on_train_start(self, train_ctx: peft_trainer.PeftTrainer):
Expand Down Expand Up @@ -93,6 +95,7 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable=
@override
def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer):
"""Called at the beginning of a training step."""
self.step_start_time = time.time()
if self.config.enable_goodput_recording:
record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps)

Expand Down Expand Up @@ -125,13 +128,16 @@ def on_train_step_end(
if self.metadata["first_train_step"] == train_step - 1:
max_utils.print_mem_stats("After params initialized")

# Use our own timing since Tunix might pass 0.0
actual_step_time = time.time() - self.step_start_time

metrics = {
"scalar": {
"learning/loss": train_loss,
"learning/total_weights": self.train_metadata[train_step - 1]["total_weights"],
}
}
self.metric_logger.record_train_metrics(metrics, train_step, step_time)
self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time)
self.metric_logger.write_metrics(metrics, train_step)
del self.train_metadata[train_step - 1]

Expand Down
30 changes: 27 additions & 3 deletions src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
import optax
import pathwaysutils

from flax.linen import partitioning as nn_partitioning
import flax.linen as nn
from flax import nnx

from orbax import checkpoint as ocp

from tunix.sft import metrics_logger, peft_trainer, profiler

from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from maxtext.configs import pyconfig
from maxtext.trainers.pre_train.train import loss_fn
from maxtext.common.goodput import (
Expand Down Expand Up @@ -109,6 +111,7 @@ def get_tunix_config(mt_config):
checkpointing_options=checkpointing_options,
metrics_logging_options=metrics_logging_options,
profiler_options=profiler_options,
skip_sharding_optimizer=True,
)


Expand All @@ -135,6 +138,10 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
"targets_position": targets_position,
"targets_segmentation": targets_segmentation,
}
# If the model is wrapped in TunixMaxTextAdapter, we should unwrap it
# because MaxText's loss_fn expects the raw Transformer interface.
if isinstance(model, TunixMaxTextAdapter):
model = model.base
return loss_fn(model, mt_config, data, dropout_rng=None, params=None, is_train=True)

trainer = trainer.with_loss_fn(loss_func, has_aux=True)
Expand All @@ -147,6 +154,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
model, mesh = model_creation_utils.create_nnx_model(mt_config)

# Wrap model with Tunix adapter for consistent interface
model = TunixMaxTextAdapter(model)

learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
# pass in model for muon
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)
Expand All @@ -157,11 +168,24 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
optimizer,
)

# Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer
# Tunix will now detect it's already sharded and skip its internal sharding logic.
with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param)
opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState)
opt_pspecs = nnx.get_partition_spec(opt_state)
opt_sharded_state = jax.lax.with_sharding_constraint(opt_state, opt_pspecs)
nnx.update(nnx_optimizer, opt_sharded_state)

with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)

trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
# Create the trainer within the correct JAX context to ensure checkpoint restoration
# and internal sharding checks have access to the hardware mesh and axis rules.
with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
trainer = peft_trainer.PeftTrainer(model, nnx_optimizer, tunix_config)

trainer.with_training_hooks(training_hooks)
trainer.with_data_hooks(data_hooks)
trainer = use_maxtext_loss_function(trainer, mt_config)
Expand All @@ -171,7 +195,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

def train_model(mt_config, trainer, mesh):
"""Runs the SFT training loop in Tunix."""
with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
with jax.set_mesh(mesh), mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
return trainer

Expand Down
Loading