(4/5) async grpo break out of generation loop (is_done)#5321
(4/5) async grpo break out of generation loop (is_done)#5321
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| tool_execution_duration_s=0.0, | ||
| ) | ||
| ) | ||
| return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start) |
There was a problem hiding this comment.
truncated always False even when max turns exceeded
Low Severity
The truncated field on RolloutCompletion is always set to False across all return paths. On line 600, two semantically different conditions are combined with or: (1) tool_calls_raw is None (natural completion) and (2) iteration_num >= max_num_turns (forced stop). When the model produces tool calls but max_num_turns forces an early exit, truncated is incorrectly set to False — the generation was cut short, which is truncation. The truncated value needs to depend on which sub-condition triggered the exit (e.g., truncated=(tool_calls_raw is not None)).
|
Discussion continued from PR #5299 regarding the @qgallouedec: Requiring extra methods unnecessarily narrows the set of valid implementations and reduces substitutability. We should depend on the minimal protocol the class actually needs. There is nothing we can't do without @AmineDiro:
@qgallouedec: class MyEnv:
def reset(self, **kwargs):
self._done = False
def my_func(self):
"""Some nice documentation"""
if self._done:
raise Exception("Session expired, you can't use this tool anymore!")
if some_condition:
self._done = True
returnIf the generation stops when the environment is done, then the model never learns to know that the [environment is done] implies that [it can't use it anymore]. @AmineDiro: async def _generate_one(...):
while True:
# ....
if tool_calls is None or (max_iterations is not None and iteration_num >= max_iterations):
return ...Now we can have done() raise an Exception but I am not sure it provides a clean signal to know that this rollout had exception because of some env specific errors vs it genuinely reached end of turn. If we go with Exception route, I would image that we need to define a library side Exception like Also having a done env defined tool extends tools context for the LLM which can be a good choice in some specific cases but overall maybe unnecesary ?? @qgallouedec: it seems like a proxy the a new But more generally, don't you think it's out of the scope of this PR?
To be precise, I put the previous generation loop. The changed one does use it to break : line if is_done is not None and is_done():
return _build_completion(turns, truncated=False, total_duration=time.monotonic() - t_start)
Truncation is separate from done and is mainly for metrics and debugging purposes. It can be computed from "isn't last EOS" but its just a bool and recomputing something that small seems unnecessary ? |
|
I took another look at this PR, and there’s a pattern here, I’m not sure if it’s intentional. Let me know if I’m understanding this correctly: The |
yes you are right |


🔴 🔴 IMPORTANT: Depends on: feature/async-grpo-data-classes
What does this PR do?
is_donemethod upon initialization.is_doneinto the_generate_onemethod and breaks the multi-turn generation loop early if the environment signals it has completed its objective.get_json_schemato ensure environment tools pass transformers schema validation.Note
Medium Risk
Moderate risk: changes the rollout data model and multi-turn generation control flow, which can affect reward inputs, token accounting, and tool-calling behavior in async sampling.
Overview
Adds an environment-driven early-exit path for multi-turn tool-calling:
AsyncRolloutWorkernow discovers an optionalis_donemethod per environment instance and stops_generate_oneonce it returns true.Refactors rollout outputs from flat lists into structured records (
RolloutCompletion/TurnRecord/ToolCallRecord) that track per-turn messages, token ids/logprobs, tool-response suffix ids, and timing; scoring/reward inputs, tool metrics, and completion masks/logprobs are updated to derive from this structure. Also pre-converts environment bound methods to JSON schema viaget_json_schemaso they pass Transformers tool validation.Written by Cursor Bugbot for commit ae9ea3e. This will update automatically on new commits. Configure here.