From 68b4665d0e73182edeb3c168c1d879ba794dbc03 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Wed, 25 Mar 2026 16:48:22 -0700 Subject: [PATCH 1/6] WIP: Adding Tunix DPO support --- src/maxtext/configs/post_train/dpo.yml | 15 +- src/maxtext/configs/pyconfig.py | 5 +- .../input_pipeline/hf_data_processing.py | 3 + .../trainers/post_train/dpo/train_dpo.py | 179 ++++++++++++++++++ 4 files changed, 192 insertions(+), 10 deletions(-) create mode 100644 src/maxtext/trainers/post_train/dpo/train_dpo.py diff --git a/src/maxtext/configs/post_train/dpo.yml b/src/maxtext/configs/post_train/dpo.yml index dbcdadb1ba..4b1dde5a80 100644 --- a/src/maxtext/configs/post_train/dpo.yml +++ b/src/maxtext/configs/post_train/dpo.yml @@ -1,8 +1,8 @@ base_config: "base.yml" use_dpo: true -train_data_columns: ['chosen', 'rejected'] -eval_data_columns: ['chosen', 'rejected'] +train_data_columns: ['input', 'chosen', 'rejected'] +eval_data_columns: ['input', 'chosen', 'rejected'] base_output_directory: 'gs://maxtext-external/logs' per_device_batch_size: 2.0 @@ -12,11 +12,12 @@ eval_interval: 5 # test eval once, in the middle of 10 training steps eval_steps: 2 # TFDS Pipeline ---------------------- -dataset_type: tfds -dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' -dataset_name: 'tfds:1.0.0' -eval_dataset_name: 'tfds:1.0.0' -eval_split: 'test' +#dataset_type: tfds +#dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' +#dataset_name: 'tfds:1.0.0' +#eval_dataset_name: 'tfds:1.0.0' +#eval_split: 'test' +packing: False # DEBUG: DO NOT MERGE. # HF Pipeline ------------------------- hf_eval_split: 'test' diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 78d783270f..f44e8569e3 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -53,6 +53,7 @@ "maxtext.trainers.pre_train.train_compile": "base.yml", "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml", "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml", + "maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml", "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml", "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml", "maxtext.inference.decode": "base.yml", @@ -83,9 +84,7 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: return resolve_config_path(argv[1]), argv[2:] module = _module_from_path(argv[0]) if module not in _CONFIG_FILE_MAPPING: - raise ValueError( - f"No config file provided and no default config found for module '{module}'" - ) + raise ValueError(f"No config file provided and no default config found for module '{module}'") config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) logger.warning("No config file provided, using default config mapping: %s", config_path) return config_path, argv[1:] diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index d1be2c4890..df2681781e 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -318,6 +318,9 @@ def lists2array(x): return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple))) operations.append(grain.MapOperation(lists2array)) + + rekey_dict = {"prompts": "input", "chosen_responses": "chosen", "rejected_responses": "rejected"} + operations.append(input_pipeline_utils.Rekey(rekey_dict)) else: assert len(data_column_names) == 1 operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py new file mode 100644 index 0000000000..8bf162556a --- /dev/null +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -0,0 +1,179 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO Training script that uses Tunix DPOTrainer on a MaxText model. + +Example command: +Training & Evaluation: + python3 -m maxtext.trainers.post_train.dpo.train_dpo \ + run_name=${WORKLOAD?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + tokenizer_path="google/gemma-2-2b-it" tokenizer_type=huggingface \ + dataset_type="hf" hf_path="Anthropic/hh-rlhf" \ + model_name=${MODEL?} load_parameters_path=${MAXTEXT_CONVERTED_CHECKPOINT?}/0/items \ + hf_access_token=${HF_TOKEN?} per_device_batch_size=1 max_target_length=1024 \ + eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 +""" + +from absl import app +import jax +import optax +from orbax import checkpoint as ocp +import pathwaysutils + +from flax.linen import partitioning as nn_partitioning + +from tunix.sft import metrics_logger, profiler +from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig + +from maxtext.configs import pyconfig +from maxtext.utils import max_utils +from maxtext.common.goodput import ( + GoodputEvent, + RECORD_JOB_END_TIME, + RECORD_JOB_START_TIME, + create_goodput_recorder, + maybe_monitor_goodput, + maybe_record_goodput, + record_goodput, +) +from maxtext.optimizers import optimizers +from maxtext.trainers.post_train.sft import hooks +from maxtext.utils import max_logging +from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils + + +def get_tunix_config(mt_config: pyconfig.HyperParameters) -> DPOTrainingConfig: + """Gets the Tunix training configurations from the MaxText config. + + Args: + mt_config: MaxText config. + + Returns: + A Tunix `DPOTrainingConfig` object. + """ + # Checkpointing configurations + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=mt_config.checkpoint_period, + enable_async_checkpointing=mt_config.async_checkpointing, + ) + + # Metrics configurations + metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir) + + # Profiler configurations + profiler_options = None + if mt_config.profiler: + set_profile_options = True + platform_version = jax.extend.backend.get_backend().platform_version.strip() + if platform_version.startswith("Pathways"): + max_logging.log("Pathways backend detected. Disabling setting profile options.") + set_profile_options = False + profiler_options = profiler.ProfilerOptions( + log_dir=mt_config.tensorboard_dir, + skip_first_n_steps=mt_config.skip_first_n_steps_for_profiler, + profiler_steps=mt_config.profiler_steps, + set_profile_options=set_profile_options, + ) + + return DPOTrainingConfig( + eval_every_n_steps=mt_config.eval_interval, + max_steps=mt_config.steps, + gradient_accumulation_steps=mt_config.gradient_accumulation_steps, + checkpoint_root_directory=mt_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + metrics_logging_options=metrics_logging_options, + profiler_options=profiler_options, + algorithm="dpo", # TODO: add support of "orpo" + beta=mt_config.dpo_beta, + label_smoothing=mt_config.dpo_label_smoothing, + ) + + +def setup_trainer_state(mt_config, goodput_recorder=None): + """Set up prerequisites for training loop.""" + tunix_config = get_tunix_config(mt_config) + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): + model, mesh = model_creation_utils.create_nnx_model(mt_config) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) + # pass in model for muon + optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) + + if mt_config.gradient_clipping_threshold > 0: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold), + optimizer, + ) + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): + training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) + data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) + + trainer = DPOTrainer(model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=None) + trainer.with_training_hooks(training_hooks) + trainer.with_data_hooks(data_hooks) + + # TODO(igorts-git): do we need this? It exists in SFT. + # trainer = use_maxtext_loss_function(trainer, mt_config) + + return trainer, mesh + + +def train_model(mt_config: pyconfig.HyperParameters, trainer, mesh): + """Runs the DPO training loop in Tunix.""" + with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + return trainer + + +def train(mt_config, goodput_recorder=None): + """Main method for DPO training. + + Args: + mt_config: MaxText config. + goodput_recorder: An optional GoodputRecorder to record performance metrics. + """ + trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) + _job_completed_gracefully = False + try: + trainer = train_model(mt_config, trainer, mesh) + _job_completed_gracefully = True + finally: + if _job_completed_gracefully: + record_goodput(goodput_recorder, RECORD_JOB_END_TIME) + return trainer, mesh + + +def main(argv: list[str]) -> None: + """Main function to run DPO training. + + Args: + argv: Command-line arguments. + """ + # import debugpy; debugpy.listen(("localhost", 5678)); print("Attach VS Code now"); debugpy.wait_for_client() + + pathwaysutils.initialize() + + mt_config = pyconfig.initialize(argv) + max_utils.print_system_information() + + goodput_recorder = create_goodput_recorder(mt_config) + record_goodput(goodput_recorder, RECORD_JOB_START_TIME) + with maybe_monitor_goodput(mt_config): + train(mt_config, goodput_recorder) + + +if __name__ == "__main__": + app.run(main) From 2327be48383ef3e19084ee87d18e8bf32fde99e2 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Thu, 2 Apr 2026 14:54:48 -0700 Subject: [PATCH 2/6] Switch to using tokenizer inside Tunix --- src/maxtext/input_pipeline/hf_data_processing.py | 3 ++- src/maxtext/trainers/post_train/dpo/train_dpo.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index df2681781e..b4271edb10 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -280,7 +280,8 @@ def preprocessing_pipeline( pad_id = _get_pad_id(tokenizer) - if tokenize: + # Tunix-DPO handles tokenization internally to ensure proper padding/masking. + if tokenize and not use_dpo: dataset = dataset.map( input_pipeline_utils.tokenization, batched=True, diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py index 8bf162556a..3eaa5cb75e 100644 --- a/src/maxtext/trainers/post_train/dpo/train_dpo.py +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -36,6 +36,7 @@ from tunix.sft import metrics_logger, profiler from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig +import tunix from maxtext.configs import pyconfig from maxtext.utils import max_utils from maxtext.common.goodput import ( @@ -98,6 +99,8 @@ def get_tunix_config(mt_config: pyconfig.HyperParameters) -> DPOTrainingConfig: algorithm="dpo", # TODO: add support of "orpo" beta=mt_config.dpo_beta, label_smoothing=mt_config.dpo_label_smoothing, + max_prompt_length=mt_config.max_target_length // 2, + max_response_length=mt_config.max_target_length // 2, ) @@ -121,7 +124,17 @@ def setup_trainer_state(mt_config, goodput_recorder=None): training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) - trainer = DPOTrainer(model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=None) + tokenizer = tunix.Tokenizer( + tokenizer_type=mt_config.tokenizer_type, + tokenizer_path=mt_config.tokenizer_path, + add_bos=mt_config.add_bos, + add_eos=mt_config.add_eos, + hf_access_token=mt_config.hf_access_token, + ) + + trainer = DPOTrainer( + model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=tokenizer + ) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) From 63f9dcd198e38c981abe2d16506e7a838677ee34 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Fri, 3 Apr 2026 11:04:28 -0700 Subject: [PATCH 3/6] fix mesh rules. Disable PadOrTrimToMaxLength --- .../input_pipeline/hf_data_processing.py | 3 ++- .../trainers/post_train/dpo/train_dpo.py | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index b4271edb10..ff9857701e 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -341,7 +341,8 @@ def lists2array(x): ) operations.append(input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + if not use_dpo: + operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) if shift and not use_dpo: diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py index 3eaa5cb75e..43eda57914 100644 --- a/src/maxtext/trainers/post_train/dpo/train_dpo.py +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -31,6 +31,8 @@ from orbax import checkpoint as ocp import pathwaysutils +import flax.linen as nn +from flax import nnx from flax.linen import partitioning as nn_partitioning from tunix.sft import metrics_logger, profiler @@ -120,6 +122,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None): optimizer, ) + # Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param) + opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState) + opt_pspecs = nnx.get_partition_spec(opt_state) + opt_sharded_state = jax.lax.with_sharding_constraint(opt_state, opt_pspecs) + nnx.update(nnx_optimizer, opt_sharded_state) + with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) @@ -132,9 +142,15 @@ def setup_trainer_state(mt_config, goodput_recorder=None): hf_access_token=mt_config.hf_access_token, ) - trainer = DPOTrainer( - model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=tokenizer - ) + # Pass raw optax optimizer to DPOTrainer, then inject pre-sharded nnx.Optimizer + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + trainer = DPOTrainer( + model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=tokenizer + ) + # Ensure the trainer uses our pre-sharded optimizer instance + trainer.optimizer = nnx_optimizer + # Override _shard_optimizer to avoid TypeError on scalar states + trainer._shard_optimizer = lambda mesh: None trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) @@ -146,7 +162,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): def train_model(mt_config: pyconfig.HyperParameters, trainer, mesh): """Runs the DPO training loop in Tunix.""" - with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + with jax.set_mesh(mesh), mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) return trainer From 00b8cd998f0da9677e7b6264e07873beed828e6a Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Fri, 3 Apr 2026 11:30:31 -0700 Subject: [PATCH 4/6] Post training plan v1 --- post_training_plan.md | 52 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 post_training_plan.md diff --git a/post_training_plan.md b/post_training_plan.md new file mode 100644 index 0000000000..8ab9a4f866 --- /dev/null +++ b/post_training_plan.md @@ -0,0 +1,52 @@ +# Brainstorming: MaxText & Tunix Post-Training Integration Plan + +## 1. Executive Summary +The goal is to provide a best-in-class post-training suite (SFT, DPO, RLHF, GRPO) that scales to the largest models and TPU slices. + +Instead of maintaining duplicate implementations of complex alignment algorithms, we will establish a **"Hybrid Core"** architecture: +* **MaxText** acts as the **Performance Engine**: Providing optimized model implementations (NNX), robust sharding/SPMD rules, and high-throughput data loading. +* **Tunix** acts as the **Algorithmic Orchestrator**: Providing the training loops, specialized loss functions (DPO, PPO), and alignment-specific metrics. + +## 2. Shared Responsibilities & Strengths + +| Feature | MaxText Strength | Tunix Strength | Recommended Primary | +| :--- | :--- | :--- | :--- | +| **Model Arch** | highly optimized, NNX-based, TPU-aware | research-flexible | **MaxText** | +| **Sharding** | Robust logical-to-physical SPMD rules | Basic/Standard sharding | **MaxText** | +| **Dataloading** | Multi-host Grain integration | HF Datasets convenience | **Collaborative** (MaxText Grain + Tunix Prep) | +| **Loss Functions**| Standard Cross-Entropy | DPO, ORPO, PPO, GRPO | **Tunix** | +| **Metrics** | Goodput, Hardware utilization | KL-Divergence, Rewards, Accuracy | **Tunix** (Loop) + **MaxText** (System) | + +## 3. The "Bridge" Architecture (Implementation Strategy) + +To make these two projects work together without "technical friction," we should standardize the following interfaces: + +### A. The Model Wrapper (The "Naming Bridge") +Tunix trainers expect a generic model interface. We should formalize the `ModelWrapper` we started building. +* **Action:** Create a standard `MaxTextTunixWrapper` in `src/maxtext/trainers/post_train/utils.py` that handles mapping generic names (like `positions`, `attention_mask`) to model-specific names (like `decoder_input_tokens_positions`). + +### B. Sharding-Aware Initialization +Tunix's `PeftTrainer` currently makes assumptions about sharding that clash with MaxText's more advanced SPMD rules (e.g., the `norm` axis issue and scalar optimizer states). +* **Action:** Contribute to Tunix to make its internal sharding logic optional or configurable. +* **Action:** Ensure MaxText always provides a "Pre-Sharded" state to Tunix, and Tunix should respect existing sharding rather than attempting to re-apply it. + +### C. Standardized Data Schema +We need a unified "Post-Training Data Schema" that both projects understand. +* **DPO Schema:** `prompt_ids`, `chosen_ids`, `rejected_ids` + corresponding masks. +* **RL Schema:** `queries`, `responses`, `rewards`. +* **Action:** Implement `MaxTextDPOTprep` and `MaxTextRLPrep` transforms in the MaxText pipeline that output exactly what Tunix trainers expect. + +## 4. Specific Collaboration Opportunities + +### Contribution to Tunix +1. **Refactor `_shard_optimizer`:** Modify Tunix to check if an optimizer is already sharded before attempting to apply `with_sharding_constraint`. +2. **Generalized Keyword Arguments:** Update Tunix's `get_per_token_logps` to accept a mapping for keyword arguments, avoiding the need for a manual `ModelWrapper` for every new model. + +### Enhancement in MaxText +1. **Alignment-Aware Hooks:** Formalize the `SFTTrainingHooks` into a more general `PostTrainingHooks` system that detects the algorithm (SFT vs DPO vs RL) and adjusts metric calculation accordingly. +2. **Parameter-Efficient Fine-Tuning (PEFT):** Leverage Tunix's LoRA/PEFT logic while applying MaxText's optimized kernel implementations. + +## 5. Roadmap for DPO Integration (Immediate Next Steps) +1. **Finalize the `ModelWrapper`:** Fix the "too many values to unpack" error by ensuring the wrapper returns only what Tunix needs (logits). +2. **Formalize the "No-Op" Sharding Override:** Instead of a lambda hack, create a proper `MaxTextDPOTrainer` subclass that overrides `_shard_optimizer` cleanly. +3. **Unified Config:** Allow users to specify `post_training_flavor: tunix_dpo` in `dpo.yml` to automatically trigger these bridge behaviors. From 35133861656b6bf437520d44647c52698135e3f9 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Fri, 3 Apr 2026 13:34:10 -0700 Subject: [PATCH 5/6] Post training plan v2 --- post_training_plan.md | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/post_training_plan.md b/post_training_plan.md index 8ab9a4f866..5b24c0ae7d 100644 --- a/post_training_plan.md +++ b/post_training_plan.md @@ -21,20 +21,27 @@ Instead of maintaining duplicate implementations of complex alignment algorithms To make these two projects work together without "technical friction," we should standardize the following interfaces: -### A. The Model Wrapper (The "Naming Bridge") -Tunix trainers expect a generic model interface. We should formalize the `ModelWrapper` we started building. -* **Action:** Create a standard `MaxTextTunixWrapper` in `src/maxtext/trainers/post_train/utils.py` that handles mapping generic names (like `positions`, `attention_mask`) to model-specific names (like `decoder_input_tokens_positions`). +### A. The Model Adapter (Unified Naming Bridge) +We discovered that `src/maxtext/integration/tunix/tunix_adapter.py` already contains a robust `TunixMaxTextAdapter`. +* **Current State:** Used effectively in RL (`train_rl.py`). +* **Action:** Refactor SFT (`train_sft.py`) and DPO (`train_dpo.py`) to use this same adapter instead of ad-hoc wrappers. This ensures that any model supported by MaxText is immediately compatible with all Tunix trainers. ### B. Sharding-Aware Initialization Tunix's `PeftTrainer` currently makes assumptions about sharding that clash with MaxText's more advanced SPMD rules (e.g., the `norm` axis issue and scalar optimizer states). -* **Action:** Contribute to Tunix to make its internal sharding logic optional or configurable. -* **Action:** Ensure MaxText always provides a "Pre-Sharded" state to Tunix, and Tunix should respect existing sharding rather than attempting to re-apply it. +* **Current State:** Handled via manual "pre-sharding" and no-op overrides in DPO. +* **Action:** Move this logic into a base `MaxTextTunixTrainer` class or a utility function used by all post-training scripts. +* **Action:** Contribute to Tunix to make its internal `_shard_optimizer` check for existing sharding before applying constraints. -### C. Standardized Data Schema -We need a unified "Post-Training Data Schema" that both projects understand. -* **DPO Schema:** `prompt_ids`, `chosen_ids`, `rejected_ids` + corresponding masks. -* **RL Schema:** `queries`, `responses`, `rewards`. -* **Action:** Implement `MaxTextDPOTprep` and `MaxTextRLPrep` transforms in the MaxText pipeline that output exactly what Tunix trainers expect. +### C. Standardized Data Schema (The "Input Bridge") +MaxText's multi-host Grain loader requires numeric arrays, while Tunix often expects strings. +* **Current State:** SFT/DPO/RL each handle this differently. +* **Action:** Standardize on a "Pre-tokenized numeric schema" where MaxText performs tokenization and padding (using DPO-aware left-padding when needed) and provides the `_ids` and `_mask` columns Tunix expects for pre-tokenized input. + +## 4. Documentation Strategy + +Existing documentation is fragmented (`docs/tutorials/posttraining/sft.md`, `rl.md`, etc.). +* **Action:** Create a unified `post_training_overview.md` that explains the MaxText-Tunix relationship (MaxText=Engine, Tunix=Brain). +* **Action:** Ensure all tutorials consistently mention the `maxtext[tpu-post-train]` installation requirement. ## 4. Specific Collaboration Opportunities @@ -46,7 +53,20 @@ We need a unified "Post-Training Data Schema" that both projects understand. 1. **Alignment-Aware Hooks:** Formalize the `SFTTrainingHooks` into a more general `PostTrainingHooks` system that detects the algorithm (SFT vs DPO vs RL) and adjusts metric calculation accordingly. 2. **Parameter-Efficient Fine-Tuning (PEFT):** Leverage Tunix's LoRA/PEFT logic while applying MaxText's optimized kernel implementations. -## 5. Roadmap for DPO Integration (Immediate Next Steps) +## 5. Cleanup: Deleting Legacy Post-Training Support + +As we transition to the Tunix-based "Hybrid Core" architecture, we should remove the legacy, non-Tunix implementations from MaxText to reduce maintenance burden. + +### A. Remove Legacy DPO +The existing internal DPO implementation is fragmented and harder to maintain than the Tunix version. +* **Action:** Delete `src/maxtext/trainers/post_train/dpo/dpo_utils.py`. +* **Action:** Remove DPO-specific branches and imports in: + * `src/maxtext/trainers/pre_train/train.py` + * `src/maxtext/utils/train_utils.py` + * `src/MaxText/__init__.py` +* **Action:** Deprecate legacy DPO-specific configuration parameters in `src/maxtext/configs/base.yml` once the Tunix bridge is stable. + +## 6. Roadmap for DPO Integration (Immediate Next Steps) 1. **Finalize the `ModelWrapper`:** Fix the "too many values to unpack" error by ensuring the wrapper returns only what Tunix needs (logits). 2. **Formalize the "No-Op" Sharding Override:** Instead of a lambda hack, create a proper `MaxTextDPOTrainer` subclass that overrides `_shard_optimizer` cleanly. 3. **Unified Config:** Allow users to specify `post_training_flavor: tunix_dpo` in `dpo.yml` to automatically trigger these bridge behaviors. From d2158bea67ee32b79f45d63134d4610866cc8f68 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Fri, 3 Apr 2026 14:50:10 -0700 Subject: [PATCH 6/6] First version that does not crash --- post_training_plan.md | 16 +++---- .../input_pipeline/hf_data_processing.py | 42 +++++++++++++++++-- .../trainers/post_train/dpo/train_dpo.py | 17 ++++---- src/maxtext/trainers/post_train/sft/hooks.py | 20 +++++++-- 4 files changed, 72 insertions(+), 23 deletions(-) diff --git a/post_training_plan.md b/post_training_plan.md index 5b24c0ae7d..ac6a0d6520 100644 --- a/post_training_plan.md +++ b/post_training_plan.md @@ -43,15 +43,17 @@ Existing documentation is fragmented (`docs/tutorials/posttraining/sft.md`, `rl. * **Action:** Create a unified `post_training_overview.md` that explains the MaxText-Tunix relationship (MaxText=Engine, Tunix=Brain). * **Action:** Ensure all tutorials consistently mention the `maxtext[tpu-post-train]` installation requirement. -## 4. Specific Collaboration Opportunities +## 4. Collaborative Enhancements (Modifications to Tunix) -### Contribution to Tunix -1. **Refactor `_shard_optimizer`:** Modify Tunix to check if an optimizer is already sharded before attempting to apply `with_sharding_constraint`. -2. **Generalized Keyword Arguments:** Update Tunix's `get_per_token_logps` to accept a mapping for keyword arguments, avoiding the need for a manual `ModelWrapper` for every new model. +To further reduce the "glue code" in MaxText, we should upstream the following improvements to the Tunix library: -### Enhancement in MaxText -1. **Alignment-Aware Hooks:** Formalize the `SFTTrainingHooks` into a more general `PostTrainingHooks` system that detects the algorithm (SFT vs DPO vs RL) and adjusts metric calculation accordingly. -2. **Parameter-Efficient Fine-Tuning (PEFT):** Leverage Tunix's LoRA/PEFT logic while applying MaxText's optimized kernel implementations. +### A. Flexible Sharding in `PeftTrainer` +Tunix's `_shard_optimizer` currently forces sharding constraints that can crash on pre-sharded MaxText states (especially with scalar values). +* **Action:** Modify `tunix/sft/peft_trainer.py` to only apply `with_sharding_constraint` if the optimizer is not already sharded or if a specific flag is set. + +### B. Generalized Model Call Interface +Tunix's `get_per_token_logps` hardcodes argument names like `positions` and `attention_mask`. +* **Action:** Update `tunix/rl/common.py` to allow passing a `name_mapping` dictionary. This would allow MaxText to tell Tunix: "Use `decoder_positions` instead of `positions`." ## 5. Cleanup: Deleting Legacy Post-Training Support diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index ff9857701e..425d670dfa 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -280,8 +280,10 @@ def preprocessing_pipeline( pad_id = _get_pad_id(tokenizer) - # Tunix-DPO handles tokenization internally to ensure proper padding/masking. - if tokenize and not use_dpo: + # Tunix-DPO handles tokenization internally if strings are passed. + # However, MaxText's multihost loader requires numeric JAX arrays. + # We tokenize here and rename columns to match Tunix's TrainingInput requirements. + if tokenize: dataset = dataset.map( input_pipeline_utils.tokenization, batched=True, @@ -320,8 +322,40 @@ def lists2array(x): operations.append(grain.MapOperation(lists2array)) - rekey_dict = {"prompts": "input", "chosen_responses": "chosen", "rejected_responses": "rejected"} - operations.append(input_pipeline_utils.Rekey(rekey_dict)) + # Generate masks and rename keys to match tunix.sft.dpo.dpo_trainer.TrainingInput + class DPOTunixPrep(grain.MapTransform): + + def __init__(self, pad_id, max_prompt_length, max_response_length): + self.pad_id = pad_id + self.max_prompt_length = max_prompt_length + self.max_response_length = max_response_length + + def _pad(self, x, length, left=False): + x = np.asarray(x) + pad_amount = max(length - x.shape[0], 0) + if left: + pad_width = ((pad_amount, 0),) + else: + pad_width = ((0, pad_amount),) + return np.pad(x[:length], pad_width, constant_values=self.pad_id) + + def map(self, x): + prompt_ids = self._pad(x.pop("input"), self.max_prompt_length, left=True) + chosen_ids = self._pad(x.pop("chosen"), self.max_response_length, left=False) + rejected_ids = self._pad(x.pop("rejected"), self.max_response_length, left=False) + + x["prompt_ids"] = prompt_ids + x["chosen_ids"] = chosen_ids + x["rejected_ids"] = rejected_ids + x["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32) + x["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32) + x["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32) + return x + + # Tunix DPO expects prompt and response to share the total budget. + dpo_max_prompt_len = max_target_length // 2 + dpo_max_response_len = max_target_length // 2 + operations.append(DPOTunixPrep(pad_id, dpo_max_prompt_len, dpo_max_response_len)) else: assert len(data_column_names) == 1 operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py index 43eda57914..0b6b843a6c 100644 --- a/src/maxtext/trainers/post_train/dpo/train_dpo.py +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -39,6 +39,7 @@ from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig import tunix +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from maxtext.configs import pyconfig from maxtext.utils import max_utils from maxtext.common.goodput import ( @@ -112,6 +113,10 @@ def setup_trainer_state(mt_config, goodput_recorder=None): with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): model, mesh = model_creation_utils.create_nnx_model(mt_config) + + # Wrap model with Tunix adapter for consistent interface + model = TunixMaxTextAdapter(model) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) @@ -123,6 +128,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): ) # Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer + # Tunix will now detect it's already sharded and skip its internal sharding logic. with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param) opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState) @@ -142,21 +148,14 @@ def setup_trainer_state(mt_config, goodput_recorder=None): hf_access_token=mt_config.hf_access_token, ) - # Pass raw optax optimizer to DPOTrainer, then inject pre-sharded nnx.Optimizer + # Pass the pre-sharded nnx.Optimizer directly to DPOTrainer. with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): trainer = DPOTrainer( - model=model, ref_model=None, optimizer=optimizer, training_config=tunix_config, tokenizer=tokenizer + model=model, ref_model=None, optimizer=nnx_optimizer, training_config=tunix_config, tokenizer=None ) - # Ensure the trainer uses our pre-sharded optimizer instance - trainer.optimizer = nnx_optimizer - # Override _shard_optimizer to avoid TypeError on scalar states - trainer._shard_optimizer = lambda mesh: None trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) - # TODO(igorts-git): do we need this? It exists in SFT. - # trainer = use_maxtext_loss_function(trainer, mt_config) - return trainer, mesh diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index 03fa610837..791ac010aa 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -15,6 +15,7 @@ """Training and data loading hooks for SFT""" +import time from collections import defaultdict from sys import version_info @@ -93,11 +94,17 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable= @override def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of a training step.""" + self.step_start_time = time.time() if self.config.enable_goodput_recording: record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps) # Calculate the number of non-padded tokens in the batch - total_weights = jnp.sum(train_ctx.data_hooks.train_batch["targets_segmentation"] != 0) + if self.config.use_dpo: + # For DPO, we sum both chosen and rejected masks + total_weights = jnp.sum(train_ctx.data_hooks.train_batch["chosen_mask"] != 0) + total_weights += jnp.sum(train_ctx.data_hooks.train_batch["rejected_mask"] != 0) + else: + total_weights = jnp.sum(train_ctx.data_hooks.train_batch["targets_segmentation"] != 0) self.train_metadata[train_ctx.train_steps] = { "total_weights": total_weights, @@ -117,6 +124,8 @@ def on_train_step_end( However, we will use the current `train_step` value to record metrics in this hook to be consistent with Tunix's metric logging convention. """ + # Use our own timing since Tunix might pass 0.0 + actual_step_time = time.time() - self.step_start_time assert train_step - 1 in self.train_metadata, ( "SFTTrainingHooks.on_train_step_start() must be called before" " SFTTrainingHooks.on_train_step_end()" @@ -131,7 +140,7 @@ def on_train_step_end( "learning/total_weights": self.train_metadata[train_step - 1]["total_weights"], } } - self.metric_logger.record_train_metrics(metrics, train_step, step_time) + self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time) self.metric_logger.write_metrics(metrics, train_step) del self.train_metadata[train_step - 1] @@ -140,7 +149,12 @@ def on_eval_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of an evaluation step.""" self.eval_metadata["eval_step_count"] += 1.0 # Calculate the number of non-padded tokens in the batch - self.eval_metadata["total_weights"] += jnp.sum(train_ctx.data_hooks.eval_batch["targets_segmentation"] != 0) + if self.config.use_dpo: + total_weights = jnp.sum(train_ctx.data_hooks.eval_batch["chosen_mask"] != 0) + total_weights += jnp.sum(train_ctx.data_hooks.eval_batch["rejected_mask"] != 0) + self.eval_metadata["total_weights"] += total_weights + else: + self.eval_metadata["total_weights"] += jnp.sum(train_ctx.data_hooks.eval_batch["targets_segmentation"] != 0) @override def on_eval_step_end(self, train_ctx: peft_trainer.PeftTrainer, eval_loss: float):