diff --git a/post_training_plan.md b/post_training_plan.md new file mode 100644 index 0000000000..ac6a0d6520 --- /dev/null +++ b/post_training_plan.md @@ -0,0 +1,74 @@ +# Brainstorming: MaxText & Tunix Post-Training Integration Plan + +## 1. Executive Summary +The goal is to provide a best-in-class post-training suite (SFT, DPO, RLHF, GRPO) that scales to the largest models and TPU slices. + +Instead of maintaining duplicate implementations of complex alignment algorithms, we will establish a **"Hybrid Core"** architecture: +* **MaxText** acts as the **Performance Engine**: Providing optimized model implementations (NNX), robust sharding/SPMD rules, and high-throughput data loading. +* **Tunix** acts as the **Algorithmic Orchestrator**: Providing the training loops, specialized loss functions (DPO, PPO), and alignment-specific metrics. + +## 2. Shared Responsibilities & Strengths + +| Feature | MaxText Strength | Tunix Strength | Recommended Primary | +| :--- | :--- | :--- | :--- | +| **Model Arch** | highly optimized, NNX-based, TPU-aware | research-flexible | **MaxText** | +| **Sharding** | Robust logical-to-physical SPMD rules | Basic/Standard sharding | **MaxText** | +| **Dataloading** | Multi-host Grain integration | HF Datasets convenience | **Collaborative** (MaxText Grain + Tunix Prep) | +| **Loss Functions**| Standard Cross-Entropy | DPO, ORPO, PPO, GRPO | **Tunix** | +| **Metrics** | Goodput, Hardware utilization | KL-Divergence, Rewards, Accuracy | **Tunix** (Loop) + **MaxText** (System) | + +## 3. The "Bridge" Architecture (Implementation Strategy) + +To make these two projects work together without "technical friction," we should standardize the following interfaces: + +### A. The Model Adapter (Unified Naming Bridge) +We discovered that `src/maxtext/integration/tunix/tunix_adapter.py` already contains a robust `TunixMaxTextAdapter`. +* **Current State:** Used effectively in RL (`train_rl.py`). +* **Action:** Refactor SFT (`train_sft.py`) and DPO (`train_dpo.py`) to use this same adapter instead of ad-hoc wrappers. This ensures that any model supported by MaxText is immediately compatible with all Tunix trainers. + +### B. Sharding-Aware Initialization +Tunix's `PeftTrainer` currently makes assumptions about sharding that clash with MaxText's more advanced SPMD rules (e.g., the `norm` axis issue and scalar optimizer states). +* **Current State:** Handled via manual "pre-sharding" and no-op overrides in DPO. +* **Action:** Move this logic into a base `MaxTextTunixTrainer` class or a utility function used by all post-training scripts. +* **Action:** Contribute to Tunix to make its internal `_shard_optimizer` check for existing sharding before applying constraints. + +### C. Standardized Data Schema (The "Input Bridge") +MaxText's multi-host Grain loader requires numeric arrays, while Tunix often expects strings. +* **Current State:** SFT/DPO/RL each handle this differently. +* **Action:** Standardize on a "Pre-tokenized numeric schema" where MaxText performs tokenization and padding (using DPO-aware left-padding when needed) and provides the `_ids` and `_mask` columns Tunix expects for pre-tokenized input. + +## 4. Documentation Strategy + +Existing documentation is fragmented (`docs/tutorials/posttraining/sft.md`, `rl.md`, etc.). +* **Action:** Create a unified `post_training_overview.md` that explains the MaxText-Tunix relationship (MaxText=Engine, Tunix=Brain). +* **Action:** Ensure all tutorials consistently mention the `maxtext[tpu-post-train]` installation requirement. + +## 4. Collaborative Enhancements (Modifications to Tunix) + +To further reduce the "glue code" in MaxText, we should upstream the following improvements to the Tunix library: + +### A. Flexible Sharding in `PeftTrainer` +Tunix's `_shard_optimizer` currently forces sharding constraints that can crash on pre-sharded MaxText states (especially with scalar values). +* **Action:** Modify `tunix/sft/peft_trainer.py` to only apply `with_sharding_constraint` if the optimizer is not already sharded or if a specific flag is set. + +### B. Generalized Model Call Interface +Tunix's `get_per_token_logps` hardcodes argument names like `positions` and `attention_mask`. +* **Action:** Update `tunix/rl/common.py` to allow passing a `name_mapping` dictionary. This would allow MaxText to tell Tunix: "Use `decoder_positions` instead of `positions`." + +## 5. Cleanup: Deleting Legacy Post-Training Support + +As we transition to the Tunix-based "Hybrid Core" architecture, we should remove the legacy, non-Tunix implementations from MaxText to reduce maintenance burden. + +### A. Remove Legacy DPO +The existing internal DPO implementation is fragmented and harder to maintain than the Tunix version. +* **Action:** Delete `src/maxtext/trainers/post_train/dpo/dpo_utils.py`. +* **Action:** Remove DPO-specific branches and imports in: + * `src/maxtext/trainers/pre_train/train.py` + * `src/maxtext/utils/train_utils.py` + * `src/MaxText/__init__.py` +* **Action:** Deprecate legacy DPO-specific configuration parameters in `src/maxtext/configs/base.yml` once the Tunix bridge is stable. + +## 6. Roadmap for DPO Integration (Immediate Next Steps) +1. **Finalize the `ModelWrapper`:** Fix the "too many values to unpack" error by ensuring the wrapper returns only what Tunix needs (logits). +2. **Formalize the "No-Op" Sharding Override:** Instead of a lambda hack, create a proper `MaxTextDPOTrainer` subclass that overrides `_shard_optimizer` cleanly. +3. **Unified Config:** Allow users to specify `post_training_flavor: tunix_dpo` in `dpo.yml` to automatically trigger these bridge behaviors. diff --git a/src/maxtext/configs/post_train/dpo.yml b/src/maxtext/configs/post_train/dpo.yml index dbcdadb1ba..4b1dde5a80 100644 --- a/src/maxtext/configs/post_train/dpo.yml +++ b/src/maxtext/configs/post_train/dpo.yml @@ -1,8 +1,8 @@ base_config: "base.yml" use_dpo: true -train_data_columns: ['chosen', 'rejected'] -eval_data_columns: ['chosen', 'rejected'] +train_data_columns: ['input', 'chosen', 'rejected'] +eval_data_columns: ['input', 'chosen', 'rejected'] base_output_directory: 'gs://maxtext-external/logs' per_device_batch_size: 2.0 @@ -12,11 +12,12 @@ eval_interval: 5 # test eval once, in the middle of 10 training steps eval_steps: 2 # TFDS Pipeline ---------------------- -dataset_type: tfds -dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' -dataset_name: 'tfds:1.0.0' -eval_dataset_name: 'tfds:1.0.0' -eval_split: 'test' +#dataset_type: tfds +#dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' +#dataset_name: 'tfds:1.0.0' +#eval_dataset_name: 'tfds:1.0.0' +#eval_split: 'test' +packing: False # DEBUG: DO NOT MERGE. # HF Pipeline ------------------------- hf_eval_split: 'test' diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index 78d783270f..f44e8569e3 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -53,6 +53,7 @@ "maxtext.trainers.pre_train.train_compile": "base.yml", "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml", "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml", + "maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml", "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml", "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml", "maxtext.inference.decode": "base.yml", @@ -83,9 +84,7 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: return resolve_config_path(argv[1]), argv[2:] module = _module_from_path(argv[0]) if module not in _CONFIG_FILE_MAPPING: - raise ValueError( - f"No config file provided and no default config found for module '{module}'" - ) + raise ValueError(f"No config file provided and no default config found for module '{module}'") config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) logger.warning("No config file provided, using default config mapping: %s", config_path) return config_path, argv[1:] diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index d1be2c4890..425d670dfa 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -280,6 +280,9 @@ def preprocessing_pipeline( pad_id = _get_pad_id(tokenizer) + # Tunix-DPO handles tokenization internally if strings are passed. + # However, MaxText's multihost loader requires numeric JAX arrays. + # We tokenize here and rename columns to match Tunix's TrainingInput requirements. if tokenize: dataset = dataset.map( input_pipeline_utils.tokenization, @@ -318,6 +321,41 @@ def lists2array(x): return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple))) operations.append(grain.MapOperation(lists2array)) + + # Generate masks and rename keys to match tunix.sft.dpo.dpo_trainer.TrainingInput + class DPOTunixPrep(grain.MapTransform): + + def __init__(self, pad_id, max_prompt_length, max_response_length): + self.pad_id = pad_id + self.max_prompt_length = max_prompt_length + self.max_response_length = max_response_length + + def _pad(self, x, length, left=False): + x = np.asarray(x) + pad_amount = max(length - x.shape[0], 0) + if left: + pad_width = ((pad_amount, 0),) + else: + pad_width = ((0, pad_amount),) + return np.pad(x[:length], pad_width, constant_values=self.pad_id) + + def map(self, x): + prompt_ids = self._pad(x.pop("input"), self.max_prompt_length, left=True) + chosen_ids = self._pad(x.pop("chosen"), self.max_response_length, left=False) + rejected_ids = self._pad(x.pop("rejected"), self.max_response_length, left=False) + + x["prompt_ids"] = prompt_ids + x["chosen_ids"] = chosen_ids + x["rejected_ids"] = rejected_ids + x["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32) + x["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32) + x["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32) + return x + + # Tunix DPO expects prompt and response to share the total budget. + dpo_max_prompt_len = max_target_length // 2 + dpo_max_response_len = max_target_length // 2 + operations.append(DPOTunixPrep(pad_id, dpo_max_prompt_len, dpo_max_response_len)) else: assert len(data_column_names) == 1 operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) @@ -337,7 +375,8 @@ def lists2array(x): ) operations.append(input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + if not use_dpo: + operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) if shift and not use_dpo: diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py new file mode 100644 index 0000000000..0b6b843a6c --- /dev/null +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -0,0 +1,207 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO Training script that uses Tunix DPOTrainer on a MaxText model. + +Example command: +Training & Evaluation: + python3 -m maxtext.trainers.post_train.dpo.train_dpo \ + run_name=${WORKLOAD?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + tokenizer_path="google/gemma-2-2b-it" tokenizer_type=huggingface \ + dataset_type="hf" hf_path="Anthropic/hh-rlhf" \ + model_name=${MODEL?} load_parameters_path=${MAXTEXT_CONVERTED_CHECKPOINT?}/0/items \ + hf_access_token=${HF_TOKEN?} per_device_batch_size=1 max_target_length=1024 \ + eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 +""" + +from absl import app +import jax +import optax +from orbax import checkpoint as ocp +import pathwaysutils + +import flax.linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from tunix.sft import metrics_logger, profiler +from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig + +import tunix +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from maxtext.configs import pyconfig +from maxtext.utils import max_utils +from maxtext.common.goodput import ( + GoodputEvent, + RECORD_JOB_END_TIME, + RECORD_JOB_START_TIME, + create_goodput_recorder, + maybe_monitor_goodput, + maybe_record_goodput, + record_goodput, +) +from maxtext.optimizers import optimizers +from maxtext.trainers.post_train.sft import hooks +from maxtext.utils import max_logging +from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils + + +def get_tunix_config(mt_config: pyconfig.HyperParameters) -> DPOTrainingConfig: + """Gets the Tunix training configurations from the MaxText config. + + Args: + mt_config: MaxText config. + + Returns: + A Tunix `DPOTrainingConfig` object. + """ + # Checkpointing configurations + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=mt_config.checkpoint_period, + enable_async_checkpointing=mt_config.async_checkpointing, + ) + + # Metrics configurations + metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir) + + # Profiler configurations + profiler_options = None + if mt_config.profiler: + set_profile_options = True + platform_version = jax.extend.backend.get_backend().platform_version.strip() + if platform_version.startswith("Pathways"): + max_logging.log("Pathways backend detected. Disabling setting profile options.") + set_profile_options = False + profiler_options = profiler.ProfilerOptions( + log_dir=mt_config.tensorboard_dir, + skip_first_n_steps=mt_config.skip_first_n_steps_for_profiler, + profiler_steps=mt_config.profiler_steps, + set_profile_options=set_profile_options, + ) + + return DPOTrainingConfig( + eval_every_n_steps=mt_config.eval_interval, + max_steps=mt_config.steps, + gradient_accumulation_steps=mt_config.gradient_accumulation_steps, + checkpoint_root_directory=mt_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + metrics_logging_options=metrics_logging_options, + profiler_options=profiler_options, + algorithm="dpo", # TODO: add support of "orpo" + beta=mt_config.dpo_beta, + label_smoothing=mt_config.dpo_label_smoothing, + max_prompt_length=mt_config.max_target_length // 2, + max_response_length=mt_config.max_target_length // 2, + ) + + +def setup_trainer_state(mt_config, goodput_recorder=None): + """Set up prerequisites for training loop.""" + tunix_config = get_tunix_config(mt_config) + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): + model, mesh = model_creation_utils.create_nnx_model(mt_config) + + # Wrap model with Tunix adapter for consistent interface + model = TunixMaxTextAdapter(model) + + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) + # pass in model for muon + optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) + + if mt_config.gradient_clipping_threshold > 0: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold), + optimizer, + ) + + # Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer + # Tunix will now detect it's already sharded and skip its internal sharding logic. + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param) + opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState) + opt_pspecs = nnx.get_partition_spec(opt_state) + opt_sharded_state = jax.lax.with_sharding_constraint(opt_state, opt_pspecs) + nnx.update(nnx_optimizer, opt_sharded_state) + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): + training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) + data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) + + tokenizer = tunix.Tokenizer( + tokenizer_type=mt_config.tokenizer_type, + tokenizer_path=mt_config.tokenizer_path, + add_bos=mt_config.add_bos, + add_eos=mt_config.add_eos, + hf_access_token=mt_config.hf_access_token, + ) + + # Pass the pre-sharded nnx.Optimizer directly to DPOTrainer. + with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + trainer = DPOTrainer( + model=model, ref_model=None, optimizer=nnx_optimizer, training_config=tunix_config, tokenizer=None + ) + trainer.with_training_hooks(training_hooks) + trainer.with_data_hooks(data_hooks) + + return trainer, mesh + + +def train_model(mt_config: pyconfig.HyperParameters, trainer, mesh): + """Runs the DPO training loop in Tunix.""" + with jax.set_mesh(mesh), mesh, nn.logical_axis_rules(mt_config.logical_axis_rules): + trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + return trainer + + +def train(mt_config, goodput_recorder=None): + """Main method for DPO training. + + Args: + mt_config: MaxText config. + goodput_recorder: An optional GoodputRecorder to record performance metrics. + """ + trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) + _job_completed_gracefully = False + try: + trainer = train_model(mt_config, trainer, mesh) + _job_completed_gracefully = True + finally: + if _job_completed_gracefully: + record_goodput(goodput_recorder, RECORD_JOB_END_TIME) + return trainer, mesh + + +def main(argv: list[str]) -> None: + """Main function to run DPO training. + + Args: + argv: Command-line arguments. + """ + # import debugpy; debugpy.listen(("localhost", 5678)); print("Attach VS Code now"); debugpy.wait_for_client() + + pathwaysutils.initialize() + + mt_config = pyconfig.initialize(argv) + max_utils.print_system_information() + + goodput_recorder = create_goodput_recorder(mt_config) + record_goodput(goodput_recorder, RECORD_JOB_START_TIME) + with maybe_monitor_goodput(mt_config): + train(mt_config, goodput_recorder) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index 03fa610837..791ac010aa 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -15,6 +15,7 @@ """Training and data loading hooks for SFT""" +import time from collections import defaultdict from sys import version_info @@ -93,11 +94,17 @@ def on_train_end(self, train_ctx: peft_trainer.PeftTrainer): # pylint: disable= @override def on_train_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of a training step.""" + self.step_start_time = time.time() if self.config.enable_goodput_recording: record_goodput(self.goodput_recorder, f"record_{GoodputEvent.STEP.value}_start_time", train_ctx.train_steps) # Calculate the number of non-padded tokens in the batch - total_weights = jnp.sum(train_ctx.data_hooks.train_batch["targets_segmentation"] != 0) + if self.config.use_dpo: + # For DPO, we sum both chosen and rejected masks + total_weights = jnp.sum(train_ctx.data_hooks.train_batch["chosen_mask"] != 0) + total_weights += jnp.sum(train_ctx.data_hooks.train_batch["rejected_mask"] != 0) + else: + total_weights = jnp.sum(train_ctx.data_hooks.train_batch["targets_segmentation"] != 0) self.train_metadata[train_ctx.train_steps] = { "total_weights": total_weights, @@ -117,6 +124,8 @@ def on_train_step_end( However, we will use the current `train_step` value to record metrics in this hook to be consistent with Tunix's metric logging convention. """ + # Use our own timing since Tunix might pass 0.0 + actual_step_time = time.time() - self.step_start_time assert train_step - 1 in self.train_metadata, ( "SFTTrainingHooks.on_train_step_start() must be called before" " SFTTrainingHooks.on_train_step_end()" @@ -131,7 +140,7 @@ def on_train_step_end( "learning/total_weights": self.train_metadata[train_step - 1]["total_weights"], } } - self.metric_logger.record_train_metrics(metrics, train_step, step_time) + self.metric_logger.record_train_metrics(metrics, train_step, actual_step_time) self.metric_logger.write_metrics(metrics, train_step) del self.train_metadata[train_step - 1] @@ -140,7 +149,12 @@ def on_eval_step_start(self, train_ctx: peft_trainer.PeftTrainer): """Called at the beginning of an evaluation step.""" self.eval_metadata["eval_step_count"] += 1.0 # Calculate the number of non-padded tokens in the batch - self.eval_metadata["total_weights"] += jnp.sum(train_ctx.data_hooks.eval_batch["targets_segmentation"] != 0) + if self.config.use_dpo: + total_weights = jnp.sum(train_ctx.data_hooks.eval_batch["chosen_mask"] != 0) + total_weights += jnp.sum(train_ctx.data_hooks.eval_batch["rejected_mask"] != 0) + self.eval_metadata["total_weights"] += total_weights + else: + self.eval_metadata["total_weights"] += jnp.sum(train_ctx.data_hooks.eval_batch["targets_segmentation"] != 0) @override def on_eval_step_end(self, train_ctx: peft_trainer.PeftTrainer, eval_loss: float):