diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py
index 7dbe14c480..7c03f357c2 100644
--- a/src/maxtext/trainers/post_train/rl/train_rl.py
+++ b/src/maxtext/trainers/post_train/rl/train_rl.py
@@ -142,6 +142,10 @@ def get_dataset(
)
template_config = load_template_from_file(tmvp_config.chat_template_path)
+ if template_config is None:
+ max_logging.warning(
+ f"Failed to load chat template from {tmvp_config.chat_template_path}. Proceeding without chat template."
+ )
loaded_dataset = (
grain.MapDataset.source(data)
@@ -339,6 +343,10 @@ def prepare_openinstructmath2_dataset(
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
splits = prepare_openinstructmath2_dataset(split=split_name)
template_config = load_template_from_file(trainer_config.chat_template_path)
+ if template_config is None:
+ max_logging.warning(
+ f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
+ )
train_dataset = (
grain.MapDataset.source(splits["train"])
@@ -616,6 +624,10 @@ def _reward_fn(**kwargs):
)
# Instantiate the custom MaxText chat parser
template_config = load_template_from_file(trainer_config.chat_template_path)
+ if template_config is None:
+ max_logging.warning(
+ f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
+ )
chat_parser = utils_rl.MaxTextChatParser(
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
)
diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py
index f8cb6fd1b7..8dc53f25ee 100644
--- a/src/maxtext/trainers/post_train/rl/utils_rl.py
+++ b/src/maxtext/trainers/post_train/rl/utils_rl.py
@@ -526,8 +526,11 @@ def make_optimizer(learning_rate):
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)
-def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
+def format_maxtext_messages(messages: list[str], template_config: dict | None, tmvp_config) -> list[dict[str, str]]:
"""Helper to inject MaxText's system prompt into the input user messages."""
+ if template_config is None:
+ return [{"role": "user", "content": msg} for msg in messages]
+
formatted_messages = []
for msg in messages:
formatted_content = template_config["TEMPLATE"].format(
diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py
index 5d3ada8558..1560dbd80b 100644
--- a/tests/post_training/unit/rl_utils_test.py
+++ b/tests/post_training/unit/rl_utils_test.py
@@ -370,5 +370,40 @@ def test_returns_optimizer_with_clipping(self):
self.assertIn("learning_rate", state.hyperparams)
+class TestFormatMaxTextMessages(unittest.TestCase):
+ """Tests for utils_rl.format_maxtext_messages."""
+
+ def setUp(self):
+ self.config = _make_config()
+ self.template_config = {
+ "SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. "
+ + "Solution between {solution_start_token} and {solution_end_token}.",
+ "TEMPLATE": "system: {system_prompt}\nquestion: {question}",
+ }
+
+ @pytest.mark.cpu_only
+ def test_format_with_template(self):
+ """Test formatting when a template is provided."""
+ messages = ["What is 2+2?"]
+ formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config)
+ self.assertEqual(len(formatted), 1)
+ self.assertEqual(formatted[0]["role"], "user")
+ expected_content = (
+ "system: Reason between and . "
+ "Solution between and .\n"
+ "question: What is 2+2?"
+ )
+ self.assertEqual(formatted[0]["content"], expected_content)
+
+ @pytest.mark.cpu_only
+ def test_format_without_template(self):
+ """Test formatting when template_config is None (the fix)."""
+ messages = ["What is 2+2?"]
+ formatted = utils_rl.format_maxtext_messages(messages, None, self.config)
+ self.assertEqual(len(formatted), 1)
+ self.assertEqual(formatted[0]["role"], "user")
+ self.assertEqual(formatted[0]["content"], "What is 2+2?")
+
+
if __name__ == "__main__":
unittest.main()