diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index ad7601701..47f4d2338 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -969,7 +969,7 @@ func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.S // RunStream starts the agent's interaction loop and returns a channel of events func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event { - slog.Debug("Starting runtime stream", "agent", r.CurrentAgentName(), "session_id", sess.ID) + slog.Debug("Starting runtime stream", "agent", r.CurrentAgentName(), "session_agent", sess.AgentName, "session_id", sess.ID) events := make(chan Event, 128) go func() { @@ -1039,9 +1039,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c runtimeMaxIterations := sess.MaxIterations for { - // Set elicitation handler on all MCP toolsets before getting tools - a := r.CurrentAgent() - r.emitAgentWarnings(a, events) r.configureToolsetHandlers(a, events) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index ba53b5e5c..aa600be73 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -423,8 +423,9 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates // getToolsForAgent returns the tool definitions for an agent based on its configuration func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry) ([]tools.ToolSet, []string) { var ( - toolSets []tools.ToolSet - warnings []string + toolSets []tools.ToolSet + warnings []string + lspBackends []builtin.LSPBackend ) deferredToolset := builtin.NewDeferredToolset() @@ -456,9 +457,29 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri } } + // Collect LSP backends for multiplexing when there are multiple. + // Instead of adding them individually (which causes duplicate tool names), + // they are combined into a single LSPMultiplexer after the loop. + if toolset.Type == "lsp" { + if lspTool, ok := tool.(*builtin.LSPTool); ok { + lspBackends = append(lspBackends, builtin.LSPBackend{LSP: lspTool, Toolset: wrapped}) + continue + } + slog.Warn("Toolset configured as type 'lsp' but registry returned unexpected type; treating as regular toolset", + "type", fmt.Sprintf("%T", tool), "command", toolset.Command) + } + toolSets = append(toolSets, wrapped) } + // Merge LSP backends: if there are multiple, combine them into a single + // multiplexer so the LLM sees one set of lsp_* tools instead of duplicates. + if len(lspBackends) > 1 { + toolSets = append(toolSets, builtin.NewLSPMultiplexer(lspBackends)) + } else if len(lspBackends) == 1 { + toolSets = append(toolSets, lspBackends[0].Toolset) + } + if deferredToolset.HasSources() { toolSets = append(toolSets, deferredToolset) } diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 05a85c3a8..9b2d0c7d2 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -386,6 +386,85 @@ agents: assert.Equal(t, expected, rootAgent.AddPromptFiles()) } +func TestGetToolsForAgent_MultipleLSPToolsetsAreCombined(t *testing.T) { + t.Parallel() + + a := &latest.AgentConfig{ + Instruction: "test", + Toolsets: []latest.Toolset{ + { + Type: "lsp", + Command: "gopls", + FileTypes: []string{".go"}, + }, + { + Type: "lsp", + Command: "gopls", + FileTypes: []string{".mod"}, + }, + }, + } + + runConfig := config.RuntimeConfig{ + EnvProviderForTests: &noEnvProvider{}, + } + + got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry()) + require.Empty(t, warnings) + + // Should have exactly one toolset (the multiplexer) + require.Len(t, got, 1) + + // Verify that we get no duplicate tool names + allTools, err := got[0].Tools(t.Context()) + require.NoError(t, err) + + seen := make(map[string]bool) + for _, tool := range allTools { + assert.False(t, seen[tool.Name], "duplicate tool name: %s", tool.Name) + seen[tool.Name] = true + } + + // Verify LSP tools are present + assert.True(t, seen["lsp_hover"]) + assert.True(t, seen["lsp_definition"]) +} + +func TestGetToolsForAgent_SingleLSPToolsetNotWrapped(t *testing.T) { + t.Parallel() + + a := &latest.AgentConfig{ + Instruction: "test", + Toolsets: []latest.Toolset{ + { + Type: "lsp", + Command: "gopls", + FileTypes: []string{".go"}, + }, + }, + } + + runConfig := config.RuntimeConfig{ + EnvProviderForTests: &noEnvProvider{}, + } + + got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry()) + require.Empty(t, warnings) + + // Should have exactly one toolset that provides LSP tools. + require.Len(t, got, 1) + + allTools, err := got[0].Tools(t.Context()) + require.NoError(t, err) + + var names []string + for _, tool := range allTools { + names = append(names, tool.Name) + } + assert.Contains(t, names, "lsp_hover") + assert.Contains(t, names, "lsp_definition") +} + func TestExternalDepthContext(t *testing.T) { t.Parallel() diff --git a/pkg/tools/builtin/lsp.go b/pkg/tools/builtin/lsp.go index 3688be0b2..02412e8f2 100644 --- a/pkg/tools/builtin/lsp.go +++ b/pkg/tools/builtin/lsp.go @@ -62,6 +62,7 @@ type lspHandler struct { stdout *bufio.Reader initialized atomic.Bool requestID atomic.Int64 + done chan struct{} // closed by stop() to signal background goroutines // Configuration command string @@ -501,12 +502,22 @@ func (h *lspHandler) start(ctx context.Context) error { h.mu.Lock() defer h.mu.Unlock() + return h.startLocked(ctx) +} + +// startLocked starts the LSP server process. The caller must hold h.mu. +func (h *lspHandler) startLocked(ctx context.Context) error { if h.cmd != nil { - return errors.New("LSP server already running") + return nil } slog.Debug("Starting LSP server", "command", h.command, "args", h.args) + // Detach from the caller's context so the LSP process outlives the + // request or sub-session that triggered the start. The process is + // explicitly terminated by stop(). + ctx = context.WithoutCancel(ctx) + cmd := exec.CommandContext(ctx, h.command, h.args...) cmd.Env = append(os.Environ(), h.env...) cmd.Dir = h.workingDir @@ -533,8 +544,9 @@ func (h *lspHandler) start(ctx context.Context) error { h.cmd = cmd h.stdin = stdin h.stdout = bufio.NewReader(stdout) + h.done = make(chan struct{}) - go h.readNotifications(ctx, &stderrBuf) + go h.readNotifications(h.done, &stderrBuf) slog.Debug("LSP server started successfully") return nil @@ -550,6 +562,8 @@ func (h *lspHandler) stop(_ context.Context) error { slog.Debug("Stopping LSP server") + close(h.done) + if h.initialized.Load() { _, _ = h.sendRequestLocked("shutdown", nil) _ = h.sendNotificationLocked("exit", nil) @@ -590,12 +604,9 @@ func (h *lspHandler) ensureInitialized(ctx context.Context) error { } if h.cmd == nil { - h.mu.Unlock() - if err := h.start(ctx); err != nil { - h.mu.Lock() + if err := h.startLocked(ctx); err != nil { return fmt.Errorf("failed to start LSP server: %w", err) } - h.mu.Lock() } if !h.initialized.Load() { @@ -1455,13 +1466,13 @@ func (h *lspHandler) readMessageLocked() ([]byte, error) { return body, nil } -func (h *lspHandler) readNotifications(ctx context.Context, stderrBuf *bytes.Buffer) { +func (h *lspHandler) readNotifications(done <-chan struct{}, stderrBuf *bytes.Buffer) { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-done: return case <-ticker.C: if stderrBuf.Len() > 0 { diff --git a/pkg/tools/builtin/lsp_multiplexer.go b/pkg/tools/builtin/lsp_multiplexer.go new file mode 100644 index 000000000..fc569c782 --- /dev/null +++ b/pkg/tools/builtin/lsp_multiplexer.go @@ -0,0 +1,172 @@ +package builtin + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/docker/cagent/pkg/tools" +) + +// LSPMultiplexer combines multiple LSP backends into a single toolset. +// It presents one set of lsp_* tools and routes each call to the appropriate +// backend based on the file extension in the tool arguments. +type LSPMultiplexer struct { + backends []LSPBackend +} + +// LSPBackend pairs a raw LSPTool (used for file-type routing) with an +// optionally-wrapped ToolSet (used for tool enumeration, so that per-toolset +// config like tool filters, instructions, or toon wrappers are respected). +type LSPBackend struct { + LSP *LSPTool + Toolset tools.ToolSet +} + +// lspRouteTarget pairs a backend with the tool handler it produced for a given tool name. +type lspRouteTarget struct { + lsp *LSPTool + handler tools.ToolHandler +} + +// Verify interface compliance. +var ( + _ tools.ToolSet = (*LSPMultiplexer)(nil) + _ tools.Startable = (*LSPMultiplexer)(nil) + _ tools.Instructable = (*LSPMultiplexer)(nil) +) + +// NewLSPMultiplexer creates a multiplexer that routes LSP tool calls +// to the appropriate backend based on file type. +func NewLSPMultiplexer(backends []LSPBackend) *LSPMultiplexer { + return &LSPMultiplexer{backends: append([]LSPBackend{}, backends...)} +} + +func (m *LSPMultiplexer) Start(ctx context.Context) error { + var started int + for _, b := range m.backends { + if err := b.LSP.Start(ctx); err != nil { + // Clean up previously started backends to avoid resource leaks. + for _, s := range m.backends[:started] { + _ = s.LSP.Stop(ctx) + } + return fmt.Errorf("starting LSP backend %q: %w", b.LSP.handler.command, err) + } + started++ + } + return nil +} + +func (m *LSPMultiplexer) Stop(ctx context.Context) error { + var errs []error + for _, b := range m.backends { + if err := b.LSP.Stop(ctx); err != nil { + errs = append(errs, fmt.Errorf("stopping LSP backend %q: %w", b.LSP.handler.command, err)) + } + } + return errors.Join(errs...) +} + +func (m *LSPMultiplexer) Instructions() string { + // Combine instructions from all backends, deduplicating identical ones. + // Typically they share the same base LSP instructions, but individual + // toolsets may override them via the Instruction config field. + var parts []string + seen := make(map[string]bool) + for _, b := range m.backends { + instr := tools.GetInstructions(b.Toolset) + if instr != "" && !seen[instr] { + seen[instr] = true + parts = append(parts, instr) + } + } + return strings.Join(parts, "\n\n") +} + +func (m *LSPMultiplexer) Tools(ctx context.Context) ([]tools.Tool, error) { + // Collect each backend's tools keyed by name. We build the union of all + // tool names (not just the first backend's) so that per-backend tool + // filters don't accidentally hide tools that other backends expose. + handlersByName := make(map[string][]lspRouteTarget) + seenTools := make(map[string]tools.Tool) // first definition wins (for schema/description) + var toolOrder []string // preserve insertion order + for _, b := range m.backends { + bTools, err := b.Toolset.Tools(ctx) + if err != nil { + return nil, fmt.Errorf("getting tools from LSP backend %q: %w", b.LSP.handler.command, err) + } + for _, t := range bTools { + handlersByName[t.Name] = append(handlersByName[t.Name], lspRouteTarget{b.LSP, t.Handler}) + if _, exists := seenTools[t.Name]; !exists { + seenTools[t.Name] = t + toolOrder = append(toolOrder, t.Name) + } + } + } + + result := make([]tools.Tool, 0, len(toolOrder)) + for _, name := range toolOrder { + t := seenTools[name] + handlers := handlersByName[name] + if name == ToolNameLSPWorkspace || name == ToolNameLSPWorkspaceSymbols { + t.Handler = broadcastLSP(handlers) + } else { + t.Handler = routeByFile(handlers) + } + result = append(result, t) + } + return result, nil +} + +// routeByFile returns a handler that extracts the "file" field from the JSON +// arguments and dispatches to the backend whose file-type filter matches. +func routeByFile(handlers []lspRouteTarget) tools.ToolHandler { + return func(ctx context.Context, tc tools.ToolCall) (*tools.ToolCallResult, error) { + var args struct { + File string `json:"file"` + } + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { + return tools.ResultError(fmt.Sprintf("failed to parse file argument: %s", err)), nil + } + if args.File == "" { + return tools.ResultError("file argument is required"), nil + } + for _, h := range handlers { + if h.lsp.HandlesFile(args.File) { + return h.handler(ctx, tc) + } + } + return tools.ResultError(fmt.Sprintf("no LSP server configured for file: %s", args.File)), nil + } +} + +// broadcastLSP returns a handler that calls every backend best-effort and +// merges the outputs. Individual backend failures are collected rather than +// aborting the entire operation. +func broadcastLSP(handlers []lspRouteTarget) tools.ToolHandler { + return func(ctx context.Context, tc tools.ToolCall) (*tools.ToolCallResult, error) { + var sections []string + var errs []error + for _, h := range handlers { + result, err := h.handler(ctx, tc) + if err != nil { + errs = append(errs, fmt.Errorf("backend %s: %w", h.lsp.handler.command, err)) + continue + } + if result.IsError { + sections = append(sections, fmt.Sprintf("[LSP %s] Error: %s", h.lsp.handler.command, result.Output)) + } else if result.Output != "" { + sections = append(sections, result.Output) + } + } + if len(sections) == 0 && len(errs) > 0 { + return nil, errors.Join(errs...) + } + if len(sections) == 0 { + return tools.ResultSuccess("No results"), nil + } + return tools.ResultSuccess(strings.Join(sections, "\n---\n")), nil + } +} diff --git a/pkg/tools/builtin/lsp_multiplexer_test.go b/pkg/tools/builtin/lsp_multiplexer_test.go new file mode 100644 index 000000000..6a8b0f2d2 --- /dev/null +++ b/pkg/tools/builtin/lsp_multiplexer_test.go @@ -0,0 +1,149 @@ +package builtin + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/tools" +) + +// newTestMultiplexer creates a multiplexer with a Go and Python backend. +func newTestMultiplexer() (*LSPMultiplexer, *LSPTool) { + goTool := NewLSPTool("gopls", nil, nil, "/tmp") + goTool.SetFileTypes([]string{".go", ".mod"}) + + pyTool := NewLSPTool("pyright", nil, nil, "/tmp") + pyTool.SetFileTypes([]string{".py"}) + + mux := NewLSPMultiplexer([]LSPBackend{ + {LSP: goTool, Toolset: goTool}, + {LSP: pyTool, Toolset: pyTool}, + }) + return mux, goTool +} + +// findTool returns the tool with the given name, or fails the test. +func findTool(t *testing.T, allTools []tools.Tool, name string) tools.Tool { + t.Helper() + for _, tool := range allTools { + if tool.Name == name { + return tool + } + } + t.Fatalf("tool %q not found", name) + return tools.Tool{} +} + +// callHover is a shorthand to invoke lsp_hover on a given file through the multiplexer. +func callHover(t *testing.T, mux *LSPMultiplexer, args string) *tools.ToolCallResult { + t.Helper() + allTools, err := mux.Tools(t.Context()) + require.NoError(t, err) + hover := findTool(t, allTools, ToolNameLSPHover) + tc := tools.ToolCall{Function: tools.FunctionCall{Name: ToolNameLSPHover, Arguments: args}} + result, err := hover.Handler(t.Context(), tc) + require.NoError(t, err) + return result +} + +func TestLSPMultiplexer_Tools_NoDuplicates(t *testing.T) { + t.Parallel() + + mux, goTool := newTestMultiplexer() + + allTools, err := mux.Tools(t.Context()) + require.NoError(t, err) + + // Should have the same number of tools as a single LSP backend. + singleTools, err := goTool.Tools(t.Context()) + require.NoError(t, err) + assert.Len(t, allTools, len(singleTools)) + + // No duplicate tool names. + seen := make(map[string]bool) + for _, tool := range allTools { + assert.False(t, seen[tool.Name], "duplicate tool name: %s", tool.Name) + seen[tool.Name] = true + } +} + +func TestLSPMultiplexer_RoutesToCorrectBackend(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + + // .go → routes to gopls, .py → routes to pyright. + // Both backends are not running so they will auto-init and respond with + // some output — we just verify routing produces a non-empty response. + for _, file := range []string{"/tmp/main.go", "/tmp/app.py"} { + result := callHover(t, mux, `{"file": "`+file+`", "line": 1, "character": 1}`) + assert.NotEmpty(t, result.Output, "expected output for %s", file) + } +} + +func TestLSPMultiplexer_NoBackendForFile(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + result := callHover(t, mux, `{"file": "/tmp/main.rs", "line": 1, "character": 1}`) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "no LSP server configured for file") +} + +func TestLSPMultiplexer_EmptyFileArgument(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + result := callHover(t, mux, `{"line": 1, "character": 1}`) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "file argument is required") +} + +func TestLSPMultiplexer_InvalidJSON(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + result := callHover(t, mux, `{invalid`) + assert.True(t, result.IsError) + assert.Contains(t, result.Output, "failed to parse file argument") +} + +func TestLSPMultiplexer_Instructions(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + instructions := mux.Instructions() + assert.Contains(t, instructions, "lsp_hover") + assert.Contains(t, instructions, "Stateless") + + // Both backends share the same instructions — "Stateless" should appear only once. + assert.Equal(t, 1, strings.Count(instructions, "Stateless"), + "identical instructions should be deduplicated") +} + +func TestLSPMultiplexer_WorkspaceToolBroadcasts(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + + allTools, err := mux.Tools(t.Context()) + require.NoError(t, err) + workspace := findTool(t, allTools, ToolNameLSPWorkspace) + + args, _ := json.Marshal(WorkspaceArgs{}) + tc := tools.ToolCall{Function: tools.FunctionCall{Name: ToolNameLSPWorkspace, Arguments: string(args)}} + result, err := workspace.Handler(t.Context(), tc) + require.NoError(t, err) + assert.NotEmpty(t, result.Output) +} + +func TestLSPMultiplexer_Stop_NotStarted(t *testing.T) { + t.Parallel() + + mux, _ := newTestMultiplexer() + require.NoError(t, mux.Stop(t.Context())) +} diff --git a/pkg/tools/codemode/codemode.go b/pkg/tools/codemode/codemode.go index 940ca464e..ec4a47a05 100644 --- a/pkg/tools/codemode/codemode.go +++ b/pkg/tools/codemode/codemode.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "strings" "github.com/docker/cagent/pkg/tools" @@ -104,9 +105,9 @@ func (c *codeModeTool) Tools(ctx context.Context) ([]tools.Tool, error) { func (c *codeModeTool) Start(ctx context.Context) error { for _, t := range c.toolsets { - if startable, ok := t.(tools.Startable); ok { + if startable, ok := tools.As[tools.Startable](t); ok { if err := startable.Start(ctx); err != nil { - return err + slog.Warn("Code mode: toolset start failed; continuing", "toolset", fmt.Sprintf("%T", t), "error", err) } } } @@ -118,7 +119,7 @@ func (c *codeModeTool) Stop(ctx context.Context) error { var errs []error for _, t := range c.toolsets { - if startable, ok := t.(tools.Startable); ok { + if startable, ok := tools.As[tools.Startable](t); ok { if err := startable.Stop(ctx); err != nil { errs = append(errs, err) } diff --git a/pkg/tools/codemode/codemode_test.go b/pkg/tools/codemode/codemode_test.go index 43a9c61e9..4298e4d39 100644 --- a/pkg/tools/codemode/codemode_test.go +++ b/pkg/tools/codemode/codemode_test.go @@ -213,6 +213,30 @@ func (t *testToolSet) Stop(context.Context) error { return nil } +// failingToolSet always returns an error on Start. +type failingToolSet struct { + testToolSet +} + +func (f *failingToolSet) Start(context.Context) error { + f.start++ + return assert.AnError +} + +// wrappingToolSet wraps another ToolSet without implementing Startable, +// but implements Unwrapper so tools.As can find the inner Startable. +type wrappingToolSet struct { + inner tools.ToolSet +} + +func (w *wrappingToolSet) Tools(ctx context.Context) ([]tools.Tool, error) { + return w.inner.Tools(ctx) +} + +func (w *wrappingToolSet) Unwrap() tools.ToolSet { + return w.inner +} + // TestCodeModeTool_SuccessNoToolCalls verifies that successful execution does not include tool calls. func TestCodeModeTool_SuccessNoToolCalls(t *testing.T) { tool := Wrap(&testToolSet{ @@ -373,3 +397,32 @@ func TestCodeModeTool_FailureIncludesToolArguments(t *testing.T) { assert.Equal(t, map[string]any{"value": "test123"}, scriptResult.ToolCalls[0].Arguments) assert.Equal(t, "result", scriptResult.ToolCalls[0].Result) } + +func TestCodeModeTool_Start_ContinuesOnFailure(t *testing.T) { + first := &testToolSet{} + failing := &failingToolSet{} + third := &testToolSet{} + + tool := Wrap(first, failing, third) + + startable := tool.(tools.Startable) + err := startable.Start(t.Context()) + require.NoError(t, err) + + assert.Equal(t, 1, first.start, "first toolset should be started") + assert.Equal(t, 1, failing.start, "failing toolset should have been attempted") + assert.Equal(t, 1, third.start, "third toolset should be started despite earlier failure") +} + +func TestCodeModeTool_Start_UnwrapsToFindStartable(t *testing.T) { + inner := &testToolSet{} + wrapped := &wrappingToolSet{inner: inner} + + tool := Wrap(wrapped) + + startable := tool.(tools.Startable) + err := startable.Start(t.Context()) + require.NoError(t, err) + + assert.Equal(t, 1, inner.start, "should unwrap to find the Startable inner toolset") +}