From 37c3296e33a9667b270b9a3407a81f4e079ecab2 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Sun, 5 Apr 2026 17:55:39 -0700 Subject: [PATCH] Standardize SFT-Tunix integration and add manual step timing --- src/maxtext/trainers/post_train/sft/hooks.py | 8 ++++- .../trainers/post_train/sft/train_sft.py | 30 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index 03fa610837..86c0fa3bec 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -25,6 +25,7 @@ from typing_extensions import override import jax +import time import jax.numpy as jnp from flax import nnx @@ -54,6 +55,7 @@ def __init__(self, config, mesh, learning_rate_schedule, goodput_recorder): self.metadata = {} self.train_metadata = defaultdict(float) self.eval_metadata = defaultdict(float) + self.step_start_time = 0.0 @override def on_train_start(self, train_ctx: peft_trainer.PeftTrainer): @@ -93,6 +95,7 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable= @override def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of a training step.""" + self.step_start_time = time.time() if self.config.enable_goodput_recording: record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps) @@ -125,13 +128,16 @@ def on_train_step_end( if self.metadata["first_train_step"] == train_step - 1: max_utils.print_mem_stats("After params initialized") + # Use our own timing since Tunix might pass 0.0 + actual_step_time = time.time() - self.step_start_time + metrics = { "scalar": { "learning/loss": train_loss, "learning/total_weights": self.train_metadata[train_step - 1]["total_weights"], } } - self.metric_logger.record_train_metrics(metrics, train_step, step_time) + self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time) self.metric_logger.write_metrics(metrics, train_step) del self.train_metadata[train_step - 1] diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..87b8c5e51f 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -43,12 +43,14 @@ import optax import pathwaysutils -from flax.linen import partitioning as nn_partitioning +import flax.linen as nn +from flax import nnx from orbax import checkpoint as ocp from tunix.sft import metrics_logger, peft_trainer, profiler +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( @@ -109,6 +111,7 @@ def get_tunix_config(mt_config): checkpointing_options=checkpointing_options, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, + skip_sharding_optimizer=True, ) @@ -135,6 +138,10 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ "targets_position": targets_position, "targets_segmentation": targets_segmentation, } + # If the model is wrapped in TunixMaxTextAdapter, we should unwrap it + # because MaxText's loss_fn expects the raw Transformer interface. + if isinstance(model, TunixMaxTextAdapter): + model = model.base return loss_fn(model, mt_config, data, dropout_rng=None, params=None, is_train=True) trainer = trainer.with_loss_fn(loss_func, has_aux=True) @@ -147,6 +154,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + + # Wrap model with Tunix adapter for consistent interface + model = TunixMaxTextAdapter(model) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -157,11 +168,24 @@ def setup_trainer_state(mt_config, goodput_recorder=None): optimizer, ) + # Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer + # Tunix will now detect it's already sharded and skip its internal sharding logic. + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param) + opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState) + opt_pspecs = nnx.get_partition_spec(opt_state) + opt_sharded_state = jax.lax.with_sharding_constraint(opt_state, opt_pspecs) + nnx.update(nnx_optimizer, opt_sharded_state) + with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + # Create the trainer within the correct JAX context to ensure checkpoint restoration + # and internal sharding checks have access to the hardware mesh and axis rules. + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + trainer = peft_trainer.PeftTrainer(model, nnx_optimizer, tunix_config) + trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) @@ -171,7 +195,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config, trainer, mesh): """Runs the SFT training loop in Tunix.""" - with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + with jax.set_mesh(mesh), mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) return trainer