From 0bf4907a362cdf289e96098dd7c2e11cacb0e264 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Tue, 7 Apr 2026 15:45:52 -0700 Subject: [PATCH] Standardize SFT Model Interface with TunixMaxTextAdapter --- src/maxtext/trainers/post_train/sft/train_sft.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..057bbdb446 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -49,6 +49,7 @@ from tunix.sft import metrics_logger, peft_trainer, profiler +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train import loss_fn from maxtext.common.goodput import ( @@ -135,6 +136,10 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ "targets_position": targets_position, "targets_segmentation": targets_segmentation, } + # If the model is wrapped in TunixMaxTextAdapter, we should unwrap it + # because MaxText's loss_fn expects the raw Transformer interface. + if isinstance(model, TunixMaxTextAdapter): + model = model.base return loss_fn(model, mt_config, data, dropout_rng=None, params=None, is_train=True) trainer = trainer.with_loss_fn(loss_func, has_aux=True) @@ -147,6 +152,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + + # Wrap model with Tunix adapter for consistent interface + model = TunixMaxTextAdapter(model) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)