diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 58e92913..8a560f02 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -16,6 +16,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/readline" + "github.com/docker/model-runner/cmd/cli/tools" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/fatih/color" @@ -24,6 +25,15 @@ import ( "golang.org/x/term" ) +// defaultTools returns the tools enabled by default for interactive sessions. +// Web search can be disabled by setting DOCKER_MODEL_NO_WEBSEARCH=1. +func defaultTools() []desktop.ClientTool { + if os.Getenv("DOCKER_MODEL_NO_WEBSEARCH") != "" { + return nil + } + return []desktop.ClientTool{&tools.WebSearchTool{}} +} + // readMultilineInput reads input from stdin, supporting both single-line and multiline input. // For multiline input, it detects triple-quoted strings and shows continuation prompts. func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, error) { @@ -632,11 +642,13 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de // This reflects exactly what the model receives. processedUserMessage = buildUserMessage(prompt, imageURLs) + activeTools := defaultTools() + if !useMarkdown { // Simple case: just stream as plain text assistantResponse, err = client.ChatWithMessagesContext(ctx, model, conversationHistory, prompt, imageURLs, func(content string) { cmd.Print(content) - }, false) + }, false, activeTools...) return assistantResponse, processedUserMessage, err } @@ -655,7 +667,7 @@ func chatWithMarkdownContext(ctx context.Context, cmd *cobra.Command, client *de } else if rendered != "" { cmd.Print(rendered) } - }, true) + }, true, activeTools...) if err != nil { return assistantResponse, processedUserMessage, err } diff --git a/cmd/cli/desktop/api.go b/cmd/cli/desktop/api.go index a19a8543..b5e10600 100644 --- a/cmd/cli/desktop/api.go +++ b/cmd/cli/desktop/api.go @@ -1,8 +1,37 @@ package desktop +// Tool represents an OpenAI function tool definition. +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} + +// ToolFunction holds the schema for a tool. +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +// ToolCall represents a tool call in a message or streaming delta. +type ToolCall struct { + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Index int `json:"index"` + Function ToolCallFunction `json:"function"` +} + +// ToolCallFunction holds the name and accumulated arguments for a tool call. +type ToolCallFunction struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + type OpenAIChatMessage struct { - Role string `json:"role"` - Content interface{} `json:"content"` // Can be string or []ContentPart for multimodal + Role string `json:"role"` + Content any `json:"content,omitempty"` // Can be string or []ContentPart for multimodal + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } // ContentPart represents a part of multimodal content (text or image) @@ -21,6 +50,7 @@ type OpenAIChatRequest struct { Model string `json:"model"` Messages []OpenAIChatMessage `json:"messages"` Stream bool `json:"stream"` + Tools []Tool `json:"tools,omitempty"` } type OpenAIChatResponse struct { @@ -30,13 +60,15 @@ type OpenAIChatResponse struct { Model string `json:"model"` Choices []struct { Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - ReasoningContent string `json:"reasoning_content,omitempty"` + Content string `json:"content"` + Role string `json:"role,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } `json:"delta"` Message struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` + Content string `json:"content"` + Role string `json:"role,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } `json:"message"` Index int `json:"index"` FinishReason string `json:"finish_reason"` diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 361779e5..f941ade4 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -29,6 +29,13 @@ var ( ErrServiceUnavailable = errors.New("service unavailable") ) +// ClientTool is a tool that can be registered with the chat client. +type ClientTool interface { + Name() string + Schema() Tool + Execute(args map[string]any) (string, error) +} + type otelErrorSilencer struct{} func (oes *otelErrorSilencer) Handle(error) {} @@ -368,6 +375,13 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func( return c.ChatWithContext(context.Background(), model, prompt, imageURLs, outputFunc, shouldUseMarkdown) } +// accumulatedToolCall collects streamed tool call fragments into a complete call. +type accumulatedToolCall struct { + id string + name string + arguments strings.Builder +} + // Preload loads a model into memory without running inference. // The model stays loaded for the idle timeout period. func (c *Client) Preload(ctx context.Context, model string) error { @@ -409,7 +423,9 @@ func (c *Client) Preload(ctx context.Context, model string) error { // ChatWithMessagesContext performs a chat request with conversation history and returns the assistant's response. // This allows maintaining conversation context across multiple exchanges. -func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conversationHistory []OpenAIChatMessage, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) (string, error) { +// When tools are provided, the function implements an agentic loop: if the model requests a tool call, +// the tool is executed and the result is sent back until the model produces a final response. +func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conversationHistory []OpenAIChatMessage, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool, tools ...ClientTool) (string, error) { // Build the current user message content - either simple string or multimodal array var messageContent interface{} if len(imageURLs) > 0 { @@ -448,34 +464,21 @@ func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conv Content: messageContent, }) - reqBody := OpenAIChatRequest{ - Model: model, - Messages: messages, - Stream: true, - } + // initialMessages captures the messages before any tool calls so we can + // fall back to them if the model's chat template doesn't support tool roles. + initialMessages := messages - jsonData, err := json.Marshal(reqBody) - if err != nil { - return "", fmt.Errorf("error marshaling request: %w", err) + // Build tool schemas and lookup map + var toolSchemas []Tool + toolMap := make(map[string]ClientTool, len(tools)) + for _, t := range tools { + toolSchemas = append(toolSchemas, t.Schema()) + toolMap[t.Name()] = t } - completionsPath := c.modelRunner.OpenAIPathPrefix() + "/chat/completions" - - resp, err := c.doRequestWithAuthContext( - ctx, - http.MethodPost, - completionsPath, - bytes.NewReader(jsonData), - ) - if err != nil { - return "", c.handleQueryError(err, completionsPath) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body) - } + // toolsSupported is cleared if the model returns a Jinja template error, + // indicating its chat template doesn't support tool calling. + toolsSupported := len(toolSchemas) > 0 type chatPrinterState int const ( @@ -484,162 +487,316 @@ func (c *Client) ChatWithMessagesContext(ctx context.Context, model string, conv chatPrinterContent ) - printerState := chatPrinterNone reasoningFmt := color.New().Add(color.Italic) if !shouldUseMarkdown { reasoningFmt.DisableColor() } - var assistantResponse strings.Builder var finalUsage *struct { CompletionTokens int `json:"completion_tokens"` PromptTokens int `json:"prompt_tokens"` TotalTokens int `json:"total_tokens"` } - // Use a buffered reader so we can consume server-sent progress - // lines (e.g. "Installing vllm-metal backend...") that arrive - // before the actual SSE or JSON inference response. - br := bufio.NewReader(resp.Body) + completionsPath := c.modelRunner.OpenAIPathPrefix() + "/chat/completions" - // Consume any plain-text progress lines that precede the real - // response. We peek ahead: if the next non-empty content starts - // with '{' (JSON) or "data:" / ":" (SSE), the progress section - // is over and we fall through to normal processing. + var assistantResponse strings.Builder + + // Agentic loop: iterate until the model produces a stop response (no more tool calls). for { - peek, err := br.Peek(1) - if err != nil { - break + reqBody := OpenAIChatRequest{ + Model: model, + Messages: messages, + Stream: true, + Tools: toolSchemas, } - // JSON object or SSE stream — stop consuming progress lines. - if peek[0] == '{' || peek[0] == ':' { - break - } - line, err := br.ReadString('\n') - if err != nil && line == "" { - break - } - line = strings.TrimRight(line, "\r\n") - if line == "" { - continue - } - // SSE data line — stop, let the normal SSE parser handle it. - if strings.HasPrefix(line, "data:") { - // Put the line back by chaining a reader with the rest. - br = bufio.NewReader(io.MultiReader( - strings.NewReader(line+"\n"), - br, - )) - break - } - // Progress message — print to stderr. - fmt.Fprintln(os.Stderr, line) - } - // Detect streaming vs non-streaming response. Because server-sent - // progress lines may have been flushed before the Content-Type was - // set, we also peek at the body content to detect SSE. - isStreaming := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if !isStreaming { - if peek, err := br.Peek(5); err == nil { - isStreaming = strings.HasPrefix(string(peek), "data:") + jsonData, err := json.Marshal(reqBody) + if err != nil { + return assistantResponse.String(), fmt.Errorf("error marshaling request: %w", err) } - } - if !isStreaming { - // Non-streaming JSON response - body, err := io.ReadAll(br) + resp, err := c.doRequestWithAuthContext( + ctx, + http.MethodPost, + completionsPath, + bytes.NewReader(jsonData), + ) if err != nil { - return assistantResponse.String(), fmt.Errorf("error reading response body: %w", err) + return assistantResponse.String(), c.handleQueryError(err, completionsPath) } - var nonStreamResp OpenAIChatResponse - if err := json.Unmarshal(body, &nonStreamResp); err != nil { - return assistantResponse.String(), fmt.Errorf("error parsing response: %w", err) + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + // If the model doesn't support tool calling (e.g., its chat template throws a + // Jinja exception when tools are present), retry the request without tools. + // Only do this before any tool calls have been executed to avoid corrupting + // the message history. + // If the model's chat template doesn't support tool calling (Jinja exception), + // fall back to retrying with the original messages and no tools. + // This handles both cases: + // - Error before any tool calls: the tools parameter in the request itself + // breaks the template (e.g. injects an incompatible system message). + // - Error after tool calls: the tool/assistant(tool_calls) messages in the + // history use roles the template doesn't understand. + // In both cases we reset to the initial user messages and disable tools so the + // model can answer from its training data. + // + // Note: This detection relies on string matching because the model runner does + // not provide a structured error code for template incompatibility. The check + // looks for "Jinja" (the templating engine used by many models) or + // "template" in the error body. If this proves too brittle in practice, + // consider adding a specific error code or flag to the model runner API. + if toolsSupported && isTemplateIncompatibleError(body) { + toolSchemas = nil + toolMap = nil + toolsSupported = false + messages = initialMessages + assistantResponse.Reset() + continue + } + return assistantResponse.String(), fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body) } - // Extract content from non-streaming response - if len(nonStreamResp.Choices) > 0 && nonStreamResp.Choices[0].Message.Content != "" { - content := nonStreamResp.Choices[0].Message.Content - outputFunc(content) - assistantResponse.WriteString(content) - } + printerState := chatPrinterNone - if nonStreamResp.Usage != nil { - finalUsage = nonStreamResp.Usage - } - } else { - // SSE streaming response - process line by line - scanner := bufio.NewScanner(br) - - for scanner.Scan() { - // Check if context was cancelled - select { - case <-ctx.Done(): - return assistantResponse.String(), ctx.Err() - default: - } + // Accumulated tool calls for this iteration, keyed by index. + pendingToolCalls := make(map[int]*accumulatedToolCall) + var finishReason string - line := scanner.Text() + // Use a buffered reader so we can consume server-sent progress + // lines (e.g. "Installing vllm-metal backend...") that arrive + // before the actual SSE or JSON inference response. + br := bufio.NewReader(resp.Body) + + // Consume any plain-text progress lines that precede the real + // response. We peek ahead: if the next non-empty content starts + // with '{' (JSON) or "data:" / ":" (SSE), the progress section + // is over and we fall through to normal processing. + for { + peek, err := br.Peek(1) + if err != nil { + break + } + // JSON object or SSE stream — stop consuming progress lines. + if peek[0] == '{' || peek[0] == ':' { + break + } + line, err := br.ReadString('\n') + if err != nil && line == "" { + break + } + line = strings.TrimRight(line, "\r\n") if line == "" { continue } + // SSE data line — stop, let the normal SSE parser handle it. + if strings.HasPrefix(line, "data:") { + // Put the line back by chaining a reader with the rest. + br = bufio.NewReader(io.MultiReader( + strings.NewReader(line+"\n"), + br, + )) + break + } + // Progress message — print to stderr. + fmt.Fprintln(os.Stderr, line) + } - if !strings.HasPrefix(line, "data: ") { - continue + // Detect streaming vs non-streaming response. Because server-sent + // progress lines may have been flushed before the Content-Type was + // set, we also peek at the body content to detect SSE. + isStreaming := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + if !isStreaming { + if peek, err := br.Peek(5); err == nil { + isStreaming = strings.HasPrefix(string(peek), "data:") } + } - data := strings.TrimPrefix(line, "data: ") + if !isStreaming { + // Non-streaming JSON response + body, err := io.ReadAll(br) + resp.Body.Close() + if err != nil { + return assistantResponse.String(), fmt.Errorf("error reading response body: %w", err) + } - if data == "[DONE]" { - break + var nonStreamResp OpenAIChatResponse + if err := json.Unmarshal(body, &nonStreamResp); err != nil { + return assistantResponse.String(), fmt.Errorf("error parsing response: %w", err) } - var streamResp OpenAIChatResponse - if err := json.Unmarshal([]byte(data), &streamResp); err != nil { - return assistantResponse.String(), fmt.Errorf("error parsing stream response: %w", err) + // Extract content from non-streaming response + if len(nonStreamResp.Choices) > 0 { + if nonStreamResp.Choices[0].Message.Content != "" { + content := nonStreamResp.Choices[0].Message.Content + outputFunc(content) + assistantResponse.WriteString(content) + } + finishReason = nonStreamResp.Choices[0].FinishReason + for _, tc := range nonStreamResp.Choices[0].Message.ToolCalls { + atc := &accumulatedToolCall{id: tc.ID, name: tc.Function.Name} + atc.arguments.WriteString(tc.Function.Arguments) + pendingToolCalls[tc.Index] = atc + } } - if streamResp.Usage != nil { - finalUsage = streamResp.Usage + if nonStreamResp.Usage != nil { + finalUsage = nonStreamResp.Usage } + } else { + // SSE streaming response - process line by line + scanner := bufio.NewScanner(br) + + for scanner.Scan() { + // Check if context was cancelled + select { + case <-ctx.Done(): + resp.Body.Close() + return assistantResponse.String(), ctx.Err() + default: + } + + line := scanner.Text() + if line == "" { + continue + } - if len(streamResp.Choices) > 0 { - if streamResp.Choices[0].Delta.ReasoningContent != "" { - chunk := streamResp.Choices[0].Delta.ReasoningContent - if printerState == chatPrinterContent { - outputFunc("\n\n") + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + if data == "[DONE]" { + break + } + + var streamResp OpenAIChatResponse + if err := json.Unmarshal([]byte(data), &streamResp); err != nil { + resp.Body.Close() + return assistantResponse.String(), fmt.Errorf("error parsing stream response: %w", err) + } + + if streamResp.Usage != nil { + finalUsage = streamResp.Usage + } + + if len(streamResp.Choices) > 0 { + choice := streamResp.Choices[0] + + if choice.FinishReason != "" { + finishReason = choice.FinishReason } - if printerState != chatPrinterReasoning { - const thinkingHeader = "Thinking:\n" + + // Accumulate tool call fragments. + for _, tc := range choice.Delta.ToolCalls { + atc, ok := pendingToolCalls[tc.Index] + if !ok { + atc = &accumulatedToolCall{} + pendingToolCalls[tc.Index] = atc + } + if tc.ID != "" { + atc.id = tc.ID + } + if tc.Function.Name != "" { + atc.name = tc.Function.Name + } + atc.arguments.WriteString(tc.Function.Arguments) + } + + if choice.Delta.ReasoningContent != "" { + chunk := choice.Delta.ReasoningContent + if printerState == chatPrinterContent { + outputFunc("\n\n") + } + if printerState != chatPrinterReasoning { + const thinkingHeader = "Thinking:\n" + if reasoningFmt != nil { + reasoningFmt.Print(thinkingHeader) + } else { + outputFunc(thinkingHeader) + } + } + printerState = chatPrinterReasoning if reasoningFmt != nil { - reasoningFmt.Print(thinkingHeader) + reasoningFmt.Print(chunk) } else { - outputFunc(thinkingHeader) + outputFunc(chunk) } } - printerState = chatPrinterReasoning - if reasoningFmt != nil { - reasoningFmt.Print(chunk) - } else { + if choice.Delta.Content != "" { + chunk := choice.Delta.Content + if printerState == chatPrinterReasoning { + outputFunc("\n\n--\n\n") + } + printerState = chatPrinterContent outputFunc(chunk) + assistantResponse.WriteString(chunk) } } - if streamResp.Choices[0].Delta.Content != "" { - chunk := streamResp.Choices[0].Delta.Content - if printerState == chatPrinterReasoning { - outputFunc("\n\n--\n\n") + } + + resp.Body.Close() + if err := scanner.Err(); err != nil { + return assistantResponse.String(), fmt.Errorf("error reading response stream: %w", err) + } + } + + // If the model requested tool calls, execute them and loop. + if finishReason == "tool_calls" && len(pendingToolCalls) > 0 { + // Build assistant message with the tool calls. + toolCallSlice := make([]ToolCall, 0, len(pendingToolCalls)) + for idx := 0; idx < len(pendingToolCalls); idx++ { + atc, ok := pendingToolCalls[idx] + if !ok { + continue + } + toolCallSlice = append(toolCallSlice, ToolCall{ + ID: atc.id, + Type: "function", + Function: ToolCallFunction{ + Name: atc.name, + Arguments: atc.arguments.String(), + }, + }) + } + messages = append(messages, OpenAIChatMessage{ + Role: "assistant", + ToolCalls: toolCallSlice, + }) + + // Execute each tool and append results. + for _, tc := range toolCallSlice { + var result string + if tool, ok := toolMap[tc.Function.Name]; ok { + var args map[string]any + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + result = fmt.Sprintf("error parsing tool arguments: %v", err) + } else { + var execErr error + result, execErr = tool.Execute(args) + if execErr != nil { + result = fmt.Sprintf("tool execution error: %v", execErr) + } } - printerState = chatPrinterContent - outputFunc(chunk) - assistantResponse.WriteString(chunk) + } else { + result = fmt.Sprintf("unknown tool: %s", tc.Function.Name) } + messages = append(messages, OpenAIChatMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, + }) } + // Reset for next iteration + assistantResponse.Reset() + continue } - if err := scanner.Err(); err != nil { - return assistantResponse.String(), fmt.Errorf("error reading response stream: %w", err) - } + // Normal stop — we're done. + break } if finalUsage != nil { @@ -664,6 +821,12 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag return err } +// ChatWithToolsContext is like ChatWithContext but supports tool calls. +func (c *Client) ChatWithToolsContext(ctx context.Context, model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool, tools ...ClientTool) error { + _, err := c.ChatWithMessagesContext(ctx, model, nil, prompt, imageURLs, outputFunc, shouldUseMarkdown, tools...) + return err +} + func (c *Client) Remove(modelArgs []string, force bool) (string, error) { modelRemoved := "" for _, model := range modelArgs { @@ -1167,3 +1330,20 @@ func (c *Client) RepackageModel(ctx context.Context, source, target string, opts return nil } + +// isTemplateIncompatibleError checks if the error body indicates a chat template +// incompatibility issue. This is used to detect when a model does not support +// tool-specific chat templates (e.g., Jinja template errors). +// +// The function checks for multiple common patterns (case-insensitive): +// - "jinja": the templating engine used by many Hugging Face models +// - "template": generic template-related errors +// +// This string-based detection is necessary because the model runner does not +// provide structured error codes for template incompatibility. If you encounter +// models that fail with template errors but are not detected by this function, +// consider adding additional patterns here. +func isTemplateIncompatibleError(body []byte) bool { + bodyStr := strings.ToLower(string(body)) + return strings.Contains(bodyStr, "jinja") || strings.Contains(bodyStr, "template") +} diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index 2719c64c..4494f7d9 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -10,6 +10,7 @@ import ( mockdesktop "github.com/docker/model-runner/cmd/cli/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -152,6 +153,141 @@ func TestPushRetryOnNetworkError(t *testing.T) { assert.NoError(t, err) } +// mockTool is a minimal ClientTool for testing. +type mockTool struct{ name string } + +func (m *mockTool) Name() string { return m.name } +func (m *mockTool) Schema() Tool { + return Tool{Type: "function", Function: ToolFunction{Name: m.name}} +} +func (m *mockTool) Execute(_ map[string]any) (string, error) { return "result", nil } + +// jinjaErrorBody is the 500 response body that a model with an incompatible chat +// template returns when tools are included in the request. +const jinjaErrorBody = `{"error":{"code":500,"message":"Jinja Exception: Conversation roles must alternate user/assistant/user/assistant/...","type":"server_error"}}` + +// sseResponse builds a minimal SSE response body with a single content chunk. +func sseResponse(content string) string { + return "data: {\"choices\":[{\"delta\":{\"content\":\"" + content + "\"},\"finish_reason\":null,\"index\":0}]}\n\n" + + "data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\",\"index\":0}]}\n\n" + + "data: [DONE]\n\n" +} + +// TestChatWithMessagesContext_JinjaFallbackNoTools verifies that when a model returns a +// 500 Jinja template error (because it doesn't support tool calling), the client +// retries the request without tools and succeeds. +func TestChatWithMessagesContext_JinjaFallbackNoTools(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(NewContextForMock(mockClient)) + + gomock.InOrder( + // First call includes tools → model returns Jinja error + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusInternalServerError, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(jinjaErrorBody)), + }, nil), + // Retry without tools → model responds successfully + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(bytes.NewBufferString(sseResponse("Hello!"))), + }, nil), + ) + + var output string + resp, err := client.ChatWithMessagesContext( + t.Context(), "gemma3", nil, "hi", nil, + func(s string) { output += s }, + false, + &mockTool{name: "web_search"}, + ) + require.NoError(t, err) + assert.Equal(t, "Hello!", resp) + assert.Equal(t, "Hello!", output) +} + +// toolCallSSEResponse returns an SSE stream that emits a tool_call finish and then +// a single tool call for the given tool name with empty arguments. +func toolCallSSEResponse(toolName string) string { + return `data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call1","type":"function","function":{"name":"` + toolName + `","arguments":"{}"}}]},"finish_reason":null,"index":0}]}` + "\n\n" + + `data: {"choices":[{"delta":{},"finish_reason":"tool_calls","index":0}]}` + "\n\n" + + "data: [DONE]\n\n" +} + +// TestChatWithMessagesContext_JinjaFallbackAfterToolCall verifies that when a model +// successfully executes a tool call but then fails with a Jinja error when the tool +// result is sent back (because it doesn't support the "tool" role), the client resets +// to the original messages and retries without tools. +func TestChatWithMessagesContext_JinjaFallbackAfterToolCall(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(NewContextForMock(mockClient)) + + gomock.InOrder( + // First call with tools → model responds with a tool_call + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(bytes.NewBufferString(toolCallSSEResponse("web_search"))), + }, nil), + // Second call with tool results → model returns Jinja error (can't handle "tool" role) + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusInternalServerError, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(jinjaErrorBody)), + }, nil), + // Third call: reset to original messages, no tools → model responds successfully + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(bytes.NewBufferString(sseResponse("Here is the news."))), + }, nil), + ) + + var output string + resp, err := client.ChatWithMessagesContext( + t.Context(), "gemma3", nil, "Tell me the news", nil, + func(s string) { output += s }, + false, + &mockTool{name: "web_search"}, + ) + require.NoError(t, err) + assert.Equal(t, "Here is the news.", resp) + assert.Equal(t, "Here is the news.", output) +} + +// TestChatWithMessagesContext_Non500ErrorNotRetried verifies that non-Jinja 500 errors +// are not silently retried without tools. +func TestChatWithMessagesContext_Non500ErrorNotRetried(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) + client := New(NewContextForMock(mockClient)) + + // Only one call should be made — no retry for unrelated errors. + mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{ + StatusCode: http.StatusInternalServerError, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(`{"error":"out of memory"}`)), + }, nil).Times(1) + + _, err := client.ChatWithMessagesContext( + t.Context(), "gemma3", nil, "hi", nil, + func(string) {}, + false, + &mockTool{name: "web_search"}, + ) + assert.Error(t, err) + assert.Contains(t, err.Error(), "out of memory") +} + func TestIsRetryableError(t *testing.T) { tests := []struct { name string @@ -180,3 +316,26 @@ func TestIsRetryableError(t *testing.T) { }) } } + +func TestIsTemplateIncompatibleError(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + {"empty body", "", false}, + {"jinja error", `{"error":"Jinja template error: unsupported role"}`, true}, + {"template error", `{"error":"template does not support tools"}`, true}, + {"generic error", `{"error":"out of memory"}`, false}, + {"jinja in message", "model failed: Jinja exception in chat template", true}, + {"Template capitalized", `{"error":"Template rendering failed"}`, true}, + {"JINJA uppercase", `{"error":"JINJA EXCEPTION"}`, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isTemplateIncompatibleError([]byte(tt.body)) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/cmd/cli/tools/websearch.go b/cmd/cli/tools/websearch.go new file mode 100644 index 00000000..b4c31913 --- /dev/null +++ b/cmd/cli/tools/websearch.go @@ -0,0 +1,132 @@ +package tools + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/docker/model-runner/cmd/cli/desktop" +) + +const ( + exaMCPURL = "https://mcp.exa.ai/mcp" + searchTimeout = 25 * time.Second + defaultNumResults = 8 +) + +// WebSearchTool implements web search via Exa's MCP API. +type WebSearchTool struct{} + +// Name returns the tool name. +func (w *WebSearchTool) Name() string { return "web_search" } + +// Schema returns the OpenAI tool definition for web search. +func (w *WebSearchTool) Schema() desktop.Tool { + return desktop.Tool{ + Type: "function", + Function: desktop.ToolFunction{ + Name: w.Name(), + Description: fmt.Sprintf("Search the web for current information up to %d results. Use this when you need up-to-date information that may not be in your training data.", defaultNumResults), + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The search query", + }, + }, + "required": []string{"query"}, + }, + }, + } +} + +// Execute performs the web search using Exa's MCP endpoint. +func (w *WebSearchTool) Execute(args map[string]any) (string, error) { + query, ok := args["query"].(string) + if !ok || query == "" { + return "", fmt.Errorf("query parameter is required") + } + + reqBody := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "web_search_exa", + "arguments": map[string]any{ + "query": query, + "numResults": defaultNumResults, + "type": "auto", + "livecrawl": "fallback", + }, + }, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("marshaling request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), searchTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, exaMCPURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("executing search: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("search API returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading response: %w", err) + } + + // Response may be SSE or plain JSON — handle both. + responseText := string(body) + + // For SSE (text/event-stream), accumulate all data: lines into a single payload. + var dataLines []string + for _, line := range strings.Split(responseText, "\n") { + if strings.HasPrefix(line, "data: ") { + dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) + } + } + if len(dataLines) > 0 { + responseText = strings.Join(dataLines, "\n") + } + + var mcpResp struct { + Result struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"result"` + } + if err := json.Unmarshal([]byte(responseText), &mcpResp); err != nil { + return "", fmt.Errorf("parsing response: %w", err) + } + + if len(mcpResp.Result.Content) > 0 { + return mcpResp.Result.Content[0].Text, nil + } + return "No search results found.", nil +}