Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def get_dataset(
)

template_config = load_template_from_file(tmvp_config.chat_template_path)
if template_config is None:
max_logging.warning(
f"Failed to load chat template from {tmvp_config.chat_template_path}. Proceeding without chat template."
)

loaded_dataset = (
grain.MapDataset.source(data)
Expand Down Expand Up @@ -339,6 +343,10 @@ def prepare_openinstructmath2_dataset(
split_name = trainer_config.train_split if trainer_config.train_split != "train" else "train_1M"
splits = prepare_openinstructmath2_dataset(split=split_name)
template_config = load_template_from_file(trainer_config.chat_template_path)
if template_config is None:
max_logging.warning(
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
)

train_dataset = (
grain.MapDataset.source(splits["train"])
Expand Down Expand Up @@ -616,6 +624,10 @@ def _reward_fn(**kwargs):
)
# Instantiate the custom MaxText chat parser
template_config = load_template_from_file(trainer_config.chat_template_path)
if template_config is None:
max_logging.warning(
f"Failed to load chat template from {trainer_config.chat_template_path}. Proceeding without chat template."
)
chat_parser = utils_rl.MaxTextChatParser(
model_tokenizer=model_tokenizer, template_config=template_config, tmvp_config=trainer_config
)
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,11 @@ def make_optimizer(learning_rate):
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)


def format_maxtext_messages(messages: list[dict[str, str]], template_config: dict, tmvp_config) -> list[dict[str, str]]:
def format_maxtext_messages(messages: list[str], template_config: dict | None, tmvp_config) -> list[dict[str, str]]:
"""Helper to inject MaxText's system prompt into the input user messages."""
if template_config is None:
return [{"role": "user", "content": msg} for msg in messages]

formatted_messages = []
for msg in messages:
formatted_content = template_config["TEMPLATE"].format(
Expand Down
35 changes: 35 additions & 0 deletions tests/post_training/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,40 @@ def test_returns_optimizer_with_clipping(self):
self.assertIn("learning_rate", state.hyperparams)


class TestFormatMaxTextMessages(unittest.TestCase):
"""Tests for utils_rl.format_maxtext_messages."""

def setUp(self):
self.config = _make_config()
self.template_config = {
"SYSTEM_PROMPT": "Reason between {reasoning_start_token} and {reasoning_end_token}. "
+ "Solution between {solution_start_token} and {solution_end_token}.",
"TEMPLATE": "system: {system_prompt}\nquestion: {question}",
}

@pytest.mark.cpu_only
def test_format_with_template(self):
"""Test formatting when a template is provided."""
messages = ["What is 2+2?"]
formatted = utils_rl.format_maxtext_messages(messages, self.template_config, self.config)
self.assertEqual(len(formatted), 1)
self.assertEqual(formatted[0]["role"], "user")
expected_content = (
"system: Reason between <reasoning> and </reasoning>. "
"Solution between <answer> and </answer>.\n"
"question: What is 2+2?"
)
self.assertEqual(formatted[0]["content"], expected_content)

@pytest.mark.cpu_only
def test_format_without_template(self):
"""Test formatting when template_config is None (the fix)."""
messages = ["What is 2+2?"]
formatted = utils_rl.format_maxtext_messages(messages, None, self.config)
self.assertEqual(len(formatted), 1)
self.assertEqual(formatted[0]["role"], "user")
self.assertEqual(formatted[0]["content"], "What is 2+2?")


if __name__ == "__main__":
unittest.main()
Loading