diff --git a/go/client.go b/go/client.go index d45d3447..6692e93d 100644 --- a/go/client.go +++ b/go/client.go @@ -396,36 +396,6 @@ func (c *Client) ForceStop() { } } -// buildProviderParams converts a ProviderConfig to a map for JSON-RPC params. -func buildProviderParams(p *ProviderConfig) map[string]any { - params := make(map[string]any) - if p.Type != "" { - params["type"] = p.Type - } - if p.WireApi != "" { - params["wireApi"] = p.WireApi - } - if p.BaseURL != "" { - params["baseUrl"] = p.BaseURL - } - if p.APIKey != "" { - params["apiKey"] = p.APIKey - } - if p.BearerToken != "" { - params["bearerToken"] = p.BearerToken - } - if p.Azure != nil { - azure := make(map[string]any) - if p.Azure.APIVersion != "" { - azure["apiVersion"] = p.Azure.APIVersion - } - if len(azure) > 0 { - params["azure"] = azure - } - } - return params -} - func (c *Client) ensureConnected() error { if c.client != nil { return nil @@ -467,166 +437,54 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return nil, err } - params := make(map[string]any) + req := createSessionRequest{} if config != nil { - if config.Model != "" { - params["model"] = config.Model - } - if config.SessionID != "" { - params["sessionId"] = config.SessionID - } - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - // Add system message configuration if provided - if config.SystemMessage != nil { - systemMessage := make(map[string]any) - - if config.SystemMessage.Mode != "" { - systemMessage["mode"] = config.SystemMessage.Mode - } + req.Model = config.Model + req.SessionID = config.SessionID + req.ReasoningEffort = config.ReasoningEffort + req.ConfigDir = config.ConfigDir + req.Tools = config.Tools + req.SystemMessage = config.SystemMessage + req.AvailableTools = config.AvailableTools + req.ExcludedTools = config.ExcludedTools + req.Provider = config.Provider + req.WorkingDirectory = config.WorkingDirectory + req.MCPServers = config.MCPServers + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills + req.InfiniteSessions = config.InfiniteSessions - if config.SystemMessage.Mode == "replace" { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } else { - if config.SystemMessage.Content != "" { - systemMessage["content"] = config.SystemMessage.Content - } - } - - if len(systemMessage) > 0 { - params["systemMessage"] = systemMessage - } - } - // Add tool filtering options - if len(config.AvailableTools) > 0 { - params["availableTools"] = config.AvailableTools - } - if len(config.ExcludedTools) > 0 { - params["excludedTools"] = config.ExcludedTools - } - // Add streaming option if config.Streaming { - params["streaming"] = config.Streaming + req.Streaming = Bool(true) } - // Add provider configuration - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) - } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true - } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory - } - // Add MCP servers configuration - if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add config directory override - if config.ConfigDir != "" { - params["configDir"] = config.ConfigDir - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills - } - // Add infinite sessions configuration - if config.InfiniteSessions != nil { - infiniteSessions := make(map[string]any) - if config.InfiniteSessions.Enabled != nil { - infiniteSessions["enabled"] = *config.InfiniteSessions.Enabled - } - if config.InfiniteSessions.BackgroundCompactionThreshold != nil { - infiniteSessions["backgroundCompactionThreshold"] = *config.InfiniteSessions.BackgroundCompactionThreshold - } - if config.InfiniteSessions.BufferExhaustionThreshold != nil { - infiniteSessions["bufferExhaustionThreshold"] = *config.InfiniteSessions.BufferExhaustionThreshold - } - params["infiniteSessions"] = infiniteSessions + req.Hooks = Bool(true) } } - result, err := c.client.Request("session.create", params) + result, err := c.client.Request("session.create", req) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } - sessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response createSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(sessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) @@ -644,7 +502,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses } c.sessionsMux.Lock() - c.sessions[sessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -676,119 +534,52 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, err } - params := map[string]any{ - "sessionId": sessionID, - } - + var req resumeSessionRequest + req.SessionID = sessionID if config != nil { - if config.ReasoningEffort != "" { - params["reasoningEffort"] = config.ReasoningEffort - } - if len(config.Tools) > 0 { - toolDefs := make([]map[string]any, 0, len(config.Tools)) - for _, tool := range config.Tools { - if tool.Name == "" { - continue - } - definition := map[string]any{ - "name": tool.Name, - "description": tool.Description, - } - if tool.Parameters != nil { - definition["parameters"] = tool.Parameters - } - toolDefs = append(toolDefs, definition) - } - if len(toolDefs) > 0 { - params["tools"] = toolDefs - } - } - if config.Provider != nil { - params["provider"] = buildProviderParams(config.Provider) - } - // Add streaming option + req.ReasoningEffort = config.ReasoningEffort + req.Tools = config.Tools + req.Provider = config.Provider if config.Streaming { - params["streaming"] = config.Streaming + req.Streaming = Bool(true) } - // Add permission request flag if config.OnPermissionRequest != nil { - params["requestPermission"] = true + req.RequestPermission = Bool(true) } - // Add user input request flag if config.OnUserInputRequest != nil { - params["requestUserInput"] = true + req.RequestUserInput = Bool(true) } - // Add hooks flag if config.Hooks != nil && (config.Hooks.OnPreToolUse != nil || config.Hooks.OnPostToolUse != nil || config.Hooks.OnUserPromptSubmitted != nil || config.Hooks.OnSessionStart != nil || config.Hooks.OnSessionEnd != nil || config.Hooks.OnErrorOccurred != nil) { - params["hooks"] = true + req.Hooks = Bool(true) } - // Add working directory - if config.WorkingDirectory != "" { - params["workingDirectory"] = config.WorkingDirectory - } - // Add disable resume flag + req.WorkingDirectory = config.WorkingDirectory if config.DisableResume { - params["disableResume"] = true + req.DisableResume = Bool(true) } - // Add MCP servers configuration if len(config.MCPServers) > 0 { - params["mcpServers"] = config.MCPServers - } - // Add custom agents configuration - if len(config.CustomAgents) > 0 { - customAgents := make([]map[string]any, 0, len(config.CustomAgents)) - for _, agent := range config.CustomAgents { - agentMap := map[string]any{ - "name": agent.Name, - "prompt": agent.Prompt, - } - if agent.DisplayName != "" { - agentMap["displayName"] = agent.DisplayName - } - if agent.Description != "" { - agentMap["description"] = agent.Description - } - if len(agent.Tools) > 0 { - agentMap["tools"] = agent.Tools - } - if len(agent.MCPServers) > 0 { - agentMap["mcpServers"] = agent.MCPServers - } - if agent.Infer != nil { - agentMap["infer"] = *agent.Infer - } - customAgents = append(customAgents, agentMap) - } - params["customAgents"] = customAgents - } - // Add skill directories configuration - if len(config.SkillDirectories) > 0 { - params["skillDirectories"] = config.SkillDirectories - } - // Add disabled skills configuration - if len(config.DisabledSkills) > 0 { - params["disabledSkills"] = config.DisabledSkills + req.MCPServers = config.MCPServers } + req.CustomAgents = config.CustomAgents + req.SkillDirectories = config.SkillDirectories + req.DisabledSkills = config.DisabledSkills } - result, err := c.client.Request("session.resume", params) + result, err := c.client.Request("session.resume", req) if err != nil { return nil, fmt.Errorf("failed to resume session: %w", err) } - resumedSessionID, ok := result["sessionId"].(string) - if !ok { - return nil, fmt.Errorf("invalid response: missing sessionId") + var response resumeSessionResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - workspacePath, _ := result["workspacePath"].(string) - - session := newSession(resumedSessionID, c.client, workspacePath) + session := newSession(response.SessionID, c.client, response.WorkspacePath) if config != nil { session.registerTools(config.Tools) if config.OnPermissionRequest != nil { @@ -805,7 +596,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, } c.sessionsMux.Lock() - c.sessions[resumedSessionID] = session + c.sessions[response.SessionID] = session c.sessionsMux.Unlock() return session, nil @@ -830,19 +621,13 @@ func (c *Client) ListSessions(ctx context.Context) ([]SessionMetadata, error) { return nil, err } - result, err := c.client.Request("session.list", map[string]any{}) + result, err := c.client.Request("session.list", listSessionsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal sessions response: %w", err) - } - - var response ListSessionsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listSessionsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal sessions response: %w", err) } @@ -864,23 +649,13 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { return err } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.delete", params) + result, err := c.client.Request("session.delete", deleteSessionRequest{SessionID: sessionID}) if err != nil { return err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal delete response: %w", err) - } - - var response DeleteSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response deleteSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal delete response: %w", err) } @@ -925,18 +700,13 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { } } - result, err := c.client.Request("session.getForeground", map[string]any{}) + result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{}) if err != nil { return nil, err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal getForeground response: %w", err) - } - - var response GetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response getForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal getForeground response: %w", err) } @@ -964,22 +734,13 @@ func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) e } } - params := map[string]any{ - "sessionId": sessionID, - } - - result, err := c.client.Request("session.setForeground", params) + result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID}) if err != nil { return err } - jsonBytes, err := json.Marshal(result) - if err != nil { - return fmt.Errorf("failed to marshal setForeground response: %w", err) - } - - var response SetForegroundSessionResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response setForegroundSessionResponse + if err := json.Unmarshal(result, &response); err != nil { return fmt.Errorf("failed to unmarshal setForeground response: %w", err) } @@ -1056,8 +817,8 @@ func (c *Client) OnEventType(eventType SessionLifecycleEventType, handler Sessio } } -// dispatchLifecycleEvent dispatches a lifecycle event to all registered handlers -func (c *Client) dispatchLifecycleEvent(event SessionLifecycleEvent) { +// handleLifecycleEvent dispatches a lifecycle event to all registered handlers +func (c *Client) handleLifecycleEvent(event SessionLifecycleEvent) { c.lifecycleHandlersMux.Lock() // Copy handlers to avoid holding lock during callbacks typedHandlers := make([]SessionLifecycleHandler, 0) @@ -1116,29 +877,16 @@ func (c *Client) Ping(ctx context.Context, message string) (*PingResponse, error return nil, fmt.Errorf("client not connected") } - params := map[string]any{} - if message != "" { - params["message"] = message - } - - result, err := c.client.Request("ping", params) + result, err := c.client.Request("ping", pingRequest{Message: message}) if err != nil { return nil, err } - response := &PingResponse{} - if msg, ok := result["message"].(string); ok { - response.Message = msg - } - if ts, ok := result["timestamp"].(float64); ok { - response.Timestamp = int64(ts) - } - if pv, ok := result["protocolVersion"].(float64); ok { - v := int(pv) - response.ProtocolVersion = &v + var response PingResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetStatus returns CLI status including version and protocol information @@ -1147,20 +895,16 @@ func (c *Client) GetStatus(ctx context.Context) (*GetStatusResponse, error) { return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("status.get", map[string]any{}) + result, err := c.client.Request("status.get", getStatusRequest{}) if err != nil { return nil, err } - response := &GetStatusResponse{} - if v, ok := result["version"].(string); ok { - response.Version = v - } - if pv, ok := result["protocolVersion"].(float64); ok { - response.ProtocolVersion = int(pv) + var response GetStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // GetAuthStatus returns current authentication status @@ -1169,29 +913,16 @@ func (c *Client) GetAuthStatus(ctx context.Context) (*GetAuthStatusResponse, err return nil, fmt.Errorf("client not connected") } - result, err := c.client.Request("auth.getStatus", map[string]any{}) + result, err := c.client.Request("auth.getStatus", getAuthStatusRequest{}) if err != nil { return nil, err } - response := &GetAuthStatusResponse{} - if v, ok := result["isAuthenticated"].(bool); ok { - response.IsAuthenticated = v - } - if v, ok := result["authType"].(string); ok { - response.AuthType = &v - } - if v, ok := result["host"].(string); ok { - response.Host = &v - } - if v, ok := result["login"].(string); ok { - response.Login = &v - } - if v, ok := result["statusMessage"].(string); ok { - response.StatusMessage = &v + var response GetAuthStatusResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, err } - - return response, nil + return &response, nil } // ListModels returns available models with their metadata. @@ -1216,19 +947,13 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { } // Cache miss - fetch from backend while holding lock - result, err := c.client.Request("models.list", map[string]any{}) + result, err := c.client.Request("models.list", listModelsRequest{}) if err != nil { return nil, err } - // Marshal and unmarshal to convert map to struct - jsonBytes, err := json.Marshal(result) - if err != nil { - return nil, fmt.Errorf("failed to marshal models response: %w", err) - } - - var response GetModelsResponse - if err := json.Unmarshal(jsonBytes, &response); err != nil { + var response listModelsResponse + if err := json.Unmarshal(result, &response); err != nil { return nil, fmt.Errorf("failed to unmarshal models response: %w", err) } @@ -1422,82 +1147,48 @@ func (c *Client) connectViaTcp(ctx context.Context) error { // setupNotificationHandler configures handlers for session events, tool calls, and permission requests. func (c *Client) setupNotificationHandler() { - c.client.SetNotificationHandler(func(method string, params map[string]any) { - switch method { - case "session.event": - // Extract sessionId and event - sessionID, ok := params["sessionId"].(string) - if !ok { - return - } - - // Marshal the event back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(params["event"]) - if err != nil { - return - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - return - } - - // Dispatch to session - c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] - c.sessionsMux.Unlock() - - if ok { - session.dispatchEvent(event) - } - case "session.lifecycle": - // Handle session lifecycle events - eventJSON, err := json.Marshal(params) - if err != nil { - return - } - - var event SessionLifecycleEvent - if err := json.Unmarshal(eventJSON, &event); err != nil { - return - } + c.client.SetRequestHandler("session.event", jsonrpc2.NotificationHandlerFor(c.handleSessionEvent)) + c.client.SetRequestHandler("session.lifecycle", jsonrpc2.NotificationHandlerFor(c.handleLifecycleEvent)) + c.client.SetRequestHandler("tool.call", jsonrpc2.RequestHandlerFor(c.handleToolCallRequest)) + c.client.SetRequestHandler("permission.request", jsonrpc2.RequestHandlerFor(c.handlePermissionRequest)) + c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) + c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) +} - c.dispatchLifecycleEvent(event) - } - }) +func (c *Client) handleSessionEvent(req sessionEventRequest) { + if req.SessionID == "" { + return + } + // Dispatch to session + c.sessionsMux.Lock() + session, ok := c.sessions[req.SessionID] + c.sessionsMux.Unlock() - c.client.SetRequestHandler("tool.call", c.handleToolCallRequest) - c.client.SetRequestHandler("permission.request", c.handlePermissionRequest) - c.client.SetRequestHandler("userInput.request", c.handleUserInputRequest) - c.client.SetRequestHandler("hooks.invoke", c.handleHooksInvoke) + if ok { + session.dispatchEvent(req.Event) + } } // handleToolCallRequest handles a tool call request from the CLI server. -func (c *Client) handleToolCallRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - toolCallID, _ := params["toolCallId"].(string) - toolName, _ := params["toolName"].(string) - - if sessionID == "" || toolCallID == "" || toolName == "" { +func (c *Client) handleToolCallRequest(req toolCallRequest) (*toolCallResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.ToolCallID == "" || req.ToolName == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid tool call payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - handler, ok := session.getToolHandler(toolName) + handler, ok := session.getToolHandler(req.ToolName) if !ok { - return map[string]any{"result": buildUnsupportedToolResult(toolName)}, nil + return &toolCallResponse{Result: buildUnsupportedToolResult(req.ToolName)}, nil } - arguments := params["arguments"] - result := c.executeToolCall(sessionID, toolCallID, toolName, arguments, handler) - - return map[string]any{"result": result}, nil + result := c.executeToolCall(req.SessionID, req.ToolCallID, req.ToolName, req.Arguments, handler) + return &toolCallResponse{Result: result}, nil } // executeToolCall executes a tool handler and returns the result. @@ -1531,100 +1222,70 @@ func (c *Client) executeToolCall( } // handlePermissionRequest handles a permission request from the CLI server. -func (c *Client) handlePermissionRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - permissionRequest, _ := params["permissionRequest"].(map[string]any) - - if sessionID == "" { +func (c *Client) handlePermissionRequest(req permissionRequestRequest) (*permissionRequestResponse, *jsonrpc2.Error) { + if req.SessionID == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid permission request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - result, err := session.handlePermissionRequest(permissionRequest) + result, err := session.handlePermissionRequest(req.Request) if err != nil { // Return denial on error - return map[string]any{ - "result": map[string]any{ - "kind": "denied-no-approval-rule-and-could-not-request-from-user", + return &permissionRequestResponse{ + Result: PermissionRequestResult{ + Kind: "denied-no-approval-rule-and-could-not-request-from-user", }, }, nil } - return map[string]any{"result": result}, nil + return &permissionRequestResponse{Result: result}, nil } // handleUserInputRequest handles a user input request from the CLI server. -func (c *Client) handleUserInputRequest(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - question, _ := params["question"].(string) - - if sessionID == "" || question == "" { +func (c *Client) handleUserInputRequest(req userInputRequest) (*userInputResponse, *jsonrpc2.Error) { + if req.SessionID == "" || req.Question == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid user input request payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - // Parse choices - var choices []string - if choicesRaw, ok := params["choices"].([]any); ok { - for _, choice := range choicesRaw { - if s, ok := choice.(string); ok { - choices = append(choices, s) - } - } - } - - var allowFreeform *bool - if af, ok := params["allowFreeform"].(bool); ok { - allowFreeform = &af - } - - request := UserInputRequest{ - Question: question, - Choices: choices, - AllowFreeform: allowFreeform, - } - - response, err := session.handleUserInputRequest(request) + response, err := session.handleUserInputRequest(UserInputRequest{ + Question: req.Question, + Choices: req.Choices, + AllowFreeform: req.AllowFreeform, + }) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } - return map[string]any{ - "answer": response.Answer, - "wasFreeform": response.WasFreeform, - }, nil + return &userInputResponse{Answer: response.Answer, WasFreeform: response.WasFreeform}, nil } // handleHooksInvoke handles a hooks invocation from the CLI server. -func (c *Client) handleHooksInvoke(params map[string]any) (map[string]any, *jsonrpc2.Error) { - sessionID, _ := params["sessionId"].(string) - hookType, _ := params["hookType"].(string) - input, _ := params["input"].(map[string]any) - - if sessionID == "" || hookType == "" { +func (c *Client) handleHooksInvoke(req hooksInvokeRequest) (map[string]any, *jsonrpc2.Error) { + if req.SessionID == "" || req.Type == "" { return nil, &jsonrpc2.Error{Code: -32602, Message: "invalid hooks invoke payload"} } c.sessionsMux.Lock() - session, ok := c.sessions[sessionID] + session, ok := c.sessions[req.SessionID] c.sessionsMux.Unlock() if !ok { - return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", req.SessionID)} } - output, err := session.handleHooksInvoke(hookType, input) + output, err := session.handleHooksInvoke(req.Type, req.Input) if err != nil { return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} } diff --git a/go/client_test.go b/go/client_test.go index 185bb4cb..176dad8c 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -25,25 +25,20 @@ func TestClient_HandleToolCallRequest(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - params := map[string]any{ - "sessionId": session.SessionID, - "toolCallId": "123", - "toolName": "missing_tool", - "arguments": map[string]any{}, + params := toolCallRequest{ + SessionID: session.SessionID, + ToolCallID: "123", + ToolName: "missing_tool", + Arguments: map[string]any{}, } response, _ := client.handleToolCallRequest(params) - result, ok := response["result"].(ToolResult) - if !ok { - t.Fatalf("Expected result to be ToolResult, got %T", response["result"]) + if response.Result.ResultType != "failure" { + t.Errorf("Expected resultType to be 'failure', got %q", response.Result.ResultType) } - if result.ResultType != "failure" { - t.Errorf("Expected resultType to be 'failure', got %q", result.ResultType) - } - - if result.Error != "tool 'missing_tool' not supported" { - t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", result.Error) + if response.Result.Error != "tool 'missing_tool' not supported" { + t.Errorf("Expected error to be \"tool 'missing_tool' not supported\", got %q", response.Result.Error) } }) } diff --git a/go/internal/e2e/compaction_test.go b/go/internal/e2e/compaction_test.go index da9ea240..5fae9393 100644 --- a/go/internal/e2e/compaction_test.go +++ b/go/internal/e2e/compaction_test.go @@ -83,6 +83,9 @@ func TestCompaction(t *testing.T) { if err != nil { t.Fatalf("Failed to send verification message: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer, got nil") + } if answer.Data.Content == nil || !strings.Contains(strings.ToLower(*answer.Data.Content), "dragon") { t.Errorf("Expected answer to contain 'dragon', got %v", answer.Data.Content) } diff --git a/go/internal/e2e/mcp_and_agents_test.go b/go/internal/e2e/mcp_and_agents_test.go index 1d21651b..33ad8479 100644 --- a/go/internal/e2e/mcp_and_agents_test.go +++ b/go/internal/e2e/mcp_and_agents_test.go @@ -37,16 +37,15 @@ func TestMCPServers(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 2+2?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) + if message == nil { + t.Fatal("Expected a message, got nil") } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { @@ -97,6 +96,10 @@ func TestMCPServers(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "6") { t.Errorf("Expected message to contain '6', got: %v", message.Data.Content) } @@ -168,16 +171,15 @@ func TestCustomAgents(t *testing.T) { } // Simple interaction to verify session works - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 5+5?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) + if message == nil { + t.Fatal("Expected a message, got nil") } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "10") { @@ -228,6 +230,10 @@ func TestCustomAgents(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "12") { t.Errorf("Expected message to contain '12', got: %v", message.Data.Content) } @@ -373,16 +379,15 @@ func TestCombinedConfiguration(t *testing.T) { t.Error("Expected non-empty session ID") } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is 7+7?", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) + if message == nil { + t.Fatalf("Expected a message, got nil") } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "14") { diff --git a/go/internal/e2e/permissions_test.go b/go/internal/e2e/permissions_test.go index a891548c..cde53b1d 100644 --- a/go/internal/e2e/permissions_test.go +++ b/go/internal/e2e/permissions_test.go @@ -134,18 +134,13 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to write test file: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Edit protected.txt and replace 'protected' with 'hacked'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) - } - // Verify the file was NOT modified content, err := os.ReadFile(testFile) if err != nil { @@ -165,14 +160,13 @@ func TestPermissions(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + message, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - message, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get final message: %v", err) + if message == nil { + t.Fatal("Expected a message, got nil") } if message.Data.Content == nil || !strings.Contains(*message.Data.Content, "4") { diff --git a/go/internal/e2e/session_test.go b/go/internal/e2e/session_test.go index 62183286..6fb05051 100644 --- a/go/internal/e2e/session_test.go +++ b/go/internal/e2e/session_test.go @@ -2,6 +2,7 @@ package e2e import ( "regexp" + "slices" "strings" "testing" "time" @@ -68,6 +69,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") + } + if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { t.Errorf("Expected assistant message to contain '2', got %v", assistantMessage.Data.Content) } @@ -77,6 +82,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send second message: %v", err) } + if secondMessage == nil { + t.Fatal("Expected second assistant message, got nil") + } + if secondMessage.Data.Content == nil || !strings.Contains(*secondMessage.Data.Content, "4") { t.Errorf("Expected second message to contain '4', got %v", secondMessage.Data.Content) } @@ -144,14 +153,13 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is your full name?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } content := "" @@ -190,16 +198,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Validate that only the specified tools are present traffic, err := ctx.GetExchanges() if err != nil { @@ -228,16 +231,11 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + _, err = session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - _, err = testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Validate that excluded tool is not present but others are traffic, err := ctx.GetExchanges() if err != nil { @@ -295,14 +293,13 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is the secret number for key ALPHA?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } content := "" @@ -329,14 +326,13 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if answer == nil { + t.Fatalf("Expected an answer, got nil") } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { @@ -353,13 +349,21 @@ func TestSession(t *testing.T) { t.Errorf("Expected resumed session ID to match, got %q vs %q", session2.SessionID, sessionID) } - answer2, err := testharness.GetFinalAssistantMessage(t.Context(), session2) + messages, err := session2.GetMessages(t.Context()) if err != nil { - t.Fatalf("Failed to get assistant message from resumed session: %v", err) + t.Fatalf("Failed to get messages: %v", err) + } + + answer2Idx := slices.IndexFunc(messages, func(m copilot.SessionEvent) bool { + return m.Type == "assistant.message" + }) + + if answer2Idx == -1 { + t.Fatalf("Expected to find an assistant.message in resumed session messages, got %v", messages) } - if answer2.Data.Content == nil || !strings.Contains(*answer2.Data.Content, "2") { - t.Errorf("Expected resumed session answer to contain '2', got %v", answer2.Data.Content) + if messages[answer2Idx].Data.Content == nil || !strings.Contains(*messages[answer2Idx].Data.Content, "2") { + t.Errorf("Expected resumed session answer to contain '2', got %v", messages[answer2Idx].Data.Content) } }) @@ -373,14 +377,13 @@ func TestSession(t *testing.T) { } sessionID := session1.SessionID - _, err = session1.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + answer, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session1) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if answer == nil { + t.Fatalf("Expected an answer, got nil") } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "2") { @@ -546,6 +549,10 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to send message after abort: %v", err) } + if answer == nil { + t.Fatalf("Expected an answer after abort, got nil") + } + if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") { t.Errorf("Expected answer to contain '4', got %v", answer.Data.Content) } @@ -562,7 +569,6 @@ func TestSession(t *testing.T) { } var deltaContents []string - done := make(chan bool) session.On(func(event copilot.SessionEvent) { switch event.Type { @@ -570,21 +576,17 @@ func TestSession(t *testing.T) { if event.Data.DeltaContent != nil { deltaContents = append(deltaContents, *event.Data.DeltaContent) } - case "session.idle": - close(done) + case "assistant.message": } }) - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - // Wait for completion - select { - case <-done: - case <-time.After(60 * time.Second): - t.Fatal("Timed out waiting for session.idle") + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } // Should have received delta events @@ -592,12 +594,6 @@ func TestSession(t *testing.T) { t.Error("Expected to receive delta events, got none") } - // Get the final message to compare - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - // Accumulated deltas should equal the final message accumulated := strings.Join(deltaContents, "") if assistantMessage.Data.Content != nil && accumulated != *assistantMessage.Data.Content { @@ -627,14 +623,13 @@ func TestSession(t *testing.T) { } // Session should still work normally - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { @@ -651,29 +646,18 @@ func TestSession(t *testing.T) { } var receivedEvents []copilot.SessionEvent - idle := make(chan bool) - session.On(func(event copilot.SessionEvent) { receivedEvents = append(receivedEvents, event) - if event.Type == "session.idle" { - select { - case idle <- true: - default: - } - } }) // Send a message to trigger events - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 100+200?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - // Wait for session to become idle - select { - case <-idle: - case <-time.After(60 * time.Second): - t.Fatal("Timed out waiting for session.idle") + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } // Should have received multiple events @@ -705,11 +689,6 @@ func TestSession(t *testing.T) { t.Error("Expected to receive session.idle event") } - // Verify the assistant response contains the expected answer - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "300") { t.Errorf("Expected assistant message to contain '300', got %v", assistantMessage.Data.Content) } @@ -732,14 +711,13 @@ func TestSession(t *testing.T) { } // Session should work normally with custom config dir - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) + assistantMessage, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 1+1?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - assistantMessage, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if assistantMessage == nil { + t.Fatal("Expected assistant message, got nil") } if assistantMessage.Data.Content == nil || !strings.Contains(*assistantMessage.Data.Content, "2") { diff --git a/go/internal/e2e/skills_test.go b/go/internal/e2e/skills_test.go index ed3578ab..52367422 100644 --- a/go/internal/e2e/skills_test.go +++ b/go/internal/e2e/skills_test.go @@ -71,6 +71,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content == nil || !strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s', got: %v", skillMarker, message.Data.Content) } @@ -99,6 +103,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message == nil { + t.Fatalf("Expected a message, got nil") + } + if message.Data.Content != nil && strings.Contains(*message.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker '%s' when disabled, got: %v", skillMarker, *message.Data.Content) } @@ -125,6 +133,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message1 == nil { + t.Fatalf("Expected a message, got nil") + } + if message1.Data.Content != nil && strings.Contains(*message1.Data.Content, skillMarker) { t.Errorf("Expected message to NOT contain skill marker before skill was added, got: %v", *message1.Data.Content) } @@ -147,6 +159,10 @@ func TestSkills(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } + if message2 == nil { + t.Fatalf("Expected a message, got nil") + } + if message2.Data.Content == nil || !strings.Contains(*message2.Data.Content, skillMarker) { t.Errorf("Expected message to contain skill marker '%s' after resume, got: %v", skillMarker, message2.Data.Content) } diff --git a/go/internal/e2e/testharness/helper.go b/go/internal/e2e/testharness/helper.go index 05947c80..c523b6db 100644 --- a/go/internal/e2e/testharness/helper.go +++ b/go/internal/e2e/testharness/helper.go @@ -1,60 +1,12 @@ package testharness import ( - "context" "errors" "time" copilot "github.com/github/copilot-sdk/go" ) -// GetFinalAssistantMessage waits for and returns the final assistant message from a session turn. -func GetFinalAssistantMessage(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { - result := make(chan *copilot.SessionEvent, 1) - errCh := make(chan error, 1) - - // Subscribe to future events - var finalAssistantMessage *copilot.SessionEvent - unsubscribe := session.On(func(event copilot.SessionEvent) { - switch event.Type { - case "assistant.message": - finalAssistantMessage = &event - case "session.idle": - if finalAssistantMessage != nil { - result <- finalAssistantMessage - } - case "session.error": - msg := "session error" - if event.Data.Message != nil { - msg = *event.Data.Message - } - errCh <- errors.New(msg) - } - }) - defer unsubscribe() - - // Also check existing messages in case the response already arrived - go func() { - existing, err := getExistingFinalResponse(ctx, session) - if err != nil { - errCh <- err - return - } - if existing != nil { - result <- existing - } - }() - - select { - case msg := <-result: - return msg, nil - case err := <-errCh: - return nil, err - case <-ctx.Done(): - return nil, errors.New("timeout waiting for assistant message") - } -} - // GetNextEventOfType waits for and returns the next event of the specified type from a session. func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) { result := make(chan *copilot.SessionEvent, 1) @@ -89,57 +41,3 @@ func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEvent return nil, errors.New("timeout waiting for event: " + string(eventType)) } } - -func getExistingFinalResponse(ctx context.Context, session *copilot.Session) (*copilot.SessionEvent, error) { - messages, err := session.GetMessages(ctx) - if err != nil { - return nil, err - } - - // Find last user message - finalUserMessageIndex := -1 - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Type == "user.message" { - finalUserMessageIndex = i - break - } - } - - var currentTurnMessages []copilot.SessionEvent - if finalUserMessageIndex < 0 { - currentTurnMessages = messages - } else { - currentTurnMessages = messages[finalUserMessageIndex:] - } - - // Check for errors - for _, msg := range currentTurnMessages { - if msg.Type == "session.error" { - errMsg := "session error" - if msg.Data.Message != nil { - errMsg = *msg.Data.Message - } - return nil, errors.New(errMsg) - } - } - - // Find session.idle and get last assistant message before it - sessionIdleIndex := -1 - for i, msg := range currentTurnMessages { - if msg.Type == "session.idle" { - sessionIdleIndex = i - break - } - } - - if sessionIdleIndex != -1 { - // Find last assistant.message before session.idle - for i := sessionIdleIndex - 1; i >= 0; i-- { - if currentTurnMessages[i].Type == "assistant.message" { - return ¤tTurnMessages[i], nil - } - } - } - - return nil, nil -} diff --git a/go/internal/e2e/tools_test.go b/go/internal/e2e/tools_test.go index 5af9079c..b6af6ef0 100644 --- a/go/internal/e2e/tools_test.go +++ b/go/internal/e2e/tools_test.go @@ -30,14 +30,13 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What's the first line of README.md in this directory?"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if answer == nil { + t.Fatalf("Expected an answer, got nil") } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "ELIZA") { @@ -64,14 +63,13 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Use encrypt_string to encrypt this string: Hello"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if answer == nil { + t.Fatalf("Expected an answer, got nil") } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "HELLO") { @@ -96,16 +94,15 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "What is my location? If you can't find out, just say 'unknown'.", }) if err != nil { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) + if answer == nil { + t.Fatalf("Expected an answer, got nil") } // Check the underlying traffic @@ -213,7 +210,7 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - _, err = session.Send(t.Context(), copilot.MessageOptions{ + answer, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ Prompt: "Perform a DB query for the 'cities' table using IDs 12 and 19, sorting ascending. " + "Reply only with lines of the form: [cityname] [population]", }) @@ -221,11 +218,6 @@ func TestTools(t *testing.T) { t.Fatalf("Failed to send message: %v", err) } - answer, err := testharness.GetFinalAssistantMessage(t.Context(), session) - if err != nil { - t.Fatalf("Failed to get assistant message: %v", err) - } - if answer == nil || answer.Data.Content == nil { t.Fatalf("Expected assistant message with content") } diff --git a/go/internal/jsonrpc2/jsonrpc2.go b/go/internal/jsonrpc2/jsonrpc2.go index 8e4a0f6a..e44e1231 100644 --- a/go/internal/jsonrpc2/jsonrpc2.go +++ b/go/internal/jsonrpc2/jsonrpc2.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "reflect" "sync" ) @@ -23,43 +24,39 @@ func (e *Error) Error() string { // Request represents a JSON-RPC 2.0 request type Request struct { JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` + ID json.RawMessage `json:"id"` // nil for notifications Method string `json:"method"` - Params map[string]any `json:"params"` + Params json.RawMessage `json:"params"` +} + +func (r *Request) IsCall() bool { + return len(r.ID) > 0 } // Response represents a JSON-RPC 2.0 response type Response struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` - Result map[string]any `json:"result,omitempty"` + Result json.RawMessage `json:"result,omitempty"` Error *Error `json:"error,omitempty"` } -// Notification represents a JSON-RPC 2.0 notification -type Notification struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]any `json:"params"` -} - // NotificationHandler handles incoming notifications -type NotificationHandler func(method string, params map[string]any) +type NotificationHandler func(method string, params json.RawMessage) // RequestHandler handles incoming server requests and returns a result or error -type RequestHandler func(params map[string]any) (map[string]any, *Error) +type RequestHandler func(params json.RawMessage) (json.RawMessage, *Error) // Client is a minimal JSON-RPC 2.0 client for stdio transport type Client struct { - stdin io.WriteCloser - stdout io.ReadCloser - mu sync.Mutex - pendingRequests map[string]chan *Response - notificationHandler NotificationHandler - requestHandlers map[string]RequestHandler - running bool - stopChan chan struct{} - wg sync.WaitGroup + stdin io.WriteCloser + stdout io.ReadCloser + mu sync.Mutex + pendingRequests map[string]chan *Response + requestHandlers map[string]RequestHandler + running bool + stopChan chan struct{} + wg sync.WaitGroup } // NewClient creates a new JSON-RPC client @@ -96,11 +93,55 @@ func (c *Client) Stop() { c.wg.Wait() } -// SetNotificationHandler sets the handler for incoming notifications -func (c *Client) SetNotificationHandler(handler NotificationHandler) { - c.mu.Lock() - defer c.mu.Unlock() - c.notificationHandler = handler +func NotificationHandlerFor[In any](handler func(params In)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeFor[In](); t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + handler(in) + return nil, nil + } +} + +// RequestHandlerFor creates a RequestHandler from a typed function +func RequestHandlerFor[In, Out any](handler func(params In) (Out, *Error)) RequestHandler { + return func(params json.RawMessage) (json.RawMessage, *Error) { + var in In + // If In is a pointer type, allocate the underlying value and unmarshal into it directly + var target any = &in + if t := reflect.TypeOf(in); t != nil && t.Kind() == reflect.Pointer { + in = reflect.New(t.Elem()).Interface().(In) + target = in + } + if err := json.Unmarshal(params, target); err != nil { + return nil, &Error{ + Code: -32602, + Message: fmt.Sprintf("Invalid params: %v", err), + } + } + out, errj := handler(in) + if errj != nil { + return nil, errj + } + outData, err := json.Marshal(out) + if err != nil { + return nil, &Error{ + Code: -32603, + Message: fmt.Sprintf("Failed to marshal response: %v", err), + } + } + return outData, nil + } } // SetRequestHandler registers a handler for incoming requests from the server @@ -115,7 +156,7 @@ func (c *Client) SetRequestHandler(method string, handler RequestHandler) { } // Request sends a JSON-RPC request and waits for the response -func (c *Client) Request(method string, params map[string]any) (map[string]any, error) { +func (c *Client) Request(method string, params any) (json.RawMessage, error) { requestID := generateUUID() // Create response channel @@ -131,12 +172,17 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, c.mu.Unlock() }() + paramsData, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + // Send request request := Request{ JSONRPC: "2.0", ID: json.RawMessage(`"` + requestID + `"`), Method: method, - Params: params, + Params: json.RawMessage(paramsData), } if err := c.sendMessage(request); err != nil { @@ -156,11 +202,16 @@ func (c *Client) Request(method string, params map[string]any) (map[string]any, } // Notify sends a JSON-RPC notification (no response expected) -func (c *Client) Notify(method string, params map[string]any) error { - notification := Notification{ +func (c *Client) Notify(method string, params any) error { + paramsData, err := json.Marshal(params) + if err != nil { + return fmt.Errorf("failed to marshal params: %w", err) + } + + notification := Request{ JSONRPC: "2.0", Method: method, - Params: params, + Params: json.RawMessage(paramsData), } return c.sendMessage(notification) } @@ -231,7 +282,7 @@ func (c *Client) readLoop() { // Try to parse as request first (has both ID and Method) var request Request - if err := json.Unmarshal(body, &request); err == nil && request.Method != "" && len(request.ID) > 0 { + if err := json.Unmarshal(body, &request); err == nil && request.Method != "" { c.handleRequest(&request) continue } @@ -242,13 +293,6 @@ func (c *Client) readLoop() { c.handleResponse(&response) continue } - - // Try to parse as notification (has Method but no ID) - var notification Notification - if err := json.Unmarshal(body, ¬ification); err == nil && notification.Method != "" { - c.handleNotification(¬ification) - continue - } } } @@ -270,24 +314,21 @@ func (c *Client) handleResponse(response *Response) { } } -// handleNotification dispatches a notification to the handler -func (c *Client) handleNotification(notification *Notification) { - c.mu.Lock() - handler := c.notificationHandler - c.mu.Unlock() - - if handler != nil { - handler(notification.Method, notification.Params) - } -} - func (c *Client) handleRequest(request *Request) { c.mu.Lock() handler := c.requestHandlers[request.Method] c.mu.Unlock() if handler == nil { - c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + if request.IsCall() { + c.sendErrorResponse(request.ID, -32601, fmt.Sprintf("Method not found: %s", request.Method), nil) + } + return + } + + // Notifications run synchronously, calls run in a goroutine to avoid blocking + if !request.IsCall() { + handler(request.Params) return } @@ -303,14 +344,11 @@ func (c *Client) handleRequest(request *Request) { c.sendErrorResponse(request.ID, err.Code, err.Message, err.Data) return } - if result == nil { - result = make(map[string]any) - } c.sendResponse(request.ID, result) }() } -func (c *Client) sendResponse(id json.RawMessage, result map[string]any) { +func (c *Client) sendResponse(id json.RawMessage, result json.RawMessage) { response := Response{ JSONRPC: "2.0", ID: id, diff --git a/go/session.go b/go/session.go index e4f1473d..37cfe52f 100644 --- a/go/session.go +++ b/go/session.go @@ -106,29 +106,23 @@ func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) // log.Printf("Failed to send message: %v", err) // } func (s *Session) Send(ctx context.Context, options MessageOptions) (string, error) { - params := map[string]any{ - "sessionId": s.SessionID, - "prompt": options.Prompt, + req := sessionSendRequest{ + SessionID: s.SessionID, + Prompt: options.Prompt, + Attachments: options.Attachments, + Mode: options.Mode, } - if options.Attachments != nil { - params["attachments"] = options.Attachments - } - if options.Mode != "" { - params["mode"] = options.Mode - } - - result, err := s.client.Request("session.send", params) + result, err := s.client.Request("session.send", req) if err != nil { return "", fmt.Errorf("failed to send message: %w", err) } - messageID, ok := result["messageId"].(string) - if !ok { - return "", fmt.Errorf("invalid response: missing messageId") + var response sessionSendResponse + if err := json.Unmarshal(result, &response); err != nil { + return "", fmt.Errorf("failed to unmarshal send response: %w", err) } - - return messageID, nil + return response.MessageID, nil } // SendAndWait sends a message to this session and waits until the session becomes idle. @@ -306,7 +300,7 @@ func (s *Session) getPermissionHandler() PermissionHandler { // handlePermissionRequest handles a permission request from the Copilot CLI. // This is an internal method called by the SDK when the CLI requests permission. -func (s *Session) handlePermissionRequest(requestData map[string]any) (PermissionRequestResult, error) { +func (s *Session) handlePermissionRequest(request PermissionRequest) (PermissionRequestResult, error) { handler := s.getPermissionHandler() if handler == nil { @@ -315,16 +309,6 @@ func (s *Session) handlePermissionRequest(requestData map[string]any) (Permissio }, nil } - // Convert map to PermissionRequest struct - kind, _ := requestData["kind"].(string) - toolCallID, _ := requestData["toolCallId"].(string) - - request := PermissionRequest{ - Kind: kind, - ToolCallID: toolCallID, - Extra: requestData, - } - invocation := PermissionInvocation{ SessionID: s.SessionID, } @@ -388,7 +372,7 @@ func (s *Session) getHooks() *SessionHooks { // handleHooksInvoke handles a hook invocation from the Copilot CLI. // This is an internal method called by the SDK when the CLI invokes a hook. -func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, error) { +func (s *Session) handleHooksInvoke(hookType string, rawInput json.RawMessage) (any, error) { hooks := s.getHooks() if hooks == nil { @@ -404,153 +388,66 @@ func (s *Session) handleHooksInvoke(hookType string, input map[string]any) (any, if hooks.OnPreToolUse == nil { return nil, nil } - hookInput := parsePreToolUseInput(input) - return hooks.OnPreToolUse(hookInput, invocation) + var input PreToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPreToolUse(input, invocation) case "postToolUse": if hooks.OnPostToolUse == nil { return nil, nil } - hookInput := parsePostToolUseInput(input) - return hooks.OnPostToolUse(hookInput, invocation) + var input PostToolUseHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnPostToolUse(input, invocation) case "userPromptSubmitted": if hooks.OnUserPromptSubmitted == nil { return nil, nil } - hookInput := parseUserPromptSubmittedInput(input) - return hooks.OnUserPromptSubmitted(hookInput, invocation) + var input UserPromptSubmittedHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnUserPromptSubmitted(input, invocation) case "sessionStart": if hooks.OnSessionStart == nil { return nil, nil } - hookInput := parseSessionStartInput(input) - return hooks.OnSessionStart(hookInput, invocation) + var input SessionStartHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionStart(input, invocation) case "sessionEnd": if hooks.OnSessionEnd == nil { return nil, nil } - hookInput := parseSessionEndInput(input) - return hooks.OnSessionEnd(hookInput, invocation) + var input SessionEndHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnSessionEnd(input, invocation) case "errorOccurred": if hooks.OnErrorOccurred == nil { return nil, nil } - hookInput := parseErrorOccurredInput(input) - return hooks.OnErrorOccurred(hookInput, invocation) - + var input ErrorOccurredHookInput + if err := json.Unmarshal(rawInput, &input); err != nil { + return nil, fmt.Errorf("invalid hook input: %w", err) + } + return hooks.OnErrorOccurred(input, invocation) default: return nil, fmt.Errorf("unknown hook type: %s", hookType) } } -// Helper functions to parse hook inputs - -func parsePreToolUseInput(input map[string]any) PreToolUseHookInput { - result := PreToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - return result -} - -func parsePostToolUseInput(input map[string]any) PostToolUseHookInput { - result := PostToolUseHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if name, ok := input["toolName"].(string); ok { - result.ToolName = name - } - result.ToolArgs = input["toolArgs"] - result.ToolResult = input["toolResult"] - return result -} - -func parseUserPromptSubmittedInput(input map[string]any) UserPromptSubmittedHookInput { - result := UserPromptSubmittedHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if prompt, ok := input["prompt"].(string); ok { - result.Prompt = prompt - } - return result -} - -func parseSessionStartInput(input map[string]any) SessionStartHookInput { - result := SessionStartHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if source, ok := input["source"].(string); ok { - result.Source = source - } - if prompt, ok := input["initialPrompt"].(string); ok { - result.InitialPrompt = prompt - } - return result -} - -func parseSessionEndInput(input map[string]any) SessionEndHookInput { - result := SessionEndHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if reason, ok := input["reason"].(string); ok { - result.Reason = reason - } - if msg, ok := input["finalMessage"].(string); ok { - result.FinalMessage = msg - } - if errStr, ok := input["error"].(string); ok { - result.Error = errStr - } - return result -} - -func parseErrorOccurredInput(input map[string]any) ErrorOccurredHookInput { - result := ErrorOccurredHookInput{} - if ts, ok := input["timestamp"].(float64); ok { - result.Timestamp = int64(ts) - } - if cwd, ok := input["cwd"].(string); ok { - result.Cwd = cwd - } - if errMsg, ok := input["error"].(string); ok { - result.Error = errMsg - } - if ctx, ok := input["errorContext"].(string); ok { - result.ErrorContext = ctx - } - if rec, ok := input["recoverable"].(bool); ok { - result.Recoverable = rec - } - return result -} - // dispatchEvent dispatches an event to all registered handlers. // This is an internal method; handlers are called synchronously and any panics // are recovered to prevent crashing the event dispatcher. @@ -596,38 +493,17 @@ func (s *Session) dispatchEvent(event SessionEvent) { // } // } func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { - params := map[string]any{ - "sessionId": s.SessionID, - } - result, err := s.client.Request("session.getMessages", params) + result, err := s.client.Request("session.getMessages", sessionGetMessagesRequest{SessionID: s.SessionID}) if err != nil { return nil, fmt.Errorf("failed to get messages: %w", err) } - eventsRaw, ok := result["events"].([]any) - if !ok { - return nil, fmt.Errorf("invalid response: missing events") - } - - // Convert to SessionEvent structs - events := make([]SessionEvent, 0, len(eventsRaw)) - for _, eventRaw := range eventsRaw { - // Marshal back to JSON and unmarshal into typed struct - eventJSON, err := json.Marshal(eventRaw) - if err != nil { - continue - } - - event, err := UnmarshalSessionEvent(eventJSON) - if err != nil { - continue - } - - events = append(events, event) + var response sessionGetMessagesResponse + if err := json.Unmarshal(result, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal get messages response: %w", err) } - - return events, nil + return response.Events, nil } // Destroy destroys this session and releases all associated resources. @@ -645,11 +521,7 @@ func (s *Session) GetMessages(ctx context.Context) ([]SessionEvent, error) { // log.Printf("Failed to destroy session: %v", err) // } func (s *Session) Destroy() error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.destroy", params) + _, err := s.client.Request("session.destroy", sessionDestroyRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to destroy session: %w", err) } @@ -692,11 +564,7 @@ func (s *Session) Destroy() error { // log.Printf("Failed to abort: %v", err) // } func (s *Session) Abort(ctx context.Context) error { - params := map[string]any{ - "sessionId": s.SessionID, - } - - _, err := s.client.Request("session.abort", params) + _, err := s.client.Request("session.abort", sessionAbortRequest{SessionID: s.SessionID}) if err != nil { return fmt.Errorf("failed to abort session: %w", err) } diff --git a/go/types.go b/go/types.go index 7a1917f0..b421c85c 100644 --- a/go/types.go +++ b/go/types.go @@ -1,5 +1,7 @@ package copilot +import "encoding/json" + // ConnectionState represents the client connection state type ConnectionState string @@ -113,15 +115,15 @@ type PermissionInvocation struct { // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { - Question string `json:"question"` - Choices []string `json:"choices,omitempty"` - AllowFreeform *bool `json:"allowFreeform,omitempty"` + Question string + Choices []string + AllowFreeform *bool } // UserInputResponse represents the user's response to an input request type UserInputResponse struct { - Answer string `json:"answer"` - WasFreeform bool `json:"wasFreeform"` + Answer string + WasFreeform bool } // UserInputHandler handles user input requests from the agent @@ -307,13 +309,13 @@ type CustomAgentConfig struct { // limits through background compaction and persist state to a workspace directory. type InfiniteSessionConfig struct { // Enabled controls whether infinite sessions are enabled (default: true) - Enabled *bool + Enabled *bool `json:"enabled,omitempty"` // BackgroundCompactionThreshold is the context utilization (0.0-1.0) at which // background compaction starts. Default: 0.80 - BackgroundCompactionThreshold *float64 + BackgroundCompactionThreshold *float64 `json:"backgroundCompactionThreshold,omitempty"` // BufferExhaustionThreshold is the context utilization (0.0-1.0) at which // the session blocks until compaction completes. Default: 0.95 - BufferExhaustionThreshold *float64 + BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } // SessionConfig configures a new session @@ -369,10 +371,10 @@ type SessionConfig struct { // Tool describes a caller-implemented tool that can be invoked by Copilot type Tool struct { - Name string - Description string // optional - Parameters map[string]any - Handler ToolHandler + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Handler ToolHandler `json:"-"` } // ToolInvocation describes a tool call initiated by Copilot @@ -477,43 +479,6 @@ type MessageOptions struct { // SessionEventHandler is a callback for session events type SessionEventHandler func(event SessionEvent) -// PingResponse is the response from a ping request -type PingResponse struct { - Message string `json:"message"` - Timestamp int64 `json:"timestamp"` - ProtocolVersion *int `json:"protocolVersion,omitempty"` -} - -// SessionCreateResponse is the response from session.create -type SessionCreateResponse struct { - SessionID string `json:"sessionId"` -} - -// SessionSendResponse is the response from session.send -type SessionSendResponse struct { - MessageID string `json:"messageId"` -} - -// SessionGetMessagesResponse is the response from session.getMessages -type SessionGetMessagesResponse struct { - Events []SessionEvent `json:"events"` -} - -// GetStatusResponse is the response from status.get -type GetStatusResponse struct { - Version string `json:"version"` - ProtocolVersion int `json:"protocolVersion"` -} - -// GetAuthStatusResponse is the response from auth.getStatus -type GetAuthStatusResponse struct { - IsAuthenticated bool `json:"isAuthenticated"` - AuthType *string `json:"authType,omitempty"` - Host *string `json:"host,omitempty"` - Login *string `json:"login,omitempty"` - StatusMessage *string `json:"statusMessage,omitempty"` -} - // ModelVisionLimits contains vision-specific limits type ModelVisionLimits struct { SupportedMediaTypes []string `json:"supported_media_types"` @@ -562,11 +527,6 @@ type ModelInfo struct { DefaultReasoningEffort string `json:"defaultReasoningEffort,omitempty"` } -// GetModelsResponse is the response from models.list -type GetModelsResponse struct { - Models []ModelInfo `json:"models"` -} - // SessionMetadata contains metadata about a session type SessionMetadata struct { SessionID string `json:"sessionId"` @@ -576,22 +536,6 @@ type SessionMetadata struct { IsRemote bool `json:"isRemote"` } -// ListSessionsResponse is the response from session.list -type ListSessionsResponse struct { - Sessions []SessionMetadata `json:"sessions"` -} - -// DeleteSessionRequest is the request for session.delete -type DeleteSessionRequest struct { - SessionID string `json:"sessionId"` -} - -// DeleteSessionResponse is the response from session.delete -type DeleteSessionResponse struct { - Success bool `json:"success"` - Error *string `json:"error,omitempty"` -} - // SessionLifecycleEventType represents the type of session lifecycle event type SessionLifecycleEventType string @@ -620,19 +564,218 @@ type SessionLifecycleEventMetadata struct { // SessionLifecycleHandler is a callback for session lifecycle events type SessionLifecycleHandler func(event SessionLifecycleEvent) -// GetForegroundSessionResponse is the response from session.getForeground -type GetForegroundSessionResponse struct { +// permissionRequestRequest represents the request data for a permission request +type permissionRequestRequest struct { + SessionID string `json:"sessionId"` + Request PermissionRequest `json:"permissionRequest"` +} + +// permissionRequestResponse represents the response to a permission request +type permissionRequestResponse struct { + Result PermissionRequestResult `json:"result"` +} + +// createSessionRequest is the request for session.create +type createSessionRequest struct { + Model string `json:"model,omitempty"` + SessionID string `json:"sessionId,omitempty"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + SystemMessage *SystemMessageConfig `json:"systemMessage,omitempty"` + AvailableTools []string `json:"availableTools,omitempty"` + ExcludedTools []string `json:"excludedTools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + ConfigDir string `json:"configDir,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` + InfiniteSessions *InfiniteSessionConfig `json:"infiniteSessions,omitempty"` +} + +// createSessionResponse is the response from session.create +type createSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +// resumeSessionRequest is the request for session.resume +type resumeSessionRequest struct { + SessionID string `json:"sessionId"` + ReasoningEffort string `json:"reasoningEffort,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Provider *ProviderConfig `json:"provider,omitempty"` + RequestPermission *bool `json:"requestPermission,omitempty"` + RequestUserInput *bool `json:"requestUserInput,omitempty"` + Hooks *bool `json:"hooks,omitempty"` + WorkingDirectory string `json:"workingDirectory,omitempty"` + DisableResume *bool `json:"disableResume,omitempty"` + Streaming *bool `json:"streaming,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcpServers,omitempty"` + CustomAgents []CustomAgentConfig `json:"customAgents,omitempty"` + SkillDirectories []string `json:"skillDirectories,omitempty"` + DisabledSkills []string `json:"disabledSkills,omitempty"` +} + +// resumeSessionResponse is the response from session.resume +type resumeSessionResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` +} + +type hooksInvokeRequest struct { + SessionID string `json:"sessionId"` + Type string `json:"hookType"` + Input json.RawMessage `json:"input"` +} + +// listSessionsRequest is the request for session.list +type listSessionsRequest struct{} + +// listSessionsResponse is the response from session.list +type listSessionsResponse struct { + Sessions []SessionMetadata `json:"sessions"` +} + +// deleteSessionRequest is the request for session.delete +type deleteSessionRequest struct { + SessionID string `json:"sessionId"` +} + +// deleteSessionResponse is the response from session.delete +type deleteSessionResponse struct { + Success bool `json:"success"` + Error *string `json:"error,omitempty"` +} + +// getForegroundSessionRequest is the request for session.getForeground +type getForegroundSessionRequest struct{} + +// getForegroundSessionResponse is the response from session.getForeground +type getForegroundSessionResponse struct { SessionID *string `json:"sessionId,omitempty"` WorkspacePath *string `json:"workspacePath,omitempty"` } -// SetForegroundSessionRequest is the request for session.setForeground -type SetForegroundSessionRequest struct { +// setForegroundSessionRequest is the request for session.setForeground +type setForegroundSessionRequest struct { SessionID string `json:"sessionId"` } -// SetForegroundSessionResponse is the response from session.setForeground -type SetForegroundSessionResponse struct { +// setForegroundSessionResponse is the response from session.setForeground +type setForegroundSessionResponse struct { Success bool `json:"success"` Error *string `json:"error,omitempty"` } + +type pingRequest struct { + Message string `json:"message,omitempty"` +} + +// PingResponse is the response from a ping request +type PingResponse struct { + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` + ProtocolVersion *int `json:"protocolVersion,omitempty"` +} + +// getStatusRequest is the request for status.get +type getStatusRequest struct{} + +// GetStatusResponse is the response from status.get +type GetStatusResponse struct { + Version string `json:"version"` + ProtocolVersion int `json:"protocolVersion"` +} + +// getAuthStatusRequest is the request for auth.getStatus +type getAuthStatusRequest struct{} + +// GetAuthStatusResponse is the response from auth.getStatus +type GetAuthStatusResponse struct { + IsAuthenticated bool `json:"isAuthenticated"` + AuthType *string `json:"authType,omitempty"` + Host *string `json:"host,omitempty"` + Login *string `json:"login,omitempty"` + StatusMessage *string `json:"statusMessage,omitempty"` +} + +// listModelsRequest is the request for models.list +type listModelsRequest struct{} + +// listModelsResponse is the response from models.list +type listModelsResponse struct { + Models []ModelInfo `json:"models"` +} + +// sessionGetMessagesRequest is the request for session.getMessages +type sessionGetMessagesRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionGetMessagesResponse is the response from session.getMessages +type sessionGetMessagesResponse struct { + Events []SessionEvent `json:"events"` +} + +// sessionDestroyRequest is the request for session.destroy +type sessionDestroyRequest struct { + SessionID string `json:"sessionId"` +} + +// sessionAbortRequest is the request for session.abort +type sessionAbortRequest struct { + SessionID string `json:"sessionId"` +} + +type sessionSendRequest struct { + SessionID string `json:"sessionId"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` + Mode string `json:"mode,omitempty"` +} + +// sessionSendResponse is the response from session.send +type sessionSendResponse struct { + MessageID string `json:"messageId"` +} + +// sessionEventRequest is the request for session event notifications +type sessionEventRequest struct { + SessionID string `json:"sessionId"` + Event SessionEvent `json:"event"` +} + +// toolCallRequest represents a tool call request from the server +// to the client for execution. +type toolCallRequest struct { + SessionID string `json:"sessionId"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + Arguments any `json:"arguments"` +} + +// toolCallResponse represents the response to a tool call request +// from the client back to the server. +type toolCallResponse struct { + Result ToolResult `json:"result"` +} + +// userInputRequest represents a request for user input from the agent +type userInputRequest struct { + SessionID string `json:"sessionId"` + Question string `json:"question"` + Choices []string `json:"choices,omitempty"` + AllowFreeform *bool `json:"allowFreeform,omitempty"` +} + +// userInputResponse represents the user's response to an input request +type userInputResponse struct { + Answer string `json:"answer"` + WasFreeform bool `json:"wasFreeform"` +}