-
Notifications
You must be signed in to change notification settings - Fork 500
[Bug]: rl.yml defaults to hardcoded Gemma chat template, causing degenerate loops in LLaMA 3 GRPO #3572
Description
Bug report
When running the GRPO post-training pipeline (configs/post_train/rl.yml) with a non-Gemma model like LLaMA 3.1, the training loop silently fails and results in 0.0 rewards due to a hardcoded chat template.
Logs/Output
The Problem:
In rl.yml, the chat_template_path defaults to maxtext/examples/chat_templates/gsm8k_rl.json. This JSON file hardcodes Google's <start_of_turn> and <end_of_turn> tokens into the "TEMPLATE" string.
When training a LLaMA 3 model, feeding it <start_of_turn> causes the model to panic, hallucinate its format, and fall into a degenerate repetition loop (e.g., repeatedly outputting <start_of_turn>user until generation is cut off). The reward function fails to extract , resulting in a 0 reward and a broken actor.
Environment Information
Hardware: TPU v5e-8 slice via GKE
Additional Context
The Workaround:
I was able to fix this and successfully train the model by creating a custom llama3_rl.json file using LLaMA's native <|start_header_id|> tokens, and passing it via chat_template_path=/path/to/llama3_rl.json.
Suggested Fixes:
Short term: Add a llama3_gsm8k_rl.json to the examples/chat_templates/ folder and document how to swap it in the RL tutorial.
Long term: Deprecate the hardcoded JSON templates entirely and allow the pipeline to dynamically build the prompts using the model's native tokenizer.chat_template from Hugging Face.