From 4dd00fa5dd8c6eef9a9926f1b484a837319d3b42 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 16:41:40 +0100 Subject: [PATCH 1/5] use types to encode and decode jsonrpc queries --- go/client.go | 603 +++++++------------------------ go/client_test.go | 23 +- go/internal/jsonrpc2/jsonrpc2.go | 150 +++++--- go/session.go | 230 +++--------- go/types.go | 295 +++++++++++---- 5 files changed, 505 insertions(+), 796 deletions(-) diff --git a/go/client.go b/go/client.go index d45d3447..a6d3e950 100644 --- a/go/client.go +++ b/go/client.go @@ -396,36 +396,6 @@ func (c *Client) ForceStop() { } } -// buildProviderParams converts a ProviderConfig to a map for JSON-RPC params. -func buildProviderParams(p *ProviderConfig) map[string]any { - params := make(map[string]any) - if p.Type != "" { - params["type"] = p.Type - } - if p.WireApi != "" { - params["wireApi"] = p.WireApi - } - if p.BaseURL != "" { - params["baseUrl"] = p.BaseURL - } - if p.APIKey != "" { - params["apiKey"] = p.APIKey - } - if p.BearerToken != "" { - params["bearerToken"] = p.BearerToken - } - if p.Azure != nil { - azure := make(map[string]any) - if p.Azure.APIVersion != "" { - azure["apiVersion"] = p.Azure.APIVersion - } - if len(azure) > 0 { - params["azure"] = azure - } - } - return params -} - func (c *Client) ensureConnected() error { if c.client != nil { return nil @@ -467,166 +437,54 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return nil, err } - params := make(map[string]any) + req := createSessionRequest{} if config != nil { - if config.Model != "" { - params["model"] = config.Model - } - if config.SessionID != "" { - params["sessionId"] = config.SessionID - } - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - // Add system message configuration if provided - if config.SystemMessage != nil { - systemMessage := make(map[string]any) - - if config.SystemMessage.Mode != "" { - systemMessage["mode"] = config.SystemMessage.Mode - } + req.Model = config.Model + req.SessionID = config.SessionID + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + req.Tools = config.Tools + req.SystemMessage = config.SystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.Provider = config.Provider + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions - if config.SystemMessage.Mode == "replace" { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } else { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } - - if len(systemMessage) > 0 { - params["systemMessage"] = systemMessage - } - } - // Add tool filtering options - if len(config.AvailableTools) > 0 { - params["availableTools"] = config.AvailableTools - } - if len(config.ExcludedTools) > 0 { - params["excludedTools"] = config.ExcludedTools - } - // Add streaming option if config.Streaming { - params["streaming"] = config.Streaming - } - // Add provider configuration - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory - } - // Add MCP servers configuration - if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add config directory override - if config.ConfigDir != "" { - params["configDir"] = config.ConfigDir - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills - } - // Add infinite sessions configuration - if config.InfiniteSessions != nil { - infiniteSessions := make(map[string]any) - if config.InfiniteSessions.Enabled != nil { - infiniteSessions["enabled"] = *config.InfiniteSessions.Enabled - } - if config.InfiniteSessions.BackgroundCompactionThreshold != nil { - infiniteSessions["backgroundCompactionThreshold"] = *config.InfiniteSessions.BackgroundCompactionThreshold - } - if config.InfiniteSessions.BufferExhaustionThreshold != nil { - infiniteSessions["bufferExhaustionThreshold"] = *config.InfiniteSessions.BufferExhaustionThreshold - } - params["infiniteSessions"] = infiniteSessions + req.Hooks = Bool(true) } } - result, err := c.client.Request("session.create", params) + result, err := c.client.Request("session.create", req) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } - sessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(sessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) @@ -644,7 +502,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } c.sessionsMux.Lock() - c.sessions[sessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -676,119 +534,52 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, err } - params := map[string]any{ - "sessionId": sessionID, - } - + var req resumeSessionRequest + req.SessionID = sessionID if config != nil { - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) - } - // Add streaming option + req.ReasoningEffort = config.ReasoningEffort + req.Tools = config.Tools + req.Provider = config.Provider if config.Streaming { - params["streaming"] = config.Streaming + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag - if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + if config.OnPermissionRequest != nil { + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory + req.Hooks = Bool(true) } - // Add disable resume flag + req.WorkingDirectory = config.WorkingDirectory if config.DisableResume { - params["disableResume"] = true + req.DisableResume = Bool(true) } - // Add MCP servers configuration if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills + req.MCPServers = config.MCPServers } + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills } - result, err := c.client.Request("session.resume", params) + result, err := c.client.Request("session.resume", req) if err != nil { return nil, fmt.Errorf("failed to resume session: %w", err) } - resumedSessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response resumeSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(resumedSessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) if config.OnPermissionRequest != nil { @@ -805,7 +596,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, } c.sessionsMux.Lock() - c.sessions[resumedSessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -830,19 +621,13 @@ func (c *Client) ListSessions(ctx context.Context) ([]SessionMetadata, error) { return nil, err } - result, err := c.client.Request("session.list", map[string]any{}) + result, err := c.client.Request("session.list", listSessionsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal sessions response: %w", err) - } - - var response ListSessionsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listSessionsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal sessions response: %w", err) } @@ -864,23 +649,13 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { return err } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.delete", params) + result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID}) if err != nil { return err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal delete response: %w", err) - } - - var response DeleteSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response deleteSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal delete response: %w", err) } @@ -925,18 +700,13 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { } } - result, err := c.client.Request("session.getForeground", map[string]any{}) + result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{}) if err != nil { return nil, err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal getForeground response: %w", err) - } - - var response GetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response getForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal getForeground response: %w", err) } @@ -964,22 +734,13 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e } } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.setForeground", params) + result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID}) if err != nil { return err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal setForeground response: %w", err) - } - - var response SetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response setForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal setForeground response: %w", err) } @@ -1057,7 +818,7 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } // dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers -func (c *Client) dispatchLifecycleEvent(event SessionLifecycleEvent) { +func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks typedHandlers := make([]SessionLifecycleHandler, 0) @@ -1111,87 +872,57 @@ func (c *Client) State() ConnectionState { // } else { // log.Printf("Server responded at %d", resp.Timestamp) // } -func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error) { +func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - params := map[string]any{} - if message != "" { - params["message"] = message - } - - result, err := c.client.Request("ping", params) + result, err := c.client.Request("ping", pingRequest{Message: message}) if err != nil { return nil, err } - response := &PingResponse{} - if msg, ok := result["message"].(string); ok { - response.Message = msg - } - if ts, ok := result["timestamp"].(float64); ok { - response.Timestamp = int64(ts) - } - if pv, ok := result["protocolVersion"].(float64); ok { - v := int(pv) - response.ProtocolVersion = &v + var response pingResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetStatus returns CLI status including version and protocol information -func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { +func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("status.get", map[string]any{}) + result, err := c.client.Request("status.get", getStatusRequest{}) if err != nil { return nil, err } - response := &GetStatusResponse{} - if v, ok := result["version"].(string); ok { - response.Version = v - } - if pv, ok := result["protocolVersion"].(float64); ok { - response.ProtocolVersion = int(pv) + var response getStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetAuthStatus returns current authentication status -func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, error) { +func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("auth.getStatus", map[string]any{}) + result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{}) if err != nil { return nil, err } - response := &GetAuthStatusResponse{} - if v, ok := result["isAuthenticated"].(bool); ok { - response.IsAuthenticated = v - } - if v, ok := result["authType"].(string); ok { - response.AuthType = &v - } - if v, ok := result["host"].(string); ok { - response.Host = &v - } - if v, ok := result["login"].(string); ok { - response.Login = &v - } - if v, ok := result["statusMessage"].(string); ok { - response.StatusMessage = &v + var response getAuthStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // ListModels returns available models with their metadata. @@ -1216,19 +947,13 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { } // Cache miss - fetch from backend while holding lock - result, err := c.client.Request("models.list", map[string]any{}) + result, err := c.client.Request("models.list", listModelsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal models response: %w", err) - } - - var response GetModelsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listModelsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal models response: %w", err) } @@ -1422,82 +1147,48 @@ func (c *Client) connectViaTcp(ctx context.Context) error { // setupNotificationHandler configures handlers for session events, tool calls, and permission requests. func (c *Client) setupNotificationHandler() { - c.client.SetNotificationHandler(func(method string, params map[string]any) { - switch method { - case "session.event": - // Extract sessionId and event - sessionID, ok := params["sessionId"].(string) - if !ok { - return - } - - // Marshal the event back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(params["event"]) - if err != nil { - return - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - return - } - - // Dispatch to session - c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] - c.sessionsMux.Unlock() - - if ok { - session.dispatchEvent(event) - } - case "session.lifecycle": - // Handle session lifecycle events - eventJSON, err := json.Marshal(params) - if err != nil { - return - } - - var event SessionLifecycleEvent - if err := json.Unmarshal(eventJSON, &event); err != nil { - return - } + c.client.SetRequestHandler("session.event", jsonrpc2.NotificationHandlerFor(c.handleSessionEvent)) + c.client.SetRequestHandler("session.lifecycle", jsonrpc2.NotificationHandlerFor(c.handleLifecycleEvent)) + c.client.SetRequestHandler("tool.call", jsonrpc2.RequestHandlerFor(c.handleToolCallRequest)) + c.client.SetRequestHandler("permission.request", jsonrpc2.RequestHandlerFor(c.handlePermissionRequest)) + c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) + c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) +} - c.dispatchLifecycleEvent(event) - } - }) +func (c *Client) handleSessionEvent(req sessionEventRequest) { + if req.SessionID == "" { + return + } + // Dispatch to session + c.sessionsMux.Lock() + session, ok := c.sessions[req.SessionID] + c.sessionsMux.Unlock() - c.client.SetRequestHandler("tool.call", c.handleToolCallRequest) - c.client.SetRequestHandler("permission.request", c.handlePermissionRequest) - c.client.SetRequestHandler("userInput.request", c.handleUserInputRequest) - c.client.SetRequestHandler("hooks.invoke", c.handleHooksInvoke) + if ok { + session.dispatchEvent(req.Event) + } } // handleToolCallRequest handles a tool call request from the CLI server. -func (c *Client) handleToolCallRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - toolCallID, _ := params["toolCallId"].(string) - toolName, _ := params["toolName"].(string) - - if sessionID == "" || toolCallID == "" || toolName == "" { +func (c *Client) handleToolCallRequest(req toolCallRequest) (*toolCallResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.ToolCallID == "" || req.ToolName == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - handler, ok := session.getToolHandler(toolName) + handler, ok := session.getToolHandler(req.ToolName) if !ok { - return map[string]any{"result": buildUnsupportedToolResult(toolName)}, nil + return &toolCallResponse{Result: buildUnsupportedToolResult(req.ToolName)}, nil } - arguments := params["arguments"] - result := c.executeToolCall(sessionID, toolCallID, toolName, arguments, handler) - - return map[string]any{"result": result}, nil + result := c.executeToolCall(req.SessionID, req.ToolCallID, req.ToolName, req.Arguments, handler) + return &toolCallResponse{Result: result}, nil } // executeToolCall executes a tool handler and returns the result. @@ -1531,100 +1222,70 @@ func (c *Client) executeToolCall( } // handlePermissionRequest handles a permission request from the CLI server. -func (c *Client) handlePermissionRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - permissionRequest, _ := params["permissionRequest"].(map[string]any) - - if sessionID == "" { +func (c *Client) handlePermissionRequest(req permissionRequestRequest) (*permissionRequestResponse, *jsonrpc2.Error) { + if req.SessionID == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - result, err := session.handlePermissionRequest(permissionRequest) + result, err := session.handlePermissionRequest(req.Request) if err != nil { // Return denial on error - return map[string]any{ - "result": map[string]any{ - "kind": "denied-no-approval-rule-and-could-not-request-from-user", + return &permissionRequestResponse{ + Result: PermissionRequestResult{ + Kind: "denied-no-approval-rule-and-could-not-request-from-user", }, }, nil } - return map[string]any{"result": result}, nil + return &permissionRequestResponse{Result: result}, nil } // handleUserInputRequest handles a user input request from the CLI server. -func (c *Client) handleUserInputRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - question, _ := params["question"].(string) - - if sessionID == "" || question == "" { +func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - // Parse choices - var choices []string - if choicesRaw, ok := params["choices"].([]any); ok { - for _, choice := range choicesRaw { - if s, ok := choice.(string); ok { - choices = append(choices, s) - } - } - } - - var allowFreeform *bool - if af, ok := params["allowFreeform"].(bool); ok { - allowFreeform = &af - } - - request := UserInputRequest{ - Question: question, - Choices: choices, - AllowFreeform: allowFreeform, - } - - response, err := session.handleUserInputRequest(request) + response, err := session.handleUserInputRequest(UserInputRequest{ + Question: req.Question, + Choices: req.Choices, + AllowFreeform: req.AllowFreeform, + }) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } - return map[string]any{ - "answer": response.Answer, - "wasFreeform": response.WasFreeform, - }, nil + return &userInputResponse{Answer: response.Answer, WasFreeform: response.WasFreeform}, nil } // handleHooksInvoke handles a hooks invocation from the CLI server. -func (c *Client) handleHooksInvoke(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - hookType, _ := params["hookType"].(string) - input, _ := params["input"].(map[string]any) - - if sessionID == "" || hookType == "" { +func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jsonrpc2.Error) { + if req.SessionID == "" || req.Type == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - output, err := session.handleHooksInvoke(hookType, input) + output, err := session.handleHooksInvoke(req.Type, req.Input) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } diff --git a/go/client_test.go b/go/client_test.go index 185bb4cb..176dad8c 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -25,25 +25,20 @@ func TestClient_HandleToolCallRequest(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - params := map[string]any{ - "sessionId": session.SessionID, - "toolCallId": "123", - "toolName": "missing_tool", - "arguments": map[string]any{}, + params := toolCallRequest{ + SessionID: session.SessionID, + ToolCallID: "123", + ToolName: "missing_tool", + Arguments: map[string]any{}, } response, _ := client.handleToolCallRequest(params) - result, ok := response["result"].(ToolResult) - if !ok { - t.Fatalf("Expected result to be ToolResult, got %T", response["result"]) + if response.Result.ResultType != "failure" { + t.Errorf("Expected resultType to be 'failure', got %q", response.Result.ResultType) } - if result.ResultType != "failure" { - t.Errorf("Expected resultType to be 'failure', got %q", result.ResultType) - } - - if result.Error != "tool 'missing_tool' not supported" { - t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", result.Error) + if response.Result.Error != "tool 'missing_tool' not supported" { + t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", response.Result.Error) } }) } diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 8e4a0f6a..1a6e17d1 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "reflect" "sync" ) @@ -23,43 +24,39 @@ func (e *Error) Error() string { // Request represents a JSON-RPC 2.0 request type Request struct { JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` + ID json.RawMessage `json:"id"` // nil for notifications Method string `json:"method"` - Params map[string]any `json:"params"` + Params json.RawMessage `json:"params"` +} + +func (r *Request) IsCall() bool { + return len(r.ID) > 0 } // Response represents a JSON-RPC 2.0 response type Response struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` - Result map[string]any `json:"result,omitempty"` + Result json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` } -// Notification represents a JSON-RPC 2.0 notification -type Notification struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]any `json:"params"` -} - // NotificationHandler handles incoming notifications -type NotificationHandler func(method string, params map[string]any) +type NotificationHandler func(method string, params json.RawMessage) // RequestHandler handles incoming server requests and returns a result or error -type RequestHandler func(params map[string]any) (map[string]any, *Error) +type RequestHandler func(params json.RawMessage) (json.RawMessage, *Error) // Client is a minimal JSON-RPC 2.0 client for stdio transport type Client struct { - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - pendingRequests map[string]chan *Response - notificationHandler NotificationHandler - requestHandlers map[string]RequestHandler - running bool - stopChan chan struct{} - wg sync.WaitGroup + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + pendingRequests map[string]chan *Response + requestHandlers map[string]RequestHandler + running bool + stopChan chan struct{} + wg sync.WaitGroup } // NewClient creates a new JSON-RPC client @@ -96,11 +93,55 @@ func (c *Client) Stop() { c.wg.Wait() } -// SetNotificationHandler sets the handler for incoming notifications -func (c *Client) SetNotificationHandler(handler NotificationHandler) { - c.mu.Lock() - defer c.mu.Unlock() - c.notificationHandler = handler +func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + handler(in) + return nil, nil + } +} + +// RequestHandlerFor creates a RequestHandler from a typed function +func RequestHandlerFor[In, Out any](handler func(params In) (Out, *Error)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + out, errj := handler(in) + if errj != nil { + return nil, errj + } + outData, err := json.Marshal(out) + if err != nil { + return nil, &Error{ + Code: -32603, + Message: fmt.Sprintf("Failed to marshal response: %v", err), + } + } + return outData, nil + } } // SetRequestHandler registers a handler for incoming requests from the server @@ -115,7 +156,7 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) { } // Request sends a JSON-RPC request and waits for the response -func (c *Client) Request(method string, params map[string]any) (map[string]any, error) { +func (c *Client) Request(method string, params any) (json.RawMessage, error) { requestID := generateUUID() // Create response channel @@ -131,12 +172,17 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, c.mu.Unlock() }() + paramsData, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + // Send request request := Request{ JSONRPC: "2.0", ID: json.RawMessage(`"` + requestID + `"`), Method: method, - Params: params, + Params: json.RawMessage(paramsData), } if err := c.sendMessage(request); err != nil { @@ -156,11 +202,16 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, } // Notify sends a JSON-RPC notification (no response expected) -func (c *Client) Notify(method string, params map[string]any) error { - notification := Notification{ +func (c *Client) Notify(method string, params any) error { + paramsData, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + + notification := Request{ JSONRPC: "2.0", Method: method, - Params: params, + Params: json.RawMessage(paramsData), } return c.sendMessage(notification) } @@ -231,7 +282,7 @@ func (c *Client) readLoop() { // Try to parse as request first (has both ID and Method) var request Request - if err := json.Unmarshal(body, &request); err == nil && request.Method != "" && len(request.ID) > 0 { + if err := json.Unmarshal(body, &request); err == nil && request.Method != "" { c.handleRequest(&request) continue } @@ -242,13 +293,6 @@ func (c *Client) readLoop() { c.handleResponse(&response) continue } - - // Try to parse as notification (has Method but no ID) - var notification Notification - if err := json.Unmarshal(body, ¬ification); err == nil && notification.Method != "" { - c.handleNotification(¬ification) - continue - } } } @@ -270,47 +314,41 @@ func (c *Client) handleResponse(response *Response) { } } -// handleNotification dispatches a notification to the handler -func (c *Client) handleNotification(notification *Notification) { - c.mu.Lock() - handler := c.notificationHandler - c.mu.Unlock() - - if handler != nil { - handler(notification.Method, notification.Params) - } -} - func (c *Client) handleRequest(request *Request) { c.mu.Lock() handler := c.requestHandlers[request.Method] c.mu.Unlock() if handler == nil { - c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + } return } go func() { defer func() { if r := recover(); r != nil { - c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32603, fmt.Sprintf("request handler panic: %v", r), nil) + } } }() result, err := handler(request.Params) + if !request.IsCall() { + // Only send a response if this is a call + return + } if err != nil { c.sendErrorResponse(request.ID, err.Code, err.Message, err.Data) return } - if result == nil { - result = make(map[string]any) - } c.sendResponse(request.ID, result) }() } -func (c *Client) sendResponse(id json.RawMessage, result map[string]any) { +func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) { response := Response{ JSONRPC: "2.0", ID: id, diff --git a/go/session.go b/go/session.go index e4f1473d..5d494710 100644 --- a/go/session.go +++ b/go/session.go @@ -106,29 +106,23 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // log.Printf("Failed to send message: %v", err) // } func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { - params := map[string]any{ - "sessionId": s.SessionID, - "prompt": options.Prompt, + req := sessionSendRequest{ + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, } - if options.Attachments != nil { - params["attachments"] = options.Attachments - } - if options.Mode != "" { - params["mode"] = options.Mode - } - - result, err := s.client.Request("session.send", params) + result, err := s.client.Request("session.send", req) if err != nil { return "", fmt.Errorf("failed to send message: %w", err) } - messageID, ok := result["messageId"].(string) - if !ok { - return "", fmt.Errorf("invalid response: missing messageId") + var response sessionSendResponse + if err := json.Unmarshal(result, &response); err != nil { + return "", fmt.Errorf("failed to unmarshal send response: %w", err) } - - return messageID, nil + return response.MessageID, nil } // SendAndWait sends a message to this session and waits until the session becomes idle. @@ -306,7 +300,7 @@ func (s *Session) getPermissionHandler() PermissionHandler { // handlePermissionRequest handles a permission request from the Copilot CLI. // This is an internal method called by the SDK when the CLI requests permission. -func (s *Session) handlePermissionRequest(requestData map[string]any) (PermissionRequestResult, error) { +func (s *Session) handlePermissionRequest(request PermissionRequest) (PermissionRequestResult, error) { handler := s.getPermissionHandler() if handler == nil { @@ -315,16 +309,6 @@ func (s *Session) handlePermissionRequest(requestData map[string]any) (Permissio }, nil } - // Convert map to PermissionRequest struct - kind, _ := requestData["kind"].(string) - toolCallID, _ := requestData["toolCallId"].(string) - - request := PermissionRequest{ - Kind: kind, - ToolCallID: toolCallID, - Extra: requestData, - } - invocation := PermissionInvocation{ SessionID: s.SessionID, } @@ -388,7 +372,7 @@ func (s *Session) getHooks() *SessionHooks { // handleHooksInvoke handles a hook invocation from the Copilot CLI. // This is an internal method called by the SDK when the CLI invokes a hook. -func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, error) { +func (s *Session) handleHooksInvoke(hookType string, rawInput json.RawMessage) (any, error) { hooks := s.getHooks() if hooks == nil { @@ -404,153 +388,66 @@ func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, if hooks.OnPreToolUse == nil { return nil, nil } - hookInput := parsePreToolUseInput(input) - return hooks.OnPreToolUse(hookInput, invocation) + var input PreToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPreToolUse(input, invocation) case "postToolUse": if hooks.OnPostToolUse == nil { return nil, nil } - hookInput := parsePostToolUseInput(input) - return hooks.OnPostToolUse(hookInput, invocation) + var input PostToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPostToolUse(input, invocation) case "userPromptSubmitted": if hooks.OnUserPromptSubmitted == nil { return nil, nil } - hookInput := parseUserPromptSubmittedInput(input) - return hooks.OnUserPromptSubmitted(hookInput, invocation) + var input UserPromptSubmittedHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnUserPromptSubmitted(input, invocation) case "sessionStart": if hooks.OnSessionStart == nil { return nil, nil } - hookInput := parseSessionStartInput(input) - return hooks.OnSessionStart(hookInput, invocation) + var input SessionStartHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionStart(input, invocation) case "sessionEnd": if hooks.OnSessionEnd == nil { return nil, nil } - hookInput := parseSessionEndInput(input) - return hooks.OnSessionEnd(hookInput, invocation) + var input SessionEndHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionEnd(input, invocation) case "errorOccurred": if hooks.OnErrorOccurred == nil { return nil, nil } - hookInput := parseErrorOccurredInput(input) - return hooks.OnErrorOccurred(hookInput, invocation) - + var input ErrorOccurredHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnErrorOccurred(input, invocation) default: return nil, fmt.Errorf("unknown hook type: %s", hookType) } } -// Helper functions to parse hook inputs - -func parsePreToolUseInput(input map[string]any) PreToolUseHookInput { - result := PreToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - return result -} - -func parsePostToolUseInput(input map[string]any) PostToolUseHookInput { - result := PostToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - result.ToolResult = input["toolResult"] - return result -} - -func parseUserPromptSubmittedInput(input map[string]any) UserPromptSubmittedHookInput { - result := UserPromptSubmittedHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if prompt, ok := input["prompt"].(string); ok { - result.Prompt = prompt - } - return result -} - -func parseSessionStartInput(input map[string]any) SessionStartHookInput { - result := SessionStartHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if source, ok := input["source"].(string); ok { - result.Source = source - } - if prompt, ok := input["initialPrompt"].(string); ok { - result.InitialPrompt = prompt - } - return result -} - -func parseSessionEndInput(input map[string]any) SessionEndHookInput { - result := SessionEndHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if reason, ok := input["reason"].(string); ok { - result.Reason = reason - } - if msg, ok := input["finalMessage"].(string); ok { - result.FinalMessage = msg - } - if errStr, ok := input["error"].(string); ok { - result.Error = errStr - } - return result -} - -func parseErrorOccurredInput(input map[string]any) ErrorOccurredHookInput { - result := ErrorOccurredHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if errMsg, ok := input["error"].(string); ok { - result.Error = errMsg - } - if ctx, ok := input["errorContext"].(string); ok { - result.ErrorContext = ctx - } - if rec, ok := input["recoverable"].(bool); ok { - result.Recoverable = rec - } - return result -} - // dispatchEvent dispatches an event to all registered handlers. // This is an internal method; handlers are called synchronously and any panics // are recovered to prevent crashing the event dispatcher. @@ -596,38 +493,17 @@ func (s *Session) dispatchEvent(event SessionEvent) { // } // } func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { - params := map[string]any{ - "sessionId": s.SessionID, - } - result, err := s.client.Request("session.getMessages", params) + result, err := s.client.Request("session.getMessages", sessionGetMessagesRequest{SessionID: s.SessionID}) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } - eventsRaw, ok := result["events"].([]any) - if !ok { - return nil, fmt.Errorf("invalid response: missing events") + var response sessionGetMessagesResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal get messages response: %w", err) } - - // Convert to SessionEvent structs - events := make([]SessionEvent, 0, len(eventsRaw)) - for _, eventRaw := range eventsRaw { - // Marshal back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(eventRaw) - if err != nil { - continue - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - continue - } - - events = append(events, event) - } - - return events, nil + return response.Events, nil } // Destroy destroys this session and releases all associated resources. @@ -645,11 +521,7 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // log.Printf("Failed to destroy session: %v", err) // } func (s *Session) Destroy() error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.destroy", params) + _, err := s.client.Request("session.destroy", sessionDestroyRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to destroy session: %w", err) } @@ -692,11 +564,11 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { - params := map[string]any{ - "sessionId": s.SessionID, + req := sessionAbortRequest{ + SessionID: s.SessionID, } - _, err := s.client.Request("session.abort", params) + _, err := s.client.Request("session.abort", req) if err != nil { return fmt.Errorf("failed to abort session: %w", err) } diff --git a/go/types.go b/go/types.go index 7a1917f0..9ca57dc7 100644 --- a/go/types.go +++ b/go/types.go @@ -1,5 +1,7 @@ package copilot +import "encoding/json" + // ConnectionState represents the client connection state type ConnectionState string @@ -113,15 +115,15 @@ type PermissionInvocation struct { // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { - Question string `json:"question"` - Choices []string `json:"choices,omitempty"` - AllowFreeform *bool `json:"allowFreeform,omitempty"` + Question string + Choices []string + AllowFreeform *bool } // UserInputResponse represents the user's response to an input request type UserInputResponse struct { - Answer string `json:"answer"` - WasFreeform bool `json:"wasFreeform"` + Answer string + WasFreeform bool } // UserInputHandler handles user input requests from the agent @@ -307,13 +309,13 @@ type CustomAgentConfig struct { // limits through background compaction and persist state to a workspace directory. type InfiniteSessionConfig struct { // Enabled controls whether infinite sessions are enabled (default: true) - Enabled *bool + Enabled *bool `json:"enabled,omitempty"` // BackgroundCompactionThreshold is the context utilization (0.0-1.0) at which // background compaction starts. Default: 0.80 - BackgroundCompactionThreshold *float64 + BackgroundCompactionThreshold *float64 `json:"backgroundCompactionThreshold,omitempty"` // BufferExhaustionThreshold is the context utilization (0.0-1.0) at which // the session blocks until compaction completes. Default: 0.95 - BufferExhaustionThreshold *float64 + BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } // SessionConfig configures a new session @@ -369,10 +371,10 @@ type SessionConfig struct { // Tool describes a caller-implemented tool that can be invoked by Copilot type Tool struct { - Name string - Description string // optional - Parameters map[string]any - Handler ToolHandler + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Handler ToolHandler `json:"-"` } // ToolInvocation describes a tool call initiated by Copilot @@ -477,43 +479,6 @@ type MessageOptions struct { // SessionEventHandler is a callback for session events type SessionEventHandler func(event SessionEvent) -// PingResponse is the response from a ping request -type PingResponse struct { - Message string `json:"message"` - Timestamp int64 `json:"timestamp"` - ProtocolVersion *int `json:"protocolVersion,omitempty"` -} - -// SessionCreateResponse is the response from session.create -type SessionCreateResponse struct { - SessionID string `json:"sessionId"` -} - -// SessionSendResponse is the response from session.send -type SessionSendResponse struct { - MessageID string `json:"messageId"` -} - -// SessionGetMessagesResponse is the response from session.getMessages -type SessionGetMessagesResponse struct { - Events []SessionEvent `json:"events"` -} - -// GetStatusResponse is the response from status.get -type GetStatusResponse struct { - Version string `json:"version"` - ProtocolVersion int `json:"protocolVersion"` -} - -// GetAuthStatusResponse is the response from auth.getStatus -type GetAuthStatusResponse struct { - IsAuthenticated bool `json:"isAuthenticated"` - AuthType *string `json:"authType,omitempty"` - Host *string `json:"host,omitempty"` - Login *string `json:"login,omitempty"` - StatusMessage *string `json:"statusMessage,omitempty"` -} - // ModelVisionLimits contains vision-specific limits type ModelVisionLimits struct { SupportedMediaTypes []string `json:"supported_media_types"` @@ -562,11 +527,6 @@ type ModelInfo struct { DefaultReasoningEffort string `json:"defaultReasoningEffort,omitempty"` } -// GetModelsResponse is the response from models.list -type GetModelsResponse struct { - Models []ModelInfo `json:"models"` -} - // SessionMetadata contains metadata about a session type SessionMetadata struct { SessionID string `json:"sessionId"` @@ -576,22 +536,6 @@ type SessionMetadata struct { IsRemote bool `json:"isRemote"` } -// ListSessionsResponse is the response from session.list -type ListSessionsResponse struct { - Sessions []SessionMetadata `json:"sessions"` -} - -// DeleteSessionRequest is the request for session.delete -type DeleteSessionRequest struct { - SessionID string `json:"sessionId"` -} - -// DeleteSessionResponse is the response from session.delete -type DeleteSessionResponse struct { - Success bool `json:"success"` - Error *string `json:"error,omitempty"` -} - // SessionLifecycleEventType represents the type of session lifecycle event type SessionLifecycleEventType string @@ -620,19 +564,218 @@ type SessionLifecycleEventMetadata struct { // SessionLifecycleHandler is a callback for session lifecycle events type SessionLifecycleHandler func(event SessionLifecycleEvent) -// GetForegroundSessionResponse is the response from session.getForeground -type GetForegroundSessionResponse struct { +// permissionRequestRequest represents the request data for a permission request +type permissionRequestRequest struct { + SessionID string `json:"sessionId"` + Request PermissionRequest `json:"permissionRequest"` +} + +// permissionRequestResponse represents the response to a permission request +type permissionRequestResponse struct { + Result PermissionRequestResult `json:"result"` +} + +// createSessionRequest is the request for session.create +type createSessionRequest struct { + Model string `json:"model,omitempty"` + SessionID string `json:"sessionId,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools,omitempty"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` +} + +// createSessionResponse is the response from session.create +type createSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +// resumeSessionRequest is the request for session.resume +type resumeSessionRequest struct { + SessionID string `json:"sessionId"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + DisableResume *bool `json:"disableResume,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` +} + +// resumeSessionResponse is the response from session.resume +type resumeSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +type hooksInvokeRequest struct { + SessionID string `json:"sessionId"` + Type string `json:"hookType"` + Input json.RawMessage `json:"input"` +} + +// listSessionsRequest is the request for session.list +type listSessionsRequest struct{} + +// listSessionsResponse is the response from session.list +type listSessionsResponse struct { + Sessions []SessionMetadata `json:"sessions"` +} + +// deleteSessionRequest is the request for session.delete +type deleteSessionRequest struct { + SessionID string `json:"sessionId"` +} + +// deleteSessionResponse is the response from session.delete +type deleteSessionResponse struct { + Success bool `json:"success"` + Error *string `json:"error,omitempty"` +} + +// getForegroundSessionRequest is the request for session.getForeground +type getForegroundSessionRequest struct{} + +// getForegroundSessionResponse is the response from session.getForeground +type getForegroundSessionResponse struct { SessionID *string `json:"sessionId,omitempty"` WorkspacePath *string `json:"workspacePath,omitempty"` } -// SetForegroundSessionRequest is the request for session.setForeground -type SetForegroundSessionRequest struct { +// setForegroundSessionRequest is the request for session.setForeground +type setForegroundSessionRequest struct { SessionID string `json:"sessionId"` } -// SetForegroundSessionResponse is the response from session.setForeground -type SetForegroundSessionResponse struct { +// setForegroundSessionResponse is the response from session.setForeground +type setForegroundSessionResponse struct { Success bool `json:"success"` Error *string `json:"error,omitempty"` } + +type pingRequest struct { + Message string `json:"message,omitempty"` +} + +// pingResponse is the response from a ping request +type pingResponse struct { + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` + ProtocolVersion *int `json:"protocolVersion,omitempty"` +} + +// getStatusRequest is the request for status.get +type getStatusRequest struct{} + +// getStatusResponse is the response from status.get +type getStatusResponse struct { + Version string `json:"version"` + ProtocolVersion int `json:"protocolVersion"` +} + +// getAuthStatusRequest is the request for auth.getStatus +type getAuthStatusRequest struct{} + +// getAuthStatusResponse is the response from auth.getStatus +type getAuthStatusResponse struct { + IsAuthenticated bool `json:"isAuthenticated"` + AuthType *string `json:"authType,omitempty"` + Host *string `json:"host,omitempty"` + Login *string `json:"login,omitempty"` + StatusMessage *string `json:"statusMessage,omitempty"` +} + +// listModelsRequest is the request for models.list +type listModelsRequest struct{} + +// listModelsResponse is the response from models.list +type listModelsResponse struct { + Models []ModelInfo `json:"models"` +} + +// sessionGetMessagesRequest is the request for session.getMessages +type sessionGetMessagesRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionGetMessagesResponse is the response from session.getMessages +type sessionGetMessagesResponse struct { + Events []SessionEvent `json:"events"` +} + +// sessionDestroyRequest is the request for session.destroy +type sessionDestroyRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionAbortRequest is the request for session.abort +type sessionAbortRequest struct { + SessionID string `json:"sessionId"` +} + +type sessionSendRequest struct { + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` +} + +// sessionSendResponse is the response from session.send +type sessionSendResponse struct { + MessageID string `json:"messageId"` +} + +// sessionEventRequest is the request for session event notifications +type sessionEventRequest struct { + SessionID string `json:"sessionId"` + Event SessionEvent `json:"event"` +} + +// toolCallRequest represents a tool call request from the server +// to the client for execution. +type toolCallRequest struct { + SessionID string `json:"sessionId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + Arguments any `json:"arguments"` +} + +// toolCallResponse represents the response to a tool call request +// from the client back to the server. +type toolCallResponse struct { + Result ToolResult `json:"result"` +} + +// userInputRequest represents a request for user input from the agent +type userInputRequest struct { + SessionID string `json:"sessionId"` + Question string `json:"question"` + Choices []string `json:"choices,omitempty"` + AllowFreeform *bool `json:"allowFreeform,omitempty"` +} + +// userInputResponse represents the user's response to an input request +type userInputResponse struct { + Answer string `json:"answer"` + WasFreeform bool `json:"wasFreeform"` +} From 94c5f50545b1a5729acddc814d64cebce0893338 Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Wed, 4 Feb 2026 16:51:24 +0100 Subject: [PATCH 2/5] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- go/client.go | 4 ++-- go/internal/jsonrpc2/jsonrpc2.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/client.go b/go/client.go index a6d3e950..53a38b23 100644 --- a/go/client.go +++ b/go/client.go @@ -546,7 +546,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.OnPermissionRequest != nil { req.RequestPermission = Bool(true) } - if config.OnPermissionRequest != nil { + if config.OnUserInputRequest != nil { req.RequestUserInput = Bool(true) } if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || @@ -817,7 +817,7 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } } -// dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers +// handleLifecycleEvent dispatches a lifecycle event to all registered handlers func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 1a6e17d1..a226f11f 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -98,7 +98,7 @@ func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { var in In // If In is a pointer type, allocate the underlying value and unmarshal into it directly var target any = &in - if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + if t := reflect.TypeFor[In](); t.Kind() == reflect.Pointer { in = reflect.New(t.Elem()).Interface().(In) target = in } From d00a6147e472b2aaf8eefa12b30afebfbcd60a0f Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 16:54:02 +0100 Subject: [PATCH 3/5] reexport some types --- go/client.go | 8 ++++---- go/types.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/go/client.go b/go/client.go index 53a38b23..b89f2f36 100644 --- a/go/client.go +++ b/go/client.go @@ -890,7 +890,7 @@ func (c *Client) Ping(ctx context.Context, message string) (*pingResponse, error } // GetStatus returns CLI status including version and protocol information -func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { +func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } @@ -900,7 +900,7 @@ func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { return nil, err } - var response getStatusResponse + var response GetStatusResponse if err := json.Unmarshal(result, &response); err != nil { return nil, err } @@ -908,7 +908,7 @@ func (c *Client) GetStatus(ctx context.Context) (*getStatusResponse, error) { } // GetAuthStatus returns current authentication status -func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, error) { +func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, error) { if c.client == nil { return nil, fmt.Errorf("client not connected") } @@ -918,7 +918,7 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*getAuthStatusResponse, err return nil, err } - var response getAuthStatusResponse + var response GetAuthStatusResponse if err := json.Unmarshal(result, &response); err != nil { return nil, err } diff --git a/go/types.go b/go/types.go index 9ca57dc7..1dee8a43 100644 --- a/go/types.go +++ b/go/types.go @@ -687,8 +687,8 @@ type pingResponse struct { // getStatusRequest is the request for status.get type getStatusRequest struct{} -// getStatusResponse is the response from status.get -type getStatusResponse struct { +// GetStatusResponse is the response from status.get +type GetStatusResponse struct { Version string `json:"version"` ProtocolVersion int `json:"protocolVersion"` } @@ -696,8 +696,8 @@ type getStatusResponse struct { // getAuthStatusRequest is the request for auth.getStatus type getAuthStatusRequest struct{} -// getAuthStatusResponse is the response from auth.getStatus -type getAuthStatusResponse struct { +// GetAuthStatusResponse is the response from auth.getStatus +type GetAuthStatusResponse struct { IsAuthenticated bool `json:"isAuthenticated"` AuthType *string `json:"authType,omitempty"` Host *string `json:"host,omitempty"` From 107d5455b88d985681abbb2e587cf93c08f4c7c8 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 17:24:46 +0100 Subject: [PATCH 4/5] fix race --- go/internal/e2e/mcp_and_agents_test.go | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 1d21651b..244589a1 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -37,18 +37,13 @@ func TestMCPServers(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 2+2?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { t.Errorf("Expected message to contain '4', got: %v", message.Data.Content) } @@ -168,18 +163,13 @@ func TestCustomAgents(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 5+5?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "10") { t.Errorf("Expected message to contain '10', got: %v", message.Data.Content) } @@ -373,18 +363,13 @@ func TestCombinedConfiguration(t *testing.T) { t.Error("Expected non-empty session ID") } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 7+7?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "14") { t.Errorf("Expected message to contain '14', got: %v", message.Data.Content) } From 55759d750b628bff1c5f657324f674fefd279444 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Wed, 4 Feb 2026 17:40:12 +0100 Subject: [PATCH 5/5] fix test --- go/internal/e2e/session_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 62183286..5d225b35 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -68,6 +68,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -77,6 +81,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send second message: %v", err) } + if secondMessage == nil { + t.Fatal("Expected second assistant message, got nil") + } + if secondMessage.Data.Content == nil || !strings.Contains(*secondMessage.Data.Content, "4") { t.Errorf("Expected second message to contain '4', got %v", secondMessage.Data.Content) }