Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/maxtext/trainers/post_train/sft/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

from tunix.sft import metrics_logger, peft_trainer, profiler

from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from maxtext.configs import pyconfig
from maxtext.trainers.pre_train.train import loss_fn
from maxtext.common.goodput import (
Expand Down Expand Up @@ -135,6 +136,10 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
"targets_position": targets_position,
"targets_segmentation": targets_segmentation,
}
# If the model is wrapped in TunixMaxTextAdapter, we should unwrap it
# because MaxText's loss_fn expects the raw Transformer interface.
if isinstance(model, TunixMaxTextAdapter):
model = model.base
return loss_fn(model, mt_config, data, dropout_rng=None, params=None, is_train=True)

trainer = trainer.with_loss_fn(loss_func, has_aux=True)
Expand All @@ -147,6 +152,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None):

with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
model, mesh = model_creation_utils.create_nnx_model(mt_config)

# Wrap model with Tunix adapter for consistent interface
model = TunixMaxTextAdapter(model)

learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
# pass in model for muon
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)
Expand Down
Loading