Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e/binary/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestExecMissingKeys(t *testing.T) {
require.Contains(t, res.Stderr, "OPENAI_API_KEY")
})
}

func TestAutoComplete(t *testing.T) {
t.Run("cli plugin auto-complete docker-agent", func(t *testing.T) {
res, err := Exec(binDir+"/docker-agent", "__complete", "ser")
Expand Down
25 changes: 23 additions & 2 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates
// getToolsForAgent returns the tool definitions for an agent based on its configuration
func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry, configName string) ([]tools.ToolSet, []string) {
var (
toolSets []tools.ToolSet
warnings []string
toolSets []tools.ToolSet
warnings []string
lspBackends []builtin.LSPBackend
)

deferredToolset := builtin.NewDeferredToolset()
Expand Down Expand Up @@ -460,9 +461,29 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
}
}

// Collect LSP backends for multiplexing when there are multiple.
// Instead of adding them individually (which causes duplicate tool names),
// they are combined into a single LSPMultiplexer after the loop.
if toolset.Type == "lsp" {
if lspTool, ok := tool.(*builtin.LSPTool); ok {
Comment thread
dgageot marked this conversation as resolved.
lspBackends = append(lspBackends, builtin.LSPBackend{LSP: lspTool, Toolset: wrapped})
continue
}
slog.Warn("Toolset configured as type 'lsp' but registry returned unexpected type; treating as regular toolset",
"type", fmt.Sprintf("%T", tool), "command", toolset.Command)
}

toolSets = append(toolSets, wrapped)
}

// Merge LSP backends: if there are multiple, combine them into a single
// multiplexer so the LLM sees one set of lsp_* tools instead of duplicates.
if len(lspBackends) > 1 {
toolSets = append(toolSets, builtin.NewLSPMultiplexer(lspBackends))
} else if len(lspBackends) == 1 {
toolSets = append(toolSets, lspBackends[0].Toolset)
}

if deferredToolset.HasSources() {
toolSets = append(toolSets, deferredToolset)
}
Expand Down
82 changes: 82 additions & 0 deletions pkg/teamloader/teamloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,88 @@ agents:
assert.Equal(t, expected, rootAgent.AddPromptFiles())
}

func TestGetToolsForAgent_MultipleLSPToolsetsAreCombined(t *testing.T) {
t.Parallel()

a := &latest.AgentConfig{
Instruction: "test",
Toolsets: []latest.Toolset{
{
Type: "lsp",
Command: "gopls",
Version: "golang/tools@v0.21.0",
FileTypes: []string{".go"},
},
{
Type: "lsp",
Command: "gopls",
Version: "golang/tools@v0.21.0",
FileTypes: []string{".mod"},
},
},
}

runConfig := config.RuntimeConfig{
EnvProviderForTests: &noEnvProvider{},
}

got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry(), "test-config")
require.Empty(t, warnings)

// Should have exactly one toolset (the multiplexer)
require.Len(t, got, 1)

// Verify that we get no duplicate tool names
allTools, err := got[0].Tools(t.Context())
require.NoError(t, err)

seen := make(map[string]bool)
for _, tool := range allTools {
assert.False(t, seen[tool.Name], "duplicate tool name: %s", tool.Name)
seen[tool.Name] = true
}

// Verify LSP tools are present
assert.True(t, seen["lsp_hover"])
assert.True(t, seen["lsp_definition"])
}

func TestGetToolsForAgent_SingleLSPToolsetNotWrapped(t *testing.T) {
t.Parallel()

a := &latest.AgentConfig{
Instruction: "test",
Toolsets: []latest.Toolset{
{
Type: "lsp",
Command: "gopls",
Version: "golang/tools@v0.21.0",
FileTypes: []string{".go"},
},
},
}

runConfig := config.RuntimeConfig{
EnvProviderForTests: &noEnvProvider{},
}

got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewDefaultToolsetRegistry(), "test-config")
require.Empty(t, warnings)

// Should have exactly one toolset that provides LSP tools.
require.Len(t, got, 1)

allTools, err := got[0].Tools(t.Context())
require.NoError(t, err)

var names []string
for _, tool := range allTools {
names = append(names, tool.Name)
}
assert.Contains(t, names, "lsp_hover")
assert.Contains(t, names, "lsp_definition")
}

func TestExternalDepthContext(t *testing.T) {
t.Parallel()

Expand Down
162 changes: 58 additions & 104 deletions pkg/toolinstall/registry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package toolinstall

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand All @@ -10,9 +11,9 @@ import (
"path/filepath"
"strings"
"sync"
"time"

"github.com/goccy/go-yaml"
"github.com/natefinch/atomic"
)

// githubToken returns a GitHub personal access token from the environment,
Expand All @@ -35,8 +36,8 @@ func setGitHubAuth(req *http.Request) {
}

const (
registryBaseURL = "https://raw.githubusercontent.com/aquaproj/aqua-registry/main"
registryCacheTTL = 24 * time.Hour
registryBaseURL = "https://raw.githubusercontent.com/aquaproj/aqua-registry/main"
registryIndexFile = "registry.yaml"
)

// Package represents a parsed aqua registry package definition.
Expand Down Expand Up @@ -104,11 +105,6 @@ type Registry struct {
httpClient *http.Client
baseURL string
cacheDir string

// In-memory cache for the parsed registry index, populated once via sync.Once.
indexOnce sync.Once
cachedIndex *registryIndex
indexErr error
}

var (
Expand Down Expand Up @@ -157,8 +153,8 @@ func (r *Registry) LookupByName(ctx context.Context, name string) (*Package, err
}
}

// Fallback: fetch the per-package YAML file.
data, err := r.fetchCached(ctx, fmt.Sprintf("pkgs/%s/%s/registry.yaml", owner, repo), 0)
// Fallback: fetch the per-package YAML file directly (no caching).
data, err := r.getBody(ctx, r.baseURL+"/"+fmt.Sprintf("pkgs/%s/%s/registry.yaml", owner, repo))
if err != nil {
return nil, fmt.Errorf("fetching package %s: %w", name, err)
}
Expand Down Expand Up @@ -213,88 +209,32 @@ func providesCommand(pkg *Package, command string) bool {
return false
}

// fetchIndex fetches and parses the full registry index, with caching.
// The parsed result is cached in memory so that repeated calls within the
// same Registry instance skip both the HTTP fetch and YAML deserialization.
// fetchIndex fetches and parses the full registry index.
// The raw YAML is cached to disk; on fetch failure the cached copy is used.
// The YAML is re-parsed on every call — there is no in-memory cache.
func (r *Registry) fetchIndex(ctx context.Context) (*registryIndex, error) {
r.indexOnce.Do(func() {
var data []byte
data, r.indexErr = r.fetchCached(ctx, "registry.yaml", registryCacheTTL)
if r.indexErr != nil {
return
}

var index registryIndex
if err := yaml.Unmarshal(data, &index); err != nil {
r.indexErr = fmt.Errorf("parsing registry index: %w", err)
return
}
r.cachedIndex = &index
})

return r.cachedIndex, r.indexErr
}

// fetchCached fetches a file from the registry, using a local file cache.
// A ttl of 0 means the cache never expires.
func (r *Registry) fetchCached(ctx context.Context, path string, ttl time.Duration) ([]byte, error) {
cachePath := filepath.Join(r.cacheDir, path)

// Return cached data if still fresh.
if info, err := os.Stat(cachePath); err == nil {
if ttl == 0 || time.Since(info.ModTime()) < ttl {
return os.ReadFile(cachePath)
}
}

// Fetch from remote.
url := r.baseURL + "/" + path
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
if data, readErr := os.ReadFile(cachePath); readErr == nil {
return data, nil
}
return nil, fmt.Errorf("creating request for %s: %w", url, err)
}
setGitHubAuth(req)
cachePath := filepath.Join(r.cacheDir, registryIndexFile)

resp, err := r.httpClient.Do(req)
data, err := r.getBody(ctx, r.baseURL+"/"+registryIndexFile)
if err != nil {
if data, readErr := os.ReadFile(cachePath); readErr == nil {
return data, nil // stale cache beats no data
// Fallback to stale disk cache.
if cached, readErr := os.ReadFile(cachePath); readErr == nil {
data = cached
} else {
return nil, err
}
return nil, fmt.Errorf("fetching %s: %w", url, err)
} else {
// Best-effort: persist to disk for future fallback.
_ = os.MkdirAll(filepath.Dir(cachePath), 0o755)
_ = atomic.WriteFile(cachePath, bytes.NewReader(data))
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
if data, readErr := os.ReadFile(cachePath); readErr == nil {
return data, nil
}
return nil, fmt.Errorf("fetching %s: HTTP %d", url, resp.StatusCode)
var index registryIndex
if err := yaml.Unmarshal(data, &index); err != nil {
return nil, fmt.Errorf("parsing registry index: %w", err)
}

data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response from %s: %w", url, err)
}

// Write to cache atomically (best-effort): write to a temp file in the
// same directory, then rename. This avoids races when multiple goroutines
// fetch the same path concurrently.
if err := os.MkdirAll(filepath.Dir(cachePath), 0o755); err == nil {
if tmpFile, tmpErr := os.CreateTemp(filepath.Dir(cachePath), ".cache-*.tmp"); tmpErr == nil {
if _, writeErr := tmpFile.Write(data); writeErr == nil {
tmpFile.Close()
_ = os.Rename(tmpFile.Name(), cachePath)
} else {
tmpFile.Close()
_ = os.Remove(tmpFile.Name())
}
}
}

return data, nil
return &index, nil
}

// githubRelease represents the relevant fields from the GitHub releases API.
Expand All @@ -307,7 +247,7 @@ func (r *Registry) latestVersion(ctx context.Context, owner, repo string) (strin
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", owner, repo)

var release githubRelease
if err := r.fetchGitHubJSON(ctx, url, &release); err != nil {
if err := r.getJSON(ctx, url, &release); err != nil {
return "", fmt.Errorf("fetching latest release for %s/%s: %w", owner, repo, err)
}

Expand All @@ -324,7 +264,7 @@ func (r *Registry) latestVersionFiltered(ctx context.Context, owner, repo, tagPr
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases?per_page=50", owner, repo)

var releases []githubRelease
if err := r.fetchGitHubJSON(ctx, url, &releases); err != nil {
if err := r.getJSON(ctx, url, &releases); err != nil {
return "", fmt.Errorf("fetching releases for %s/%s: %w", owner, repo, err)
}

Expand All @@ -337,45 +277,59 @@ func (r *Registry) latestVersionFiltered(ctx context.Context, owner, repo, tagPr
return "", fmt.Errorf("no release found for %s/%s with tag prefix %q", owner, repo, tagPrefix)
}

// fetchGitHubJSON fetches a GitHub API endpoint and decodes the JSON response.
func (r *Registry) fetchGitHubJSON(ctx context.Context, url string, target any) error {
// doGet performs an authenticated GET request and returns the response.
// The caller is responsible for closing the response body.
func (r *Registry) doGet(ctx context.Context, url string, headers map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return err
return nil, err
}
for k, v := range headers {
req.Header.Set(k, v)
}
req.Header.Set("Accept", "application/vnd.github+json")
setGitHubAuth(req)

resp, err := r.httpClient.Do(req)
if err != nil {
return err
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP %d", resp.StatusCode)
resp.Body.Close()
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}

return json.NewDecoder(resp.Body).Decode(target)
return resp, nil
}

// download opens an HTTP connection to the given URL and returns the
// response body as an io.ReadCloser. The caller is responsible for closing it.
func (r *Registry) download(ctx context.Context, url string) (io.ReadCloser, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
// getBody performs a GET request and returns the full response body.
func (r *Registry) getBody(ctx context.Context, url string) ([]byte, error) {
resp, err := r.doGet(ctx, url, nil)
if err != nil {
return nil, err
}
setGitHubAuth(req)
defer resp.Body.Close()

resp, err := r.httpClient.Do(req)
return io.ReadAll(resp.Body)
}

// getJSON performs a GET request and decodes the JSON response into target.
func (r *Registry) getJSON(ctx context.Context, url string, target any) error {
resp, err := r.doGet(ctx, url, map[string]string{"Accept": "application/vnd.github+json"})
if err != nil {
return nil, err
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
return json.NewDecoder(resp.Body).Decode(target)
}

// download opens an HTTP connection to the given URL and returns the
// response body as an io.ReadCloser. The caller is responsible for closing it.
func (r *Registry) download(ctx context.Context, url string) (io.ReadCloser, error) {
resp, err := r.doGet(ctx, url, nil)
if err != nil {
return nil, err
}

return resp.Body, nil
Expand Down
Loading