-
Notifications
You must be signed in to change notification settings - Fork 769
fix: inject GPT-OSS stop tokens when not specified to prevent tool call failures (#1949) #2080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,10 @@ | |
| "too many total text bytes", | ||
| ] | ||
|
|
||
| # Stop tokens for GPT-OSS models to enforce generation boundaries | ||
| # https://github.com/openai/harmony/blob/main/src/registry.rs | ||
| _GPT_OSS_STOP_TOKENS = ["<|call|>", "<|return|>", "<|end|>"] | ||
|
|
||
|
|
||
| class Client(Protocol): | ||
| """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" | ||
|
|
@@ -467,6 +471,11 @@ def format_request( | |
| TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible | ||
| format. | ||
| """ | ||
| params = cast(dict[str, Any], self.config.get("params", {})) | ||
| # Inject default GPT-OSS stop tokens unless the user has explicitly provided their own | ||
| if "gpt-oss" in cast(str, self.config.get("model_id", "")).lower() and "stop" not in params: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this something we should always inject? any potential downsides, ie reasons not to do it? 🤔 or should the customer just pass this param for their custom self hosting?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my main hesitation is, this model provider is the gateway to all models that use openai chat shapes. so if we start injecting this, would/could it negatively impact others?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only inject the defaults when no stop tokens are specified at all. If the user passes their own, or if pre-configured middleware overrides with its own defaults (like what happens for request to the main Bedrock/Mantle endpoint), we don't touch them. We should NOT inject these tokens every time because OpenAI only allows a max of 4 stop sequences (ref). So, if the user has already passed some of their own, or if pre-configured middleware overrides with default stop tokens, appending these 3 defaults will almost certainly exceed the limit and cause errors. |
||
| params = {**params, "stop": _GPT_OSS_STOP_TOKENS} | ||
|
|
||
| return { | ||
| "messages": self.format_request_messages( | ||
| messages, system_prompt, system_prompt_content=system_prompt_content | ||
|
|
@@ -486,7 +495,7 @@ def format_request( | |
| for tool_spec in tool_specs or [] | ||
| ], | ||
| **(self._format_request_tool_choice(tool_choice)), | ||
| **cast(dict[str, Any], self.config.get("params", {})), | ||
| **params, | ||
| } | ||
|
|
||
| def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -627,6 +627,25 @@ def test_format_request(model, messages, tool_specs, system_prompt): | |
| assert tru_request == exp_request | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_id", ["openai.gpt-oss-120b"]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: Consider adding a parametrized test case with mixed-case model IDs (e.g., @pytest.mark.parametrize("model_id", ["openai.gpt-oss-120b", "openai.GPT-OSS-120B"])
def test_format_request_gpt_oss_injects_stop_tokens(model_id, model, messages, tool_specs, system_prompt):
...
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Official model ID is lowercase
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: consider adopt the same fixture pattern :
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model_id is already defined as a pytest.fixture (set to dummy id "m1") - this parametrize tag lets you override fixtures to new values. To your point, I suppose we could set up a fixture for model_id_gpt_oss_120b - I can follow up with some cosmetic upgrades for tests |
||
| def test_format_request_gpt_oss_injects_stop_tokens(model_id, model, messages, tool_specs, system_prompt): | ||
| tru_request = model.format_request(messages, tool_specs, system_prompt) | ||
| assert tru_request["stop"] == ["<|call|>", "<|return|>", "<|end|>"] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_id", ["openai.gpt-oss-120b"]) | ||
| def test_format_request_gpt_oss_preserves_explicit_stop_tokens(model_id, model, messages, tool_specs, system_prompt): | ||
| model.update_config(params={"max_tokens": 1, "stop": ["<|end|>"]}) | ||
|
|
||
| tru_request = model.format_request(messages, tool_specs, system_prompt) | ||
| assert tru_request["stop"] == ["<|end|>"] | ||
|
|
||
|
|
||
| def test_format_request_non_gpt_oss_no_stop_tokens(model, messages, tool_specs, system_prompt): | ||
| tru_request = model.format_request(messages, tool_specs, system_prompt) | ||
| assert "stop" not in tru_request | ||
|
|
||
|
|
||
| def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): | ||
| tool_choice = {"auto": {}} | ||
| tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.