diff --git a/docs/configuration/tools/index.md b/docs/configuration/tools/index.md index 5dfaa5fdf..c5ed38802 100644 --- a/docs/configuration/tools/index.md +++ b/docs/configuration/tools/index.md @@ -101,17 +101,29 @@ toolsets: ### Memory -Persistent key-value storage backed by SQLite. Data survives across sessions, letting agents remember context, user preferences, and past decisions. +Persistent key-value storage backed by SQLite. Data survives across sessions, letting agents remember context, user preferences, and past decisions. Memories can be organized with categories and searched by keyword. + +Each agent gets its own database at `~/.cagent/memory//memory.db` by default. ```yaml toolsets: - type: memory - path: ./agent_memory.db # optional: custom database path + path: ./agent_memory.db # optional: override the default location ``` -| Property | Type | Default | Description | -| -------- | ------ | --------- | ---------------------------------------------------------------------- | -| `path` | string | automatic | Path to the SQLite database file. If omitted, uses a default location. | +| Property | Type | Default | Description | +| -------- | ------ | -------------------------------------------- | ------------------------------------ | +| `path` | string | `~/.cagent/memory//memory.db` | Path to the SQLite database file | + +| Operation | Description | +| ------------------ | ------------------------------------------------------------------- | +| `add_memory` | Store a new memory with optional category | +| `get_memories` | Retrieve all stored memories | +| `delete_memory` | Delete a specific memory by ID | +| `search_memories` | Search memories by keywords and/or category (more efficient than get_all) | +| `update_memory` | Update an existing memory's content and/or category by ID | + +Memories support an optional `category` field (e.g., `preference`, `fact`, `project`, `decision`) for organization and filtering. ### Fetch diff --git a/pkg/acp/registry.go b/pkg/acp/registry.go index 101397c46..50c8a7e7a 100644 --- a/pkg/acp/registry.go +++ b/pkg/acp/registry.go @@ -14,7 +14,7 @@ import ( func createToolsetRegistry(agent *Agent) *teamloader.ToolsetRegistry { registry := teamloader.NewDefaultToolsetRegistry() - registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { + registry.Register("filesystem", func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { wd := runConfig.WorkingDir if wd == "" { var err error diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index 74e1742f6..d9f4cf310 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -116,9 +116,7 @@ func (t *Toolset) validate() error { case "shell": // no additional validation needed case "memory": - if t.Path == "" { - return errors.New("memory toolset requires a path to be set") - } + // path is optional; defaults to ~/.cagent/memory//memory.db case "tasks": // path defaults to ./tasks.json if not set case "mcp": diff --git a/pkg/creator/agent.go b/pkg/creator/agent.go index eac940ee8..99d7a1711 100644 --- a/pkg/creator/agent.go +++ b/pkg/creator/agent.go @@ -120,7 +120,7 @@ func createToolsetRegistry(workingDir string) *teamloader.ToolsetRegistry { } registry := teamloader.NewDefaultToolsetRegistry() - registry.Register("filesystem", func(context.Context, latest.Toolset, string, *config.RuntimeConfig) (tools.ToolSet, error) { + registry.Register("filesystem", func(context.Context, latest.Toolset, string, *config.RuntimeConfig, string) (tools.ToolSet, error) { return tracker, nil }) diff --git a/pkg/creator/agent_test.go b/pkg/creator/agent_test.go index 182388a2c..f39aea009 100644 --- a/pkg/creator/agent_test.go +++ b/pkg/creator/agent_test.go @@ -158,7 +158,7 @@ func TestFileWriteTracker(t *testing.T) { require.NotNil(t, registry) // Create the toolset through the registry - toolset, err := registry.CreateTool(ctx, latest.Toolset{Type: "filesystem"}, runConfig.WorkingDir, runConfig) + toolset, err := registry.CreateTool(ctx, latest.Toolset{Type: "filesystem"}, runConfig.WorkingDir, runConfig, "test-agent") require.NoError(t, err) require.NotNil(t, toolset) diff --git a/pkg/memory/database/database.go b/pkg/memory/database/database.go index 9ca97f061..bdcd93189 100644 --- a/pkg/memory/database/database.go +++ b/pkg/memory/database/database.go @@ -5,16 +5,22 @@ import ( "errors" ) -var ErrEmptyID = errors.New("memory ID cannot be empty") +var ( + ErrEmptyID = errors.New("memory ID cannot be empty") + ErrMemoryNotFound = errors.New("memory not found") +) type UserMemory struct { - ID string `description:"The ID of the memory"` - CreatedAt string `description:"The creation timestamp of the memory"` - Memory string `description:"The content of the memory"` + ID string `json:"id" description:"The ID of the memory"` + CreatedAt string `json:"created_at" description:"The creation timestamp of the memory"` + Memory string `json:"memory" description:"The content of the memory"` + Category string `json:"category,omitempty" description:"The category of the memory"` } type Database interface { AddMemory(ctx context.Context, memory UserMemory) error GetMemories(ctx context.Context) ([]UserMemory, error) DeleteMemory(ctx context.Context, memory UserMemory) error + SearchMemories(ctx context.Context, query, category string) ([]UserMemory, error) + UpdateMemory(ctx context.Context, memory UserMemory) error } diff --git a/pkg/memory/database/sqlite/sqlite.go b/pkg/memory/database/sqlite/sqlite.go index deebaaec7..62eaf27db 100644 --- a/pkg/memory/database/sqlite/sqlite.go +++ b/pkg/memory/database/sqlite/sqlite.go @@ -3,6 +3,8 @@ package sqlite import ( "context" "database/sql" + "fmt" + "strings" "github.com/docker/cagent/pkg/memory/database" "github.com/docker/cagent/pkg/sqliteutil" @@ -26,6 +28,14 @@ func NewMemoryDatabase(path string) (database.Database, error) { return nil, err } + // Add category column if it doesn't exist (transparent migration) + if _, err := db.ExecContext(context.Background(), "ALTER TABLE memories ADD COLUMN category TEXT DEFAULT ''"); err != nil { + if !strings.Contains(err.Error(), "duplicate column name") { + db.Close() + return nil, fmt.Errorf("memory database migration failed: %w", err) + } + } + return &MemoryDatabase{db: db}, nil } @@ -33,13 +43,13 @@ func (m *MemoryDatabase) AddMemory(ctx context.Context, memory database.UserMemo if memory.ID == "" { return database.ErrEmptyID } - _, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory) VALUES (?, ?, ?)", - memory.ID, memory.CreatedAt, memory.Memory) + _, err := m.db.ExecContext(ctx, "INSERT INTO memories (id, created_at, memory, category) VALUES (?, ?, ?, ?)", + memory.ID, memory.CreatedAt, memory.Memory, memory.Category) return err } func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory, error) { - rows, err := m.db.QueryContext(ctx, "SELECT id, created_at, memory FROM memories") + rows, err := m.db.QueryContext(ctx, "SELECT id, created_at, memory, COALESCE(category, '') FROM memories") if err != nil { return nil, err } @@ -48,7 +58,7 @@ func (m *MemoryDatabase) GetMemories(ctx context.Context) ([]database.UserMemory var memories []database.UserMemory for rows.Next() { var memory database.UserMemory - err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory) + err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory, &memory.Category) if err != nil { return nil, err } @@ -66,3 +76,73 @@ func (m *MemoryDatabase) DeleteMemory(ctx context.Context, memory database.UserM _, err := m.db.ExecContext(ctx, "DELETE FROM memories WHERE id = ?", memory.ID) return err } + +func (m *MemoryDatabase) SearchMemories(ctx context.Context, query, category string) ([]database.UserMemory, error) { + var conditions []string + var args []any + + if query != "" { + words := strings.Fields(query) + for _, word := range words { + conditions = append(conditions, "LOWER(memory) LIKE LOWER(?) ESCAPE '\\'") + escaped := strings.ReplaceAll(word, `\`, `\\`) + escaped = strings.ReplaceAll(escaped, `%`, `\%`) + escaped = strings.ReplaceAll(escaped, `_`, `\_`) + args = append(args, "%"+escaped+"%") + } + } + + if category != "" { + conditions = append(conditions, "LOWER(category) = LOWER(?)") + args = append(args, category) + } + + stmt := "SELECT id, created_at, memory, COALESCE(category, '') FROM memories" + if len(conditions) > 0 { + stmt += " WHERE " + strings.Join(conditions, " AND ") + } + + rows, err := m.db.QueryContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var memories []database.UserMemory + for rows.Next() { + var memory database.UserMemory + err := rows.Scan(&memory.ID, &memory.CreatedAt, &memory.Memory, &memory.Category) + if err != nil { + return nil, err + } + memories = append(memories, memory) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return memories, nil +} + +func (m *MemoryDatabase) UpdateMemory(ctx context.Context, memory database.UserMemory) error { + if memory.ID == "" { + return database.ErrEmptyID + } + + result, err := m.db.ExecContext(ctx, "UPDATE memories SET memory = ?, category = ? WHERE id = ?", + memory.Memory, memory.Category, memory.ID) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return fmt.Errorf("%w: %s", database.ErrMemoryNotFound, memory.ID) + } + + return nil +} diff --git a/pkg/memory/database/sqlite/sqlite_test.go b/pkg/memory/database/sqlite/sqlite_test.go index 1fb642a26..1b3b95b23 100644 --- a/pkg/memory/database/sqlite/sqlite_test.go +++ b/pkg/memory/database/sqlite/sqlite_test.go @@ -65,6 +65,25 @@ func TestAddMemory(t *testing.T) { require.Error(t, err, "Adding memory with empty ID should fail") } +func TestAddMemoryWithCategory(t *testing.T) { + db := setupTestDB(t) + + memory := database.UserMemory{ + ID: "cat-1", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "User prefers dark mode", + Category: "preference", + } + + err := db.AddMemory(t.Context(), memory) + require.NoError(t, err) + + memories, err := db.GetMemories(t.Context()) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "preference", memories[0].Category) +} + func TestGetMemories(t *testing.T) { db := setupTestDB(t) @@ -139,6 +158,149 @@ func TestDeleteMemory(t *testing.T) { require.NoError(t, err, "Deleting non-existent memory should not return an error") } +func TestSearchMemories(t *testing.T) { + db := setupTestDB(t) + ctx := t.Context() + + testMemories := []database.UserMemory{ + {ID: "1", CreatedAt: time.Now().Format(time.RFC3339), Memory: "User prefers dark mode", Category: "preference"}, + {ID: "2", CreatedAt: time.Now().Format(time.RFC3339), Memory: "Project uses Go and React", Category: "project"}, + {ID: "3", CreatedAt: time.Now().Format(time.RFC3339), Memory: "User likes Go for backend", Category: "preference"}, + {ID: "4", CreatedAt: time.Now().Format(time.RFC3339), Memory: "Deploy to AWS us-east-1", Category: "project"}, + } + for _, m := range testMemories { + require.NoError(t, db.AddMemory(ctx, m)) + } + + t.Run("single keyword", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "Go", "") + require.NoError(t, err) + assert.Len(t, results, 2) + }) + + t.Run("multi-word AND", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "Go backend", "") + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "3", results[0].ID) + }) + + t.Run("category filter only", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "", "preference") + require.NoError(t, err) + assert.Len(t, results, 2) + }) + + t.Run("keyword plus category", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "Go", "project") + require.NoError(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "2", results[0].ID) + }) + + t.Run("empty query returns all", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "", "") + require.NoError(t, err) + assert.Len(t, results, 4) + }) + + t.Run("no matches", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "nonexistent", "") + require.NoError(t, err) + assert.Empty(t, results) + }) + + t.Run("case insensitive", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "go", "") + require.NoError(t, err) + assert.Len(t, results, 2) + }) + + t.Run("case insensitive category", func(t *testing.T) { + results, err := db.SearchMemories(ctx, "", "PREFERENCE") + require.NoError(t, err) + assert.Len(t, results, 2) + }) +} + +func TestUpdateMemory(t *testing.T) { + db := setupTestDB(t) + ctx := t.Context() + + memory := database.UserMemory{ + ID: "upd-1", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "Original content", + Category: "fact", + } + require.NoError(t, db.AddMemory(ctx, memory)) + + t.Run("update content and category", func(t *testing.T) { + err := db.UpdateMemory(ctx, database.UserMemory{ + ID: "upd-1", + Memory: "Updated content", + Category: "decision", + }) + require.NoError(t, err) + + memories, err := db.GetMemories(ctx) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "Updated content", memories[0].Memory) + assert.Equal(t, "decision", memories[0].Category) + // CreatedAt should be preserved + assert.Equal(t, memory.CreatedAt, memories[0].CreatedAt) + }) + + t.Run("not found", func(t *testing.T) { + err := db.UpdateMemory(ctx, database.UserMemory{ + ID: "nonexistent", + Memory: "something", + }) + require.Error(t, err) + assert.ErrorIs(t, err, database.ErrMemoryNotFound) + }) + + t.Run("empty ID", func(t *testing.T) { + err := db.UpdateMemory(ctx, database.UserMemory{ + ID: "", + Memory: "something", + }) + require.Error(t, err) + assert.ErrorIs(t, err, database.ErrEmptyID) + }) +} + +func TestMigrationAddsCategory(t *testing.T) { + tmpFile := t.TempDir() + "/migrate.db" + + // Create a DB with the old schema (no category column) + db1, err := NewMemoryDatabase(tmpFile) + require.NoError(t, err) + memDB1 := db1.(*MemoryDatabase) + + // Add a memory (which now includes category column from migration) + err = db1.AddMemory(t.Context(), database.UserMemory{ + ID: "old-1", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "Old memory without category", + }) + require.NoError(t, err) + memDB1.db.Close() + + // Reopen - migration should be idempotent + db2, err := NewMemoryDatabase(tmpFile) + require.NoError(t, err) + memDB2 := db2.(*MemoryDatabase) + defer memDB2.db.Close() + + memories, err := db2.GetMemories(t.Context()) + require.NoError(t, err) + require.Len(t, memories, 1) + assert.Equal(t, "Old memory without category", memories[0].Memory) + assert.Empty(t, memories[0].Category) +} + func TestDatabaseOperationsWithCanceledContext(t *testing.T) { db := setupTestDB(t) @@ -159,6 +321,12 @@ func TestDatabaseOperationsWithCanceledContext(t *testing.T) { err = db.DeleteMemory(ctx, memory) require.Error(t, err, "DeleteMemory should fail with canceled context") + + _, err = db.SearchMemories(ctx, "test", "") + require.Error(t, err, "SearchMemories should fail with canceled context") + + err = db.UpdateMemory(ctx, memory) + require.Error(t, err, "UpdateMemory should fail with canceled context") } func TestDatabaseWithMultipleInstances(t *testing.T) { diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index 0028a9837..22ae10b9a 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -14,6 +14,7 @@ import ( "github.com/docker/cagent/pkg/js" "github.com/docker/cagent/pkg/memory/database/sqlite" "github.com/docker/cagent/pkg/path" + "github.com/docker/cagent/pkg/paths" "github.com/docker/cagent/pkg/toolinstall" "github.com/docker/cagent/pkg/tools" "github.com/docker/cagent/pkg/tools/a2a" @@ -21,8 +22,9 @@ import ( "github.com/docker/cagent/pkg/tools/mcp" ) -// ToolsetCreator is a function that creates a toolset based on the provided configuration -type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) +// ToolsetCreator is a function that creates a toolset based on the provided configuration. +// configName identifies the agent config file (e.g. "memory_agent" from "memory_agent.yaml"). +type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) // ToolsetRegistry manages the registration of toolset creators by type type ToolsetRegistry struct { @@ -48,12 +50,12 @@ func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) { } // CreateTool creates a toolset using the registered creator for the given type -func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, agentName string) (tools.ToolSet, error) { creator, ok := r.Get(toolset.Type) if !ok { return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type) } - return creator(ctx, toolset, parentDir, runConfig) + return creator(ctx, toolset, parentDir, runConfig, agentName) } func NewDefaultToolsetRegistry() *ToolsetRegistry { @@ -77,14 +79,14 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry { return r } -func createTodoTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { +func createTodoTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { if toolset.Shared { return builtin.NewSharedTodoTool(), nil } return builtin.NewTodoTool(), nil } -func createTasksTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createTasksTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { toolsetPath := toolset.Path if toolsetPath == "" { toolsetPath = "tasks.json" @@ -110,20 +112,33 @@ func createTasksTool(_ context.Context, toolset latest.Toolset, parentDir string return builtin.NewTasksTool(validatedPath), nil } -func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { - var memoryPath string - if filepath.IsAbs(toolset.Path) { - memoryPath = "" - } else if wd := runConfig.WorkingDir; wd != "" { - memoryPath = wd +func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir string, runConfig *config.RuntimeConfig, configName string) (tools.ToolSet, error) { + var validatedMemoryPath string + + if toolset.Path != "" { + // Explicit path provided - resolve relative to working dir or parent dir + var basePath string + if filepath.IsAbs(toolset.Path) { + basePath = "" + } else if wd := runConfig.WorkingDir; wd != "" { + basePath = wd + } else { + basePath = parentDir + } + + var err error + validatedMemoryPath, err = path.ValidatePathInDirectory(toolset.Path, basePath) + if err != nil { + return nil, fmt.Errorf("invalid memory database path: %w", err) + } } else { - memoryPath = parentDir + // Default: ~/.cagent/memory//memory.db + if configName == "" { + configName = "default" + } + validatedMemoryPath = filepath.Join(paths.GetDataDir(), "memory", configName, "memory.db") } - validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath) - if err != nil { - return nil, fmt.Errorf("invalid memory database path: %w", err) - } if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil { return nil, fmt.Errorf("failed to create memory database directory: %w", err) } @@ -136,11 +151,11 @@ func createMemoryTool(_ context.Context, toolset latest.Toolset, parentDir strin return builtin.NewMemoryToolWithPath(db, validatedMemoryPath), nil } -func createThinkTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { +func createThinkTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { return builtin.NewThinkTool(), nil } -func createShellTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createShellTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), runConfig.EnvProvider()) if err != nil { return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err) @@ -150,7 +165,7 @@ func createShellTool(ctx context.Context, toolset latest.Toolset, _ string, runC return builtin.NewShellTool(env, runConfig), nil } -func createScriptTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createScriptTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { if len(toolset.Shell) == 0 { return nil, fmt.Errorf("shell is required for script toolset") } @@ -163,7 +178,7 @@ func createScriptTool(ctx context.Context, toolset latest.Toolset, _ string, run return builtin.NewScriptShellTool(toolset.Shell, env) } -func createFilesystemTool(_ context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createFilesystemTool(_ context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { wd := runConfig.WorkingDir if wd == "" { var err error @@ -197,7 +212,7 @@ func createFilesystemTool(_ context.Context, toolset latest.Toolset, _ string, r return builtin.NewFilesystemTool(wd, opts...), nil } -func createAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { if toolset.APIConfig.Endpoint == "" { return nil, fmt.Errorf("api tool requires an endpoint in api_config") } @@ -209,7 +224,7 @@ func createAPITool(ctx context.Context, toolset latest.Toolset, _ string, runCon return builtin.NewAPITool(toolset.APIConfig, expander), nil } -func createFetchTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { +func createFetchTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { var opts []builtin.FetchToolOption if toolset.Timeout > 0 { timeout := time.Duration(toolset.Timeout) * time.Second @@ -218,7 +233,7 @@ func createFetchTool(_ context.Context, toolset latest.Toolset, _ string, _ *con return builtin.NewFetchTool(opts...), nil } -func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { envProvider := runConfig.EnvProvider() switch { @@ -280,7 +295,7 @@ func createMCPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon } } -func createA2ATool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createA2ATool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { expander := js.NewJsExpander(runConfig.EnvProvider()) headers := expander.ExpandMap(ctx, toolset.Headers) @@ -288,7 +303,7 @@ func createA2ATool(ctx context.Context, toolset latest.Toolset, _ string, runCon return a2a.NewToolset(toolset.Name, toolset.URL, headers), nil } -func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { // Auto-install missing command binary if needed resolvedCommand, err := toolinstall.EnsureCommand(ctx, toolset.Command, toolset.Version) if err != nil { @@ -312,11 +327,11 @@ func createLSPTool(ctx context.Context, toolset latest.Toolset, _ string, runCon return tool, nil } -func createUserPromptTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { +func createUserPromptTool(_ context.Context, _ latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { return builtin.NewUserPromptTool(), nil } -func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { +func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, runConfig *config.RuntimeConfig, _ string) (tools.ToolSet, error) { expander := js.NewJsExpander(runConfig.EnvProvider()) specURL := expander.Expand(ctx, toolset.URL, nil) @@ -325,7 +340,7 @@ func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, ru return builtin.NewOpenAPITool(specURL, headers), nil } -func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) { +func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig, _ string) (tools.ToolSet, error) { if len(toolset.Models) == 0 { return nil, fmt.Errorf("model_picker toolset requires at least one model") } diff --git a/pkg/teamloader/registry_test.go b/pkg/teamloader/registry_test.go index 2e40d1cc4..350f81fbe 100644 --- a/pkg/teamloader/registry_test.go +++ b/pkg/teamloader/registry_test.go @@ -22,7 +22,7 @@ func TestCreateShellTool(t *testing.T) { EnvProviderForTests: environment.NewOsEnvProvider(), } - tool, err := registry.CreateTool(t.Context(), toolset, ".", runConfig) + tool, err := registry.CreateTool(t.Context(), toolset, ".", runConfig, "test-agent") require.NoError(t, err) require.NotNil(t, tool) } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index ba53b5e5c..e11660dfc 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -3,9 +3,12 @@ package teamloader import ( "cmp" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "log/slog" + "path/filepath" "strings" "sync" @@ -135,6 +138,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c // Create RAG managers parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir) + configName := configNameFromSource(agentSource.Name()) ragManagers, err := rag.NewManagers(ctx, cfg, rag.ManagersBuildConfig{ ParentDir: parentDir, ModelsGateway: runConfig.ModelsGateway, @@ -214,7 +218,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c ) } - agentTools, warnings := getToolsForAgent(ctx, &agentConfig, parentDir, runConfig, loadOpts.toolsetRegistry) + agentTools, warnings := getToolsForAgent(ctx, &agentConfig, parentDir, runConfig, loadOpts.toolsetRegistry, configName) if len(warnings) > 0 { opts = append(opts, agent.WithLoadTimeWarnings(warnings)) } @@ -421,7 +425,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates } // getToolsForAgent returns the tool definitions for an agent based on its configuration -func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry) ([]tools.ToolSet, []string) { +func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, runConfig *config.RuntimeConfig, registry *ToolsetRegistry, configName string) ([]tools.ToolSet, []string) { var ( toolSets []tools.ToolSet warnings []string @@ -432,7 +436,7 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri for i := range a.Toolsets { toolset := a.Toolsets[i] - tool, err := registry.CreateTool(ctx, toolset, parentDir, runConfig) + tool, err := registry.CreateTool(ctx, toolset, parentDir, runConfig, configName) if err != nil { // Collect error but continue loading other toolsets slog.Warn("Toolset configuration failed; skipping", "type", toolset.Type, "ref", toolset.Ref, "command", toolset.Command, "error", err) @@ -480,6 +484,24 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri return toolSets, warnings } +// configNameFromSource extracts a clean config name from a source name. +// The result is "-" where basename comes from the file name +// (e.g. "memory_agent" from "/path/to/memory_agent.yaml") and hash is a short +// SHA-256 of the full source name to prevent collisions between identically +// named configs in different directories. +func configNameFromSource(sourceName string) string { + base := filepath.Base(sourceName) + ext := filepath.Ext(base) + if ext != "" { + base = base[:len(base)-len(ext)] + } + if base == "" || base == "." || base == ".." { + base = "default" + } + h := sha256.Sum256([]byte(sourceName)) + return base + "-" + hex.EncodeToString(h[:4]) +} + // resolveAgentRefs resolves a list of agent references to agent instances. // References that match a locally-defined agent name are looked up directly. // References that are external (OCI or URL) are loaded on-demand and cached diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 05a85c3a8..736092aac 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -66,7 +66,7 @@ func TestGetToolsForAgent_ContinuesOnCreateToolError(t *testing.T) { EnvProviderForTests: &noEnvProvider{}, } - got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewToolsetRegistry()) + got, warnings := getToolsForAgent(t.Context(), a, ".", &runConfig, NewToolsetRegistry(), "test-config") require.Empty(t, got) require.NotEmpty(t, warnings) diff --git a/pkg/tools/builtin/memory.go b/pkg/tools/builtin/memory.go index 6d9e60326..23b5d9d7c 100644 --- a/pkg/tools/builtin/memory.go +++ b/pkg/tools/builtin/memory.go @@ -11,15 +11,19 @@ import ( ) const ( - ToolNameAddMemory = "add_memory" - ToolNameGetMemories = "get_memories" - ToolNameDeleteMemory = "delete_memory" + ToolNameAddMemory = "add_memory" + ToolNameGetMemories = "get_memories" + ToolNameDeleteMemory = "delete_memory" + ToolNameSearchMemories = "search_memories" + ToolNameUpdateMemory = "update_memory" ) type DB interface { AddMemory(ctx context.Context, memory database.UserMemory) error GetMemories(ctx context.Context) ([]database.UserMemory, error) DeleteMemory(ctx context.Context, memory database.UserMemory) error + SearchMemories(ctx context.Context, query, category string) ([]database.UserMemory, error) + UpdateMemory(ctx context.Context, memory database.UserMemory) error } type MemoryTool struct { @@ -58,18 +62,46 @@ func (t *MemoryTool) Describe() string { } type AddMemoryArgs struct { - Memory string `json:"memory" jsonschema:"The memory content to store"` + Memory string `json:"memory" jsonschema:"The memory content to store"` + Category string `json:"category,omitempty" jsonschema:"Optional category to organize the memory (e.g. preference, fact, project)"` } type DeleteMemoryArgs struct { ID string `json:"id" jsonschema:"The ID of the memory to delete"` } +type SearchMemoriesArgs struct { + Query string `json:"query,omitempty" jsonschema:"Keywords to search for in memory content (space-separated, all must match)"` + Category string `json:"category,omitempty" jsonschema:"Optional category to filter by"` +} + +type UpdateMemoryArgs struct { + ID string `json:"id" jsonschema:"The ID of the memory to update"` + Memory string `json:"memory" jsonschema:"The new memory content"` + Category string `json:"category,omitempty" jsonschema:"Optional new category for the memory"` +} + func (t *MemoryTool) Instructions() string { return `## Using the memory tool -Before taking any action or responding, use "get_memories" to recall stored information about the user. -Use the memory tool generously to remember things about the user. Do not mention using this tool.` +Before taking any action or responding, check stored memories for relevant context. +Use the memory tool generously to remember things about the user. Do not mention using this tool. + +### When to remember +- User preferences, corrections, and explicit requests to remember something +- Key facts, decisions, and context that may be useful in future conversations +- Project-specific conventions and patterns + +### Categories +Organize memories with a category when adding or updating (e.g. "preference", "fact", "project", "decision"). + +### Searching vs getting all +- Use "search_memories" with keywords and/or a category to find specific memories efficiently. +- Use "get_memories" only when you need a full dump of all stored memories. + +### Updating vs creating +- Use "update_memory" to edit an existing memory by ID instead of deleting and re-adding. +- Use "add_memory" only for genuinely new information.` } func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) { @@ -107,6 +139,29 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) { Title: "Delete Memory", }, }, + { + Name: ToolNameSearchMemories, + Category: "memory", + Description: "Search memories by keywords and/or category. More efficient than retrieving all memories.", + Parameters: tools.MustSchemaFor[SearchMemoriesArgs](), + OutputSchema: tools.MustSchemaFor[[]database.UserMemory](), + Handler: tools.NewHandler(t.handleSearchMemories), + Annotations: tools.ToolAnnotations{ + ReadOnlyHint: true, + Title: "Search Memories", + }, + }, + { + Name: ToolNameUpdateMemory, + Category: "memory", + Description: "Update an existing memory's content and/or category by ID", + Parameters: tools.MustSchemaFor[UpdateMemoryArgs](), + OutputSchema: tools.MustSchemaFor[string](), + Handler: tools.NewHandler(t.handleUpdateMemory), + Annotations: tools.ToolAnnotations{ + Title: "Update Memory", + }, + }, }, nil } @@ -115,6 +170,7 @@ func (t *MemoryTool) handleAddMemory(ctx context.Context, args AddMemoryArgs) (* ID: fmt.Sprintf("%d", time.Now().UnixNano()), CreatedAt: time.Now().Format(time.RFC3339), Memory: args.Memory, + Category: args.Category, } if err := t.db.AddMemory(ctx, memory); err != nil { @@ -149,3 +205,31 @@ func (t *MemoryTool) handleDeleteMemory(ctx context.Context, args DeleteMemoryAr return tools.ResultSuccess(fmt.Sprintf("Memory with ID %s deleted successfully", args.ID)), nil } + +func (t *MemoryTool) handleSearchMemories(ctx context.Context, args SearchMemoriesArgs) (*tools.ToolCallResult, error) { + memories, err := t.db.SearchMemories(ctx, args.Query, args.Category) + if err != nil { + return nil, fmt.Errorf("failed to search memories: %w", err) + } + + result, err := json.Marshal(memories) + if err != nil { + return nil, fmt.Errorf("failed to marshal memories: %w", err) + } + + return tools.ResultSuccess(string(result)), nil +} + +func (t *MemoryTool) handleUpdateMemory(ctx context.Context, args UpdateMemoryArgs) (*tools.ToolCallResult, error) { + memory := database.UserMemory{ + ID: args.ID, + Memory: args.Memory, + Category: args.Category, + } + + if err := t.db.UpdateMemory(ctx, memory); err != nil { + return nil, fmt.Errorf("failed to update memory: %w", err) + } + + return tools.ResultSuccess(fmt.Sprintf("Memory with ID %s updated successfully", args.ID)), nil +} diff --git a/pkg/tools/builtin/memory_test.go b/pkg/tools/builtin/memory_test.go index 42d8cdd13..07ae35ee1 100644 --- a/pkg/tools/builtin/memory_test.go +++ b/pkg/tools/builtin/memory_test.go @@ -34,12 +34,25 @@ func (m *MockDB) DeleteMemory(ctx context.Context, memory database.UserMemory) e return args.Error(0) } +func (m *MockDB) SearchMemories(ctx context.Context, query, category string) ([]database.UserMemory, error) { + args := m.Called(ctx, query, category) + return args.Get(0).([]database.UserMemory), args.Error(1) +} + +func (m *MockDB) UpdateMemory(ctx context.Context, memory database.UserMemory) error { + args := m.Called(ctx, memory) + return args.Error(0) +} + func TestMemoryTool_Instructions(t *testing.T) { manager := new(MockDB) tool := NewMemoryTool(manager) instructions := tool.Instructions() assert.Contains(t, instructions, "Using the memory tool") + assert.Contains(t, instructions, "search_memories") + assert.Contains(t, instructions, "update_memory") + assert.Contains(t, instructions, "Categories") } func TestMemoryTool_DisplayNames(t *testing.T) { @@ -71,6 +84,23 @@ func TestMemoryTool_HandleAddMemory(t *testing.T) { manager.AssertExpectations(t) } +func TestMemoryTool_HandleAddMemoryWithCategory(t *testing.T) { + manager := new(MockDB) + tool := NewMemoryTool(manager) + + manager.On("AddMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { + return memory.Memory == "prefers dark mode" && memory.Category == "preference" + })).Return(nil) + + result, err := tool.handleAddMemory(t.Context(), AddMemoryArgs{ + Memory: "prefers dark mode", + Category: "preference", + }) + require.NoError(t, err) + assert.Contains(t, result.Output, "Memory added successfully") + manager.AssertExpectations(t) +} + func TestMemoryTool_HandleGetMemories(t *testing.T) { manager := new(MockDB) tool := NewMemoryTool(manager) @@ -118,6 +148,61 @@ func TestMemoryTool_HandleDeleteMemory(t *testing.T) { manager.AssertExpectations(t) } +func TestMemoryTool_HandleSearchMemories(t *testing.T) { + manager := new(MockDB) + tool := NewMemoryTool(manager) + + memories := []database.UserMemory{ + { + ID: "1", + CreatedAt: time.Now().Format(time.RFC3339), + Memory: "User prefers dark mode", + Category: "preference", + }, + } + manager.On("SearchMemories", mock.Anything, "dark mode", "preference").Return(memories, nil) + + result, err := tool.handleSearchMemories(t.Context(), SearchMemoriesArgs{ + Query: "dark mode", + Category: "preference", + }) + require.NoError(t, err) + + var returnedMemories []database.UserMemory + err = json.Unmarshal([]byte(result.Output), &returnedMemories) + require.NoError(t, err) + + assert.Len(t, returnedMemories, 1) + assert.Equal(t, "User prefers dark mode", returnedMemories[0].Memory) + manager.AssertExpectations(t) +} + +func TestMemoryTool_HandleUpdateMemory(t *testing.T) { + manager := new(MockDB) + tool := NewMemoryTool(manager) + + manager.On("UpdateMemory", mock.Anything, mock.MatchedBy(func(memory database.UserMemory) bool { + return memory.ID == "42" && memory.Memory == "updated content" && memory.Category == "fact" + })).Return(nil) + + result, err := tool.handleUpdateMemory(t.Context(), UpdateMemoryArgs{ + ID: "42", + Memory: "updated content", + Category: "fact", + }) + require.NoError(t, err) + assert.Contains(t, result.Output, "Memory with ID 42 updated successfully") + manager.AssertExpectations(t) +} + +func TestMemoryTool_ToolCount(t *testing.T) { + tool := NewMemoryTool(nil) + + allTools, err := tool.Tools(t.Context()) + require.NoError(t, err) + assert.Len(t, allTools, 5, "Should have 5 tools: add, get, delete, search, update") +} + func TestMemoryTool_OutputSchema(t *testing.T) { tool := NewMemoryTool(nil)