diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7dbe14c480..7c03f357c2 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -142,6 +142,10 @@ def get_dataset( ) template_config = load_template_from_file(tmvp_config.chat_template_path) + if template_config is None: + max_logging.warning( + f"Failed to load chat template from {tmvp_config.chat_template_path}. Proceeding without chat template." + ) loaded_dataset = ( grain.MapDataset.source(data) @@ -339,6 +343,10 @@ def prepare_openinstructmath2_dataset( split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M" splits = prepare_openinstructmath2_dataset(split=split_name) template_config = load_template_from_file(trainer_config.chat_template_path) + if template_config is None: + max_logging.warning( + f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template." + ) train_dataset = ( grain.MapDataset.source(splits["train"]) @@ -616,6 +624,10 @@ def _reward_fn(**kwargs): ) # Instantiate the custom MaxText chat parser template_config = load_template_from_file(trainer_config.chat_template_path) + if template_config is None: + max_logging.warning( + f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template." + ) chat_parser = utils_rl.MaxTextChatParser( model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config ) diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index f8cb6fd1b7..8dc53f25ee 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -526,8 +526,11 @@ def make_optimizer(learning_rate): return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule) -def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]: +def format_maxtext_messages(messages: list[str], template_config: dict | None, tmvp_config) -> list[dict[str, str]]: """Helper to inject MaxText's system prompt into the input user messages.""" + if template_config is None: + return [{"role": "user", "content": msg} for msg in messages] + formatted_messages = [] for msg in messages: formatted_content = template_config["TEMPLATE"].format( diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index 5d3ada8558..1560dbd80b 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -370,5 +370,40 @@ def test_returns_optimizer_with_clipping(self): self.assertIn("learning_rate", state.hyperparams) +class TestFormatMaxTextMessages(unittest.TestCase): + """Tests for utils_rl.format_maxtext_messages.""" + + def setUp(self): + self.config = _make_config() + self.template_config = { + "SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. " + + "Solution between {solution_start_token} and {solution_end_token}.", + "TEMPLATE": "system: {system_prompt}\nquestion: {question}", + } + + @pytest.mark.cpu_only + def test_format_with_template(self): + """Test formatting when a template is provided.""" + messages = ["What is 2+2?"] + formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config) + self.assertEqual(len(formatted), 1) + self.assertEqual(formatted[0]["role"], "user") + expected_content = ( + "system: Reason between and . " + "Solution between and .\n" + "question: What is 2+2?" + ) + self.assertEqual(formatted[0]["content"], expected_content) + + @pytest.mark.cpu_only + def test_format_without_template(self): + """Test formatting when template_config is None (the fix).""" + messages = ["What is 2+2?"] + formatted = utils_rl.format_maxtext_messages(messages, None, self.config) + self.assertEqual(len(formatted), 1) + self.assertEqual(formatted[0]["role"], "user") + self.assertEqual(formatted[0]["content"], "What is 2+2?") + + if __name__ == "__main__": unittest.main()