From 79de65a8f5492779cef965051fd18672f8072be4 Mon Sep 17 00:00:00 2001 From: Stefan Vetter Date: Fri, 3 Apr 2026 14:12:18 +0200 Subject: [PATCH 1/3] feat: implement per-mode MCP server/tool filtering Adds granular per-mode control over which MCP servers and tools are available, replacing the previous all-or-nothing 'mcp' group behavior. Schema (packages/types/src/mode.ts): - Add mcpServerFilterSchema with disabled/allowedTools/disabledTools - Add mcpDefaultPolicy ('allow'|'deny') for deny-by-default support - Add mcpGroupOptionsSchema extending group options for MCP tuples - Add superRefine validation rejecting MCP options on non-mcp groups Core filtering (src/utils/mcp-filter.ts): - getMcpFilterForMode() resolves MCP config for a mode - isMcpServerAllowedForMode() checks server access with policy - isMcpToolAllowedForMode() checks tool access with allowlist/blocklist - Inlined getGroupName to avoid vscode import chain (ISSUE-16) Prerequisite fix (src/utils/mcp-name.ts): - Fix normalizeForComparison regex: /[-\s]+/g -> /[-\s]/g (ISSUE-10) - Add toLowerCase() for case-insensitive matching Prompt-level filtering: - filter-tools-for-mode.ts: Filter MCP tools from system prompt - build-tools.ts: Filter MCP tools for Gemini native function calling - ISSUE-19: Document native tools remain unfiltered for Gemini Execution-time guards: - validateToolUse.ts: Replace blanket MCP allow with filter checks - Server-level guard for use_mcp_tool/access_mcp_resource - Tool-level guard for dynamic mcp--server--tool names - ISSUE-21: Tool-level check for use_mcp_tool when tool_name available - presentAssistantMessage.ts mcp_tool_use: Add filter using cline.taskMode - ISSUE-17: validateToolUse call uses cline.taskMode (frozen at task start) instead of state.mode (live UI mode) - ISSUE-20: Remove dead ?? defaultModeSlug fallback Defense-in-depth: - UseMcpToolTool.execute(): Server + tool level filter before execution - AccessMcpResourceTool.execute(): Server level filter before execution - FLAG-E: Document 10-second TTL cache in CustomModesManager UI fix (webview-ui): - ModesView.tsx: Cache group tuple options on toggle-off, restore on toggle-on to prevent MCP config data loss (ISSUE-9/ISSUE-13) - Extract pure caching functions to groupOptionsCache.ts Tests: 89 new tests across 11 files, all passing --- .../src/__tests__/mcp-filter-schema.test.ts | 88 +++++++ packages/types/src/mode.ts | 92 ++++++- .../presentAssistantMessage.ts | 44 +++- .../prompts/tools/filter-tools-for-mode.ts | 19 +- src/core/task/build-tools.ts | 7 +- src/core/tools/UseMcpToolTool.ts | 42 +++ .../AccessMcpResourceTool-mcp-filter.test.ts | 130 +++++++++ .../UseMcpToolTool-mcp-filter.test.ts | 184 +++++++++++++ .../tools/__tests__/useMcpToolTool.spec.ts | 12 +- src/core/tools/accessMcpResourceTool.ts | 16 ++ src/core/tools/validateToolUse.ts | 60 ++++- .../presentAssistantMessage-issue17.test.ts | 178 +++++++++++++ .../presentAssistantMessage-mcp.test.ts | 210 +++++++++++++++ .../tools/filter-tools-for-mode-mcp.test.ts | 184 +++++++++++++ src/tests/core/task/build-tools-mcp.test.ts | 190 ++++++++++++++ .../core/tools/validateToolUse-mcp.test.ts | 168 ++++++++++++ src/tests/utils/mcp-filter.test.ts | 248 ++++++++++++++++++ src/tests/utils/mcp-name.test.ts | 31 +++ src/utils/mcp-filter.ts | 184 +++++++++++++ src/utils/mcp-name.ts | 13 +- .../__tests__/ModesView-groupChange.spec.tsx | 148 +++++++++++ webview-ui/src/components/modes/ModesView.tsx | 18 +- .../src/components/modes/groupOptionsCache.ts | 62 +++++ .../components/modes/useGroupOptionsCache.ts | 33 +++ 24 files changed, 2335 insertions(+), 26 deletions(-) create mode 100644 packages/types/src/__tests__/mcp-filter-schema.test.ts create mode 100644 src/core/tools/__tests__/AccessMcpResourceTool-mcp-filter.test.ts create mode 100644 src/core/tools/__tests__/UseMcpToolTool-mcp-filter.test.ts create mode 100644 src/tests/core/assistant-message/presentAssistantMessage-issue17.test.ts create mode 100644 src/tests/core/assistant-message/presentAssistantMessage-mcp.test.ts create mode 100644 src/tests/core/prompts/tools/filter-tools-for-mode-mcp.test.ts create mode 100644 src/tests/core/task/build-tools-mcp.test.ts create mode 100644 src/tests/core/tools/validateToolUse-mcp.test.ts create mode 100644 src/tests/utils/mcp-filter.test.ts create mode 100644 src/tests/utils/mcp-name.test.ts create mode 100644 src/utils/mcp-filter.ts create mode 100644 webview-ui/src/__tests__/ModesView-groupChange.spec.tsx create mode 100644 webview-ui/src/components/modes/groupOptionsCache.ts create mode 100644 webview-ui/src/components/modes/useGroupOptionsCache.ts diff --git a/packages/types/src/__tests__/mcp-filter-schema.test.ts b/packages/types/src/__tests__/mcp-filter-schema.test.ts new file mode 100644 index 00000000000..445420a3a48 --- /dev/null +++ b/packages/types/src/__tests__/mcp-filter-schema.test.ts @@ -0,0 +1,88 @@ +// npx vitest run src/__tests__/mcp-filter-schema.test.ts + +import { mcpServerFilterSchema, groupEntryArraySchema } from "../mode.js" + +describe("mcpServerFilterSchema", () => { + it("validates a valid filter with disabled: true", () => { + const result = mcpServerFilterSchema.safeParse({ disabled: true }) + expect(result.success).toBe(true) + }) + + it("validates a filter with allowedTools array", () => { + const result = mcpServerFilterSchema.safeParse({ + allowedTools: ["tool-a", "tool-b"], + }) + expect(result.success).toBe(true) + }) + + it("validates a filter with disabledTools array", () => { + const result = mcpServerFilterSchema.safeParse({ + disabledTools: ["tool-x"], + }) + expect(result.success).toBe(true) + }) + + it("rejects invalid shapes (wrong types)", () => { + const result = mcpServerFilterSchema.safeParse({ + disabled: "yes", + }) + expect(result.success).toBe(false) + }) + + it("rejects invalid shapes (allowedTools not array of strings)", () => { + const result = mcpServerFilterSchema.safeParse({ + allowedTools: [123, true], + }) + expect(result.success).toBe(false) + }) + + it("rejects completely invalid shape", () => { + const result = mcpServerFilterSchema.safeParse("not-an-object") + expect(result.success).toBe(false) + }) +}) + +describe("rawGroupEntryArraySchema with MCP filtering", () => { + it("rejects mcpServers on non-mcp groups", () => { + const result = groupEntryArraySchema.safeParse([["read", { mcpServers: {} }]]) + expect(result.success).toBe(false) + }) + + it("allows mcpServers on the mcp group", () => { + const result = groupEntryArraySchema.safeParse([["mcp", { mcpServers: { "server-name": { disabled: true } } }]]) + expect(result.success).toBe(true) + }) + + it("allows mcpDefaultPolicy on the mcp group", () => { + const result = groupEntryArraySchema.safeParse([["mcp", { mcpDefaultPolicy: "allow" }]]) + expect(result.success).toBe(true) + }) + + it("rejects mcpDefaultPolicy on non-mcp groups", () => { + const result = groupEntryArraySchema.safeParse([["edit", { mcpDefaultPolicy: "allow" }]]) + expect(result.success).toBe(false) + }) + + it("mcpDefaultPolicy only accepts allow or deny", () => { + const validAllow = groupEntryArraySchema.safeParse([["mcp", { mcpDefaultPolicy: "allow" }]]) + expect(validAllow.success).toBe(true) + + const validDeny = groupEntryArraySchema.safeParse([["mcp", { mcpDefaultPolicy: "deny" }]]) + expect(validDeny.success).toBe(true) + + const invalid = groupEntryArraySchema.safeParse([["mcp", { mcpDefaultPolicy: "block" }]]) + expect(invalid.success).toBe(false) + }) + + it("still allows plain string group entries", () => { + const result = groupEntryArraySchema.safeParse(["read", "edit", "mcp"]) + expect(result.success).toBe(true) + }) + + it("still allows tuple entries with standard options", () => { + const result = groupEntryArraySchema.safeParse([ + ["edit", { fileRegex: "\\.md$", description: "Markdown only" }], + ]) + expect(result.success).toBe(true) + }) +}) diff --git a/packages/types/src/mode.ts b/packages/types/src/mode.ts index f981ba7bf9a..3b729580505 100644 --- a/packages/types/src/mode.ts +++ b/packages/types/src/mode.ts @@ -2,6 +2,26 @@ import { z } from "zod" import { deprecatedToolGroups, toolGroupsSchema } from "./tool.js" +/** + * MCP Server Filter + */ + +export const mcpServerFilterSchema = z.object({ + disabled: z.boolean().optional(), + allowedTools: z.array(z.string()).optional(), + disabledTools: z.array(z.string()).optional(), +}) + +export type McpServerFilter = z.infer + +/** + * MCP Default Policy + */ + +export const mcpDefaultPolicySchema = z.enum(["allow", "deny"]) + +export type McpDefaultPolicy = z.infer + /** * GroupOptions */ @@ -30,11 +50,31 @@ export const groupOptionsSchema = z.object({ export type GroupOptions = z.infer +/** + * MCP Group Options - extends GroupOptions with MCP-specific fields + */ + +export const mcpGroupOptionsSchema = groupOptionsSchema.extend({ + mcpServers: z.record(z.string(), mcpServerFilterSchema).optional(), + mcpDefaultPolicy: mcpDefaultPolicySchema.optional(), +}) + +export type McpGroupOptions = z.infer + +/** + * Non-MCP tool groups for use in tuple entries with standard options. + */ +const nonMcpToolGroupSchema = toolGroupsSchema.exclude(["mcp"]) + /** * GroupEntry */ -export const groupEntrySchema = z.union([toolGroupsSchema, z.tuple([toolGroupsSchema, groupOptionsSchema])]) +export const groupEntrySchema = z.union([ + toolGroupsSchema, + z.tuple([nonMcpToolGroupSchema, groupOptionsSchema]), + z.tuple([z.literal("mcp"), mcpGroupOptionsSchema]), +]) export type GroupEntry = z.infer @@ -56,6 +96,23 @@ function isDeprecatedGroupEntry(entry: unknown): boolean { return false } +/** + * Checks if a raw group entry tuple contains MCP-specific options. + */ +function hasMcpOptions(entry: unknown): boolean { + if (!Array.isArray(entry) || entry.length < 2) { + return false + } + + const opts = entry[1] + + if (typeof opts !== "object" || opts === null) { + return false + } + + return "mcpServers" in opts || "mcpDefaultPolicy" in opts +} + /** * Raw schema for validating group entries after deprecated groups are stripped. */ @@ -83,15 +140,40 @@ const rawGroupEntryArraySchema = z.array(groupEntrySchema).refine( * tool groups (e.g., "browser") before validation, ensuring backward compatibility * with older user configs. * + * Also validates that MCP-specific options (mcpServers, mcpDefaultPolicy) + * only appear on the "mcp" group via superRefine on raw input. + * * The type assertion to `z.ZodType` is * required because `z.preprocess` erases the input type to `unknown`, which * propagates through `modeConfigSchema → rooCodeSettingsSchema → createRunSchema` * and breaks `zodResolver` generic inference in downstream consumers (e.g., web-evals). */ -export const groupEntryArraySchema = z.preprocess((val) => { - if (!Array.isArray(val)) return val - return val.filter((entry) => !isDeprecatedGroupEntry(entry)) -}, rawGroupEntryArraySchema) as z.ZodType +export const groupEntryArraySchema = z.preprocess( + (val) => { + if (!Array.isArray(val)) return val + return val.filter((entry) => !isDeprecatedGroupEntry(entry)) + }, + z + .array(z.any()) + .superRefine((entries, ctx) => { + for (let i = 0; i < entries.length; i++) { + const entry = entries[i] + + if (hasMcpOptions(entry)) { + const groupName = Array.isArray(entry) ? entry[0] : entry + + if (groupName !== "mcp") { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'mcpServers and mcpDefaultPolicy are only allowed on the "mcp" group', + path: [i], + }) + } + } + } + }) + .pipe(rawGroupEntryArraySchema), +) as z.ZodType export const modeConfigSchema = z.object({ slug: z.string().regex(/^[a-zA-Z0-9-]+$/, "Slug must contain only letters numbers and dashes"), diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 7f5862be154..1275e148a4c 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -40,6 +40,25 @@ import { codebaseSearchTool } from "../tools/CodebaseSearchTool" import { formatResponse } from "../prompts/responses" import { sanitizeToolUseId } from "../../utils/tool-id" +import { isMcpToolAllowedForMode } from "../../utils/mcp-filter" +import type { ModeConfig } from "@roo-code/types" + +/** + * Step 5b: Check whether an MCP tool call is allowed for a given mode. + * + * This is a thin wrapper around isMcpToolAllowedForMode that makes the + * intent explicit and is easily testable in isolation. Called from the + * mcp_tool_use case in presentAssistantMessage with cline.taskMode + * (the mode frozen at task start), NOT the current UI mode. + */ +export function shouldAllowMcpToolUse( + serverName: string, + toolName: string, + modeSlug: string, + customModes?: ModeConfig[], +): boolean { + return isMcpToolAllowedForMode(serverName, toolName, modeSlug, customModes) +} /** * Processes and presents assistant message content to the user interface. @@ -250,6 +269,26 @@ export async function presentAssistantMessage(cline: Task) { } } + // Step 5b: MCP tool filtering using frozen task mode + if (!mcpBlock.partial) { + const taskCustomModes = await cline.providerRef.deref()?.customModesManager.getCustomModes() + // FLAG-E: getCustomModes() uses a 10-second TTL cache, no disk I/O on each call + if (!shouldAllowMcpToolUse(resolvedServerName, mcpBlock.toolName, cline.taskMode, taskCustomModes)) { + const errorMsg = + "MCP tool " + + resolvedServerName + + "/" + + mcpBlock.toolName + + " is not allowed in the current mode (" + + cline.taskMode + + ")." + await cline.say("error", errorMsg) + pushToolResult(formatResponse.toolError(errorMsg)) + cline.didRejectTool = true + break + } + } + // Execute the MCP tool using the same handler as use_mcp_tool // Create a synthetic ToolUse block that the useMcpToolTool can handle const syntheticToolUse: ToolUse<"use_mcp_tool"> = { @@ -594,9 +633,12 @@ export async function presentAssistantMessage(cline: Task) { {} as Record, ) ?? {} + // ISSUE-17: Use cline.taskMode (frozen at task start) instead of state.mode (live UI mode). + // This ensures permissions are locked to the mode active when the task started. + // ISSUE-20: No ?? defaultModeSlug fallback needed — cline.taskMode throws if uninitialized. validateToolUse( block.name as ToolName, - mode ?? defaultModeSlug, + cline.taskMode, customModes ?? [], toolRequirements, block.params, diff --git a/src/core/prompts/tools/filter-tools-for-mode.ts b/src/core/prompts/tools/filter-tools-for-mode.ts index fdd41e7e330..c195f571a56 100644 --- a/src/core/prompts/tools/filter-tools-for-mode.ts +++ b/src/core/prompts/tools/filter-tools-for-mode.ts @@ -6,6 +6,8 @@ import { defaultModeSlug } from "../../../shared/modes" import type { CodeIndexManager } from "../../../services/code-index/manager" import type { McpHub } from "../../../services/mcp/McpHub" import { isToolAllowedForMode } from "../../../core/tools/validateToolUse" +import { parseMcpToolName } from "../../../utils/mcp-name" +import { isMcpToolAllowedForMode } from "../../../utils/mcp-filter" /** * Reverse lookup map - maps alias name to canonical tool name. @@ -452,5 +454,20 @@ export function filterMcpToolsForMode( experiments ?? {}, ) - return isMcpAllowed ? mcpTools : [] + if (!isMcpAllowed) { + return [] + } + + // Apply per-server / per-tool MCP filtering + return mcpTools.filter((tool) => { + if (!("function" in tool) || !tool.function) { + return true + } + const parsed = parseMcpToolName(tool.function.name) + if (!parsed) { + // Not an MCP tool name — pass through unchanged + return true + } + return isMcpToolAllowedForMode(parsed.serverName, parsed.toolName, modeSlug, customModes) + }) } diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index c32d8f6f9b2..d46a8b08015 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -147,8 +147,11 @@ export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsO // If includeAllToolsWithRestrictions is true, return ALL tools but provide // allowed names based on mode filtering if (includeAllToolsWithRestrictions) { - // Combine ALL tools (unfiltered native + all MCP + custom) - const allTools = [...nativeTools, ...mcpTools, ...nativeCustomTools] + // ISSUE-19: Native tools remain unfiltered in Gemini's tool list. + // Gemini uses allowedFunctionNames to restrict callable tools at the API level. + // MCP tools are filtered here for consistency with the prompt-level filter, + // since they are dynamically generated and not covered by allowedFunctionNames. + const allTools = [...nativeTools, ...filteredMcpTools, ...nativeCustomTools] // Extract names of tools that are allowed based on mode filtering. // Resolve any alias names to canonical names to ensure consistency with allTools diff --git a/src/core/tools/UseMcpToolTool.ts b/src/core/tools/UseMcpToolTool.ts index 7cbc09bfd7b..9424a1150af 100644 --- a/src/core/tools/UseMcpToolTool.ts +++ b/src/core/tools/UseMcpToolTool.ts @@ -5,6 +5,7 @@ import { formatResponse } from "../prompts/responses" import { t } from "../../i18n" import type { ToolUse } from "../../shared/tools" import { toolNamesMatch } from "../../utils/mcp-name" +import { isMcpServerAllowedForMode, isMcpToolAllowedForMode } from "../../utils/mcp-filter" import { BaseTool, ToolCallbacks } from "./BaseTool" @@ -38,6 +39,47 @@ export class UseMcpToolTool extends BaseTool<"use_mcp_tool"> { const { serverName, toolName, parsedArguments } = validation + // Defense-in-depth: check MCP server/tool filtering for the current mode. + // FLAG-E: 10-second TTL cache, no disk I/O per call + const customModes = await task.providerRef.deref()?.customModesManager?.getCustomModes() + if (!isMcpServerAllowedForMode(serverName, task.taskMode, customModes)) { + task.consecutiveMistakeCount++ + task.recordToolError("use_mcp_tool") + await task.say("error", 'MCP server "' + serverName + '" is not allowed in ' + task.taskMode + " mode") + pushToolResult( + formatResponse.toolError( + 'MCP server "' + serverName + '" is not allowed in ' + task.taskMode + " mode", + ), + ) + return + } + if (!isMcpToolAllowedForMode(serverName, toolName, task.taskMode, customModes)) { + task.consecutiveMistakeCount++ + task.recordToolError("use_mcp_tool") + await task.say( + "error", + 'MCP tool "' + + toolName + + '" on server "' + + serverName + + '" is not allowed in ' + + task.taskMode + + " mode", + ) + pushToolResult( + formatResponse.toolError( + 'MCP tool "' + + toolName + + '" on server "' + + serverName + + '" is not allowed in ' + + task.taskMode + + " mode", + ), + ) + return + } + // Validate that the tool exists on the server const toolValidation = await this.validateToolExists(task, serverName, toolName, pushToolResult) if (!toolValidation.isValid) { diff --git a/src/core/tools/__tests__/AccessMcpResourceTool-mcp-filter.test.ts b/src/core/tools/__tests__/AccessMcpResourceTool-mcp-filter.test.ts new file mode 100644 index 00000000000..46c6df7ded0 --- /dev/null +++ b/src/core/tools/__tests__/AccessMcpResourceTool-mcp-filter.test.ts @@ -0,0 +1,130 @@ +// npx vitest run core/tools/__tests__/AccessMcpResourceTool-mcp-filter.test.ts + +import { accessMcpResourceTool } from "../accessMcpResourceTool" +import { Task } from "../../task/Task" + +// Mock mcp-filter functions +vi.mock("../../../utils/mcp-filter", () => ({ + isMcpServerAllowedForMode: vi.fn().mockReturnValue(true), + isMcpToolAllowedForMode: vi.fn().mockReturnValue(true), +})) + +import { isMcpServerAllowedForMode } from "../../../utils/mcp-filter" + +// Mock formatResponse +vi.mock("../../prompts/responses", () => ({ + formatResponse: { + toolResult: vi.fn((result: string) => "Tool result: " + result), + toolError: vi.fn((error: string) => "Tool error: " + error), + toolDenied: vi.fn(() => "Tool denied"), + }, +})) + +describe("AccessMcpResourceTool - MCP filter defense-in-depth", () => { + let mockTask: Partial + let mockAskApproval: ReturnType + let mockHandleError: ReturnType + let mockPushToolResult: ReturnType + let mockProviderRef: any + + beforeEach(() => { + vi.clearAllMocks() + + mockAskApproval = vi.fn().mockResolvedValue(true) + mockHandleError = vi.fn() + mockPushToolResult = vi.fn() + + mockProviderRef = { + deref: vi.fn().mockReturnValue({ + customModesManager: { + getCustomModes: vi.fn().mockResolvedValue([]), + }, + getMcpHub: vi.fn().mockReturnValue({ + readResource: vi.fn().mockResolvedValue({ + contents: [{ text: "resource content" }], + }), + }), + }), + } + + mockTask = { + consecutiveMistakeCount: 0, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn(), + say: vi.fn(), + ask: vi.fn(), + providerRef: mockProviderRef, + taskMode: "code", + } + + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + }) + + it("should proceed when server is allowed", async () => { + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + + await accessMcpResourceTool.execute( + { + server_name: "test-server", + uri: "test://resource", + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + // Should have proceeded to approval + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("should block execution when server is disabled", async () => { + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(false) + + await accessMcpResourceTool.execute( + { + server_name: "blocked-server", + uri: "test://resource", + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + // Should NOT proceed to approval + expect(mockAskApproval).not.toHaveBeenCalled() + // Should push an error result containing the server name + expect(mockPushToolResult).toHaveBeenCalled() + const pushArg = mockPushToolResult.mock.calls[0][0] as string + expect(pushArg).toContain("not allowed") + expect(pushArg).toContain("blocked-server") + }) + + it("should use task.taskMode for the mode check", async () => { + Object.defineProperty(mockTask, "taskMode", { + get: () => "architect", + configurable: true, + }) + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + + await accessMcpResourceTool.execute( + { + server_name: "test-server", + uri: "test://resource", + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + expect(isMcpServerAllowedForMode).toHaveBeenCalledWith("test-server", "architect", expect.anything()) + }) +}) diff --git a/src/core/tools/__tests__/UseMcpToolTool-mcp-filter.test.ts b/src/core/tools/__tests__/UseMcpToolTool-mcp-filter.test.ts new file mode 100644 index 00000000000..906b2f658c1 --- /dev/null +++ b/src/core/tools/__tests__/UseMcpToolTool-mcp-filter.test.ts @@ -0,0 +1,184 @@ +// npx vitest run core/tools/__tests__/UseMcpToolTool-mcp-filter.test.ts + +import { useMcpToolTool } from "../UseMcpToolTool" +import { Task } from "../../task/Task" + +// Mock mcp-filter functions +vi.mock("../../../utils/mcp-filter", () => ({ + isMcpServerAllowedForMode: vi.fn().mockReturnValue(true), + isMcpToolAllowedForMode: vi.fn().mockReturnValue(true), +})) + +import { isMcpServerAllowedForMode, isMcpToolAllowedForMode } from "../../../utils/mcp-filter" + +// Mock formatResponse +vi.mock("../../prompts/responses", () => ({ + formatResponse: { + toolResult: vi.fn((result: string) => "Tool result: " + result), + toolError: vi.fn((error: string) => "Tool error: " + error), + toolDenied: vi.fn(() => "Tool denied"), + invalidMcpToolArgumentError: vi.fn((server: string, tool: string) => "Invalid args for " + server + ":" + tool), + unknownMcpToolError: vi.fn( + (server: string, tool: string, available: string[]) => "Tool '" + tool + "' not found on '" + server + "'", + ), + unknownMcpServerError: vi.fn((server: string, available: string[]) => "Server '" + server + "' not configured"), + }, +})) + +vi.mock("../../../i18n", () => ({ + t: vi.fn((key: string) => key), +})) + +describe("UseMcpToolTool - MCP filter defense-in-depth", () => { + let mockTask: Partial + let mockAskApproval: ReturnType + let mockHandleError: ReturnType + let mockPushToolResult: ReturnType + let mockProviderRef: any + + beforeEach(() => { + vi.clearAllMocks() + + mockAskApproval = vi.fn().mockResolvedValue(true) + mockHandleError = vi.fn() + mockPushToolResult = vi.fn() + + mockProviderRef = { + deref: vi.fn().mockReturnValue({ + customModesManager: { + getCustomModes: vi.fn().mockResolvedValue([]), + }, + getMcpHub: vi.fn().mockReturnValue({ + callTool: vi.fn().mockResolvedValue({ + content: [{ type: "text", text: "result" }], + isError: false, + }), + getAllServers: vi.fn().mockReturnValue([ + { + name: "test-server", + tools: [{ name: "test-tool", enabledForPrompt: true }], + }, + ]), + }), + postMessageToWebview: vi.fn(), + }), + } + + mockTask = { + consecutiveMistakeCount: 0, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn(), + say: vi.fn(), + ask: vi.fn(), + lastMessageTs: 123456789, + providerRef: mockProviderRef, + taskMode: "code", + didToolFailInCurrentTurn: false, + } + + // Default: allow everything + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + vi.mocked(isMcpToolAllowedForMode).mockReturnValue(true) + }) + + it("should proceed when server is allowed", async () => { + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + vi.mocked(isMcpToolAllowedForMode).mockReturnValue(true) + + await useMcpToolTool.execute( + { + server_name: "test-server", + tool_name: "test-tool", + arguments: { key: "value" }, + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + // Should NOT have been blocked — askApproval should have been called + expect(mockAskApproval).toHaveBeenCalled() + }) + + it("should block execution when server is disabled", async () => { + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(false) + + await useMcpToolTool.execute( + { + server_name: "blocked-server", + tool_name: "some-tool", + arguments: {}, + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + // Should NOT proceed to approval + expect(mockAskApproval).not.toHaveBeenCalled() + // Should push an error result containing the server name + expect(mockPushToolResult).toHaveBeenCalled() + const pushArg = mockPushToolResult.mock.calls[0][0] as string + expect(pushArg).toContain("not allowed") + expect(pushArg).toContain("blocked-server") + }) + + it("should block execution when tool is in disabledTools", async () => { + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + vi.mocked(isMcpToolAllowedForMode).mockReturnValue(false) + + await useMcpToolTool.execute( + { + server_name: "test-server", + tool_name: "disabled-tool", + arguments: {}, + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + // Should NOT proceed to approval + expect(mockAskApproval).not.toHaveBeenCalled() + // Should push an error result containing the tool name + expect(mockPushToolResult).toHaveBeenCalled() + const pushArg = mockPushToolResult.mock.calls[0][0] as string + expect(pushArg).toContain("not allowed") + expect(pushArg).toContain("disabled-tool") + }) + + it("should use task.taskMode for the mode check", async () => { + // Set a specific mode + Object.defineProperty(mockTask, "taskMode", { + get: () => "architect", + configurable: true, + }) + vi.mocked(isMcpServerAllowedForMode).mockReturnValue(true) + vi.mocked(isMcpToolAllowedForMode).mockReturnValue(true) + + await useMcpToolTool.execute( + { + server_name: "test-server", + tool_name: "test-tool", + arguments: {}, + }, + mockTask as Task, + { + askApproval: mockAskApproval, + handleError: mockHandleError, + pushToolResult: mockPushToolResult, + }, + ) + + expect(isMcpServerAllowedForMode).toHaveBeenCalledWith("test-server", "architect", expect.anything()) + }) +}) diff --git a/src/core/tools/__tests__/useMcpToolTool.spec.ts b/src/core/tools/__tests__/useMcpToolTool.spec.ts index 5ee826774f4..06a284e6e1a 100644 --- a/src/core/tools/__tests__/useMcpToolTool.spec.ts +++ b/src/core/tools/__tests__/useMcpToolTool.spec.ts @@ -272,7 +272,8 @@ describe("useMcpToolTool", () => { } // Ensure server/tool validation passes so we actually reach askApproval. - mockProviderRef.deref.mockReturnValueOnce({ + const mockProvider = { + customModesManager: { getCustomModes: vi.fn().mockResolvedValue([]) }, getMcpHub: () => ({ getAllServers: vi .fn() @@ -282,7 +283,8 @@ describe("useMcpToolTool", () => { callTool: vi.fn(), }), postMessageToWebview: vi.fn(), - }) + } + mockProviderRef.deref.mockReturnValue(mockProvider) mockAskApproval.mockResolvedValue(false) @@ -315,7 +317,8 @@ describe("useMcpToolTool", () => { } // Ensure validation passes so askApproval is reached and throws - mockProviderRef.deref.mockReturnValueOnce({ + const mockProvider = { + customModesManager: { getCustomModes: vi.fn().mockResolvedValue([]) }, getMcpHub: () => ({ getAllServers: vi .fn() @@ -325,7 +328,8 @@ describe("useMcpToolTool", () => { callTool: vi.fn(), }), postMessageToWebview: vi.fn(), - }) + } + mockProviderRef.deref.mockReturnValue(mockProvider) const error = new Error("Unexpected error") mockAskApproval.mockRejectedValue(error) diff --git a/src/core/tools/accessMcpResourceTool.ts b/src/core/tools/accessMcpResourceTool.ts index 9df3b2256c5..bdf784fd63d 100644 --- a/src/core/tools/accessMcpResourceTool.ts +++ b/src/core/tools/accessMcpResourceTool.ts @@ -3,6 +3,7 @@ import type { ClineAskUseMcpServer } from "@roo-code/types" import type { ToolUse } from "../../shared/tools" import { Task } from "../task/Task" import { formatResponse } from "../prompts/responses" +import { isMcpServerAllowedForMode } from "../../utils/mcp-filter" import { BaseTool, ToolCallbacks } from "./BaseTool" @@ -33,6 +34,21 @@ export class AccessMcpResourceTool extends BaseTool<"access_mcp_resource"> { return } + // Defense-in-depth: check MCP server filtering for the current mode. + // FLAG-E: 10-second TTL cache, no disk I/O per call + const customModes = await task.providerRef.deref()?.customModesManager?.getCustomModes() + if (!isMcpServerAllowedForMode(server_name, task.taskMode, customModes)) { + task.consecutiveMistakeCount++ + task.recordToolError("access_mcp_resource") + await task.say("error", 'MCP server "' + server_name + '" is not allowed in ' + task.taskMode + " mode") + pushToolResult( + formatResponse.toolError( + 'MCP server "' + server_name + '" is not allowed in ' + task.taskMode + " mode", + ), + ) + return + } + task.consecutiveMistakeCount = 0 const completeMessage = JSON.stringify({ diff --git a/src/core/tools/validateToolUse.ts b/src/core/tools/validateToolUse.ts index 243a170ed90..ae9954ed0a1 100644 --- a/src/core/tools/validateToolUse.ts +++ b/src/core/tools/validateToolUse.ts @@ -5,6 +5,8 @@ import { customToolRegistry } from "@roo-code/core" import { type Mode, FileRestrictionError, getModeBySlug, getGroupName } from "../../shared/modes" import { EXPERIMENT_IDS } from "../../shared/experiments" import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS, TOOL_ALIASES } from "../../shared/tools" +import { isMcpServerAllowedForMode, isMcpToolAllowedForMode } from "../../utils/mcp-filter" +import { isMcpTool, parseMcpToolName } from "../../utils/mcp-name" /** * Checks if a tool name is a valid, known tool. @@ -21,8 +23,8 @@ export function isValidToolName(toolName: string, experiments?: Record ({ + validateToolUse: vi.fn(), + isValidToolName: vi.fn(() => true), +})) + +vi.mock("@roo-code/core", () => ({ + customToolRegistry: { + has: vi.fn().mockReturnValue(false), + get: vi.fn().mockReturnValue(undefined), + }, +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureToolUsage: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + captureEvent: vi.fn(), + }, + }, +})) + +describe("ISSUE-17: validateToolUse uses cline.taskMode", () => { + let mockTask: any + + beforeEach(() => { + vi.clearAllMocks() + + mockTask = { + taskId: "test-task-id", + instanceId: "test-instance", + abort: false, + presentAssistantMessageLocked: false, + presentAssistantMessageHasPendingUpdates: false, + currentStreamingContentIndex: 0, + assistantMessageContent: [], + userMessageContent: [], + didCompleteReadingStream: false, + didRejectTool: false, + didAlreadyUseTool: false, + consecutiveMistakeCount: 0, + clineMessages: [], + api: { + getModel: () => ({ id: "test-model", info: {} }), + }, + recordToolUsage: vi.fn(), + recordToolError: vi.fn(), + toolRepetitionDetector: { + check: vi.fn().mockReturnValue({ allowExecution: true }), + }, + // state.mode is 'code' (the live UI mode) + providerRef: { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: {}, + disabledTools: [], + }), + }), + }, + say: vi.fn().mockResolvedValue(undefined), + ask: vi.fn().mockResolvedValue({ response: "yesButtonClicked" }), + // ISSUE-17: taskMode is 'architect' (frozen at task start) + taskMode: "architect", + } + + mockTask.pushToolResultToUserContent = vi.fn().mockImplementation((toolResult: any) => { + const existing = mockTask.userMessageContent.find( + (b: any) => b.type === "tool_result" && b.tool_use_id === toolResult.tool_use_id, + ) + if (existing) { + return false + } + mockTask.userMessageContent.push(toolResult) + return true + }) + }) + + it("should pass cline.taskMode to validateToolUse, not state.mode", async () => { + const toolCallId = "issue17-test-001" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "read_file", + params: { path: "test.txt" }, + nativeArgs: { path: "test.txt" }, + partial: false, + }, + ] + + await presentAssistantMessage(mockTask) + + const validateMock = vi.mocked(validateToolUse) + expect(validateMock).toHaveBeenCalled() + + // Second argument (index 1) is the mode parameter + const modeArg = validateMock.mock.calls[0][1] + expect(modeArg).toBe("architect") + }) + + it('should use taskMode="architect" even when state.mode="code"', async () => { + // state.mode is 'code', but cline.taskMode is 'architect' + const toolCallId = "issue17-test-002" + mockTask.taskMode = "architect" + mockTask.providerRef = { + deref: () => ({ + getState: vi.fn().mockResolvedValue({ + mode: "code", + customModes: [], + experiments: {}, + disabledTools: [], + }), + }), + } + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "read_file", + params: { path: "test.txt" }, + nativeArgs: { path: "test.txt" }, + partial: false, + }, + ] + + await presentAssistantMessage(mockTask) + + const validateMock = vi.mocked(validateToolUse) + expect(validateMock).toHaveBeenCalled() + // Must be 'architect' (from taskMode), NOT 'code' (from state) + expect(validateMock.mock.calls[0][1]).toBe("architect") + }) + + it("should NOT have a defaultModeSlug fallback (ISSUE-20)", async () => { + // cline.taskMode always returns a string (throws if uninitialized). + // So the call must be cline.taskMode directly, with no ?? fallback. + const toolCallId = "issue17-test-003" + mockTask.taskMode = "debug" + mockTask.assistantMessageContent = [ + { + type: "tool_use", + id: toolCallId, + name: "read_file", + params: { path: "test.txt" }, + nativeArgs: { path: "test.txt" }, + partial: false, + }, + ] + + await presentAssistantMessage(mockTask) + + const validateMock = vi.mocked(validateToolUse) + expect(validateMock).toHaveBeenCalled() + // The mode arg should be exactly 'debug', proving no fallback + expect(validateMock.mock.calls[0][1]).toBe("debug") + }) +}) diff --git a/src/tests/core/assistant-message/presentAssistantMessage-mcp.test.ts b/src/tests/core/assistant-message/presentAssistantMessage-mcp.test.ts new file mode 100644 index 00000000000..50100780391 --- /dev/null +++ b/src/tests/core/assistant-message/presentAssistantMessage-mcp.test.ts @@ -0,0 +1,210 @@ +/** + * Tests for Step 5b: MCP tool_use filter in presentAssistantMessage. + * + * Tests the shouldAllowMcpToolUse helper that gates MCP tool execution + * in the mcp_tool_use case block. Verifies that: + * - Allowed MCP tools proceed normally + * - Blocked MCP servers are rejected + * - Blocked MCP tools (tool-level filter) are rejected + * - The check uses cline.taskMode (frozen at task start), NOT state.mode + */ + +import type { ModeConfig } from "@roo-code/types" +import { shouldAllowMcpToolUse } from "../../../core/assistant-message/presentAssistantMessage" +import * as mcpFilter from "../../../utils/mcp-filter" + +describe("shouldAllowMcpToolUse", () => { + it("should return true when isMcpToolAllowedForMode returns true", () => { + const result = shouldAllowMcpToolUse("my-server", "my-tool", "code", undefined) + // With no custom modes and default built-in modes, + // mcp group is present in code mode so all tools are allowed + expect(result).toBe(true) + }) + + it("should return false when the MCP server is blocked for the mode", () => { + const customModes: ModeConfig[] = [ + { + slug: "restricted", + name: "Restricted", + roleDefinition: "A restricted mode", + groups: [ + [ + "mcp", + { + mcpServers: { + "blocked-server": { disabled: true }, + }, + }, + ], + ], + }, + ] + + const result = shouldAllowMcpToolUse("blocked-server", "any-tool", "restricted", customModes) + expect(result).toBe(false) + }) + + it("should return false when a specific tool is blocked via disabledTools", () => { + const customModes: ModeConfig[] = [ + { + slug: "tool-restricted", + name: "Tool Restricted", + roleDefinition: "A mode with tool-level restrictions", + groups: [ + [ + "mcp", + { + mcpServers: { + "my-server": { + disabled: false, + disabledTools: ["secret-tool"], + }, + }, + }, + ], + ], + }, + ] + + const result = shouldAllowMcpToolUse("my-server", "secret-tool", "tool-restricted", customModes) + expect(result).toBe(false) + }) + + it("should return true for a non-disabled tool on same server", () => { + const customModes: ModeConfig[] = [ + { + slug: "tool-restricted", + name: "Tool Restricted", + roleDefinition: "A mode with tool-level restrictions", + groups: [ + [ + "mcp", + { + mcpServers: { + "my-server": { + disabled: false, + disabledTools: ["secret-tool"], + }, + }, + }, + ], + ], + }, + ] + + const result = shouldAllowMcpToolUse("my-server", "public-tool", "tool-restricted", customModes) + expect(result).toBe(true) + }) + + it("should return false when tool not in allowedTools list", () => { + const customModes: ModeConfig[] = [ + { + slug: "allowlist-mode", + name: "Allowlist Mode", + roleDefinition: "A mode with allowedTools", + groups: [ + [ + "mcp", + { + mcpServers: { + "my-server": { + disabled: false, + allowedTools: ["tool-a", "tool-b"], + }, + }, + }, + ], + ], + }, + ] + + const result = shouldAllowMcpToolUse("my-server", "tool-c", "allowlist-mode", customModes) + expect(result).toBe(false) + }) + + it("should return true when tool is in allowedTools list", () => { + const customModes: ModeConfig[] = [ + { + slug: "allowlist-mode", + name: "Allowlist Mode", + roleDefinition: "A mode with allowedTools", + groups: [ + [ + "mcp", + { + mcpServers: { + "my-server": { + disabled: false, + allowedTools: ["tool-a", "tool-b"], + }, + }, + }, + ], + ], + }, + ] + + const result = shouldAllowMcpToolUse("my-server", "tool-a", "allowlist-mode", customModes) + expect(result).toBe(true) + }) + + it("should use the provided modeSlug (taskMode), not derive it from state", () => { + // This test verifies the function signature accepts modeSlug directly. + // The caller (presentAssistantMessage) passes cline.taskMode, not state.mode. + // We verify by passing a mode slug that blocks the server vs one that allows it. + + const customModes: ModeConfig[] = [ + { + slug: "strict-mode", + name: "Strict", + roleDefinition: "Strict mode", + groups: [ + [ + "mcp", + { + mcpServers: { + "test-server": { disabled: true }, + }, + }, + ], + ], + }, + { + slug: "lax-mode", + name: "Lax", + roleDefinition: "Lax mode", + groups: ["mcp"], + }, + ] + + // strict-mode blocks test-server + expect(shouldAllowMcpToolUse("test-server", "any-tool", "strict-mode", customModes)).toBe(false) + + // lax-mode allows everything (plain 'mcp' string = no filtering) + expect(shouldAllowMcpToolUse("test-server", "any-tool", "lax-mode", customModes)).toBe(true) + }) + + it("should delegate to isMcpToolAllowedForMode from mcp-filter", () => { + const spy = vi.spyOn(mcpFilter, "isMcpToolAllowedForMode") + + shouldAllowMcpToolUse("srv", "tl", "code", undefined) + + expect(spy).toHaveBeenCalledWith("srv", "tl", "code", undefined) + spy.mockRestore() + }) + + it("should return true when mode has no mcp group at all", () => { + const customModes: ModeConfig[] = [ + { + slug: "no-mcp", + name: "No MCP", + roleDefinition: "Mode without mcp group", + groups: ["read", "edit"], + }, + ] + + const result = shouldAllowMcpToolUse("any-server", "any-tool", "no-mcp", customModes) + // No mcp group → no filtering → allow + expect(result).toBe(true) + }) +}) diff --git a/src/tests/core/prompts/tools/filter-tools-for-mode-mcp.test.ts b/src/tests/core/prompts/tools/filter-tools-for-mode-mcp.test.ts new file mode 100644 index 00000000000..7c497570aa2 --- /dev/null +++ b/src/tests/core/prompts/tools/filter-tools-for-mode-mcp.test.ts @@ -0,0 +1,184 @@ +import type OpenAI from "openai" +import type { ModeConfig } from "@roo-code/types" + +import { filterMcpToolsForMode } from "../../../../core/prompts/tools/filter-tools-for-mode" + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function makeMcpTool(serverName: string, toolName: string): OpenAI.Chat.ChatCompletionTool { + return { + type: "function", + function: { + name: "mcp--" + serverName + "--" + toolName, + description: serverName + " / " + toolName, + parameters: { type: "object", properties: {} }, + }, + } +} + +function makeNativeTool(name: string): OpenAI.Chat.ChatCompletionTool { + return { + type: "function", + function: { + name: name, + description: "native tool " + name, + parameters: { type: "object", properties: {} }, + }, + } +} + +// A mode with 'mcp' group and NO filtering config (plain string) +const modeNoFilter: ModeConfig = { + slug: "mode-no-filter", + name: "No Filter Mode", + roleDefinition: "test role", + groups: ["read", "mcp"], +} + +// A mode with 'mcp' group and a disabled server +const modeServerDisabled: ModeConfig = { + slug: "mode-server-disabled", + name: "Server Disabled Mode", + roleDefinition: "test role", + groups: [ + "read", + [ + "mcp", + { + mcpServers: { + "weather-server": { disabled: true }, + }, + }, + ], + ], +} + +// A mode with 'mcp' group and a tool in disabledTools +const modeToolDisabled: ModeConfig = { + slug: "mode-tool-disabled", + name: "Tool Disabled Mode", + roleDefinition: "test role", + groups: [ + "read", + [ + "mcp", + { + mcpServers: { + "weather-server": { + disabled: false, + disabledTools: ["get_forecast"], + }, + }, + }, + ], + ], +} + +// A mode with 'mcp' group and allowedTools whitelist +const modeAllowedTools: ModeConfig = { + slug: "mode-allowed-tools", + name: "Allowed Tools Mode", + roleDefinition: "test role", + groups: [ + "read", + [ + "mcp", + { + mcpServers: { + "weather-server": { + disabled: false, + allowedTools: ["get_forecast"], + }, + }, + }, + ], + ], +} + +// A mode with mcpDefaultPolicy = 'deny' +const modeDenyPolicy: ModeConfig = { + slug: "mode-deny-policy", + name: "Deny Policy Mode", + roleDefinition: "test role", + groups: [ + "read", + [ + "mcp", + { + mcpDefaultPolicy: "deny", + mcpServers: { + "weather-server": { disabled: false }, + }, + }, + ], + ], +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("filterMcpToolsForMode - per-server/tool MCP filtering", () => { + const weatherTool1 = makeMcpTool("weather-server", "get_forecast") + const weatherTool2 = makeMcpTool("weather-server", "get_alerts") + const dbTool = makeMcpTool("db-server", "query") + const experiments = {} + + it("passes all MCP tools when mode has no MCP filtering config", () => { + const tools = [weatherTool1, weatherTool2, dbTool] + const result = filterMcpToolsForMode(tools, "mode-no-filter", [modeNoFilter], experiments) + expect(result).toHaveLength(3) + }) + + it("excludes tools from a disabled server", () => { + const tools = [weatherTool1, weatherTool2, dbTool] + const result = filterMcpToolsForMode(tools, "mode-server-disabled", [modeServerDisabled], experiments) + // weather-server disabled, only db-server tool remains + expect(result).toHaveLength(1) + expect((result[0] as any).function.name).toBe("mcp--db-server--query") + }) + + it("excludes a tool that is in disabledTools", () => { + const tools = [weatherTool1, weatherTool2, dbTool] + const result = filterMcpToolsForMode(tools, "mode-tool-disabled", [modeToolDisabled], experiments) + // get_forecast disabled, get_alerts and db query remain + expect(result).toHaveLength(2) + const names = result.map((t: any) => t.function.name) + expect(names).toContain("mcp--weather-server--get_alerts") + expect(names).toContain("mcp--db-server--query") + expect(names).not.toContain("mcp--weather-server--get_forecast") + }) + + it("only allows tools in allowedTools whitelist", () => { + const tools = [weatherTool1, weatherTool2, dbTool] + const result = filterMcpToolsForMode(tools, "mode-allowed-tools", [modeAllowedTools], experiments) + // Only get_forecast from weather-server + db-server (no filter on db) + expect(result).toHaveLength(2) + const names = result.map((t: any) => t.function.name) + expect(names).toContain("mcp--weather-server--get_forecast") + expect(names).toContain("mcp--db-server--query") + expect(names).not.toContain("mcp--weather-server--get_alerts") + }) + + it("excludes unlisted servers when mcpDefaultPolicy is deny", () => { + const tools = [weatherTool1, dbTool] + const result = filterMcpToolsForMode(tools, "mode-deny-policy", [modeDenyPolicy], experiments) + // weather-server listed (not disabled), db-server not listed + deny policy + expect(result).toHaveLength(1) + expect((result[0] as any).function.name).toBe("mcp--weather-server--get_forecast") + }) + + it("does not affect non-MCP tools (passthrough)", () => { + const nativeTool = makeNativeTool("read_file") + // filterMcpToolsForMode receives only MCP tools in practice, + // but if a non-MCP tool sneaks in, it should be preserved. + const tools = [nativeTool, weatherTool1] + const result = filterMcpToolsForMode(tools, "mode-server-disabled", [modeServerDisabled], experiments) + // native tool passes through, weather-server disabled + const names = result.map((t: any) => t.function.name) + expect(names).toContain("read_file") + expect(names).not.toContain("mcp--weather-server--get_forecast") + }) +}) diff --git a/src/tests/core/task/build-tools-mcp.test.ts b/src/tests/core/task/build-tools-mcp.test.ts new file mode 100644 index 00000000000..5c963f70b94 --- /dev/null +++ b/src/tests/core/task/build-tools-mcp.test.ts @@ -0,0 +1,190 @@ +import type OpenAI from "openai" +import type { ModeConfig, McpServer, McpTool } from "@roo-code/types" + +import type { McpHub } from "../../../services/mcp/McpHub" + +// --------------------------------------------------------------------------- +// Mock setup — must be before imports that trigger module resolution +// --------------------------------------------------------------------------- + +// Mock vscode +vi.mock("vscode", () => ({})) + +// Mock CodeIndexManager +vi.mock("../../../services/code-index/manager", () => ({ + CodeIndexManager: { + getInstance: vi.fn().mockReturnValue(undefined), + }, +})) + +// Mock customToolRegistry +vi.mock("@roo-code/core", () => ({ + customToolRegistry: { + loadFromDirectoriesIfStale: vi.fn().mockResolvedValue(undefined), + getAllSerialized: vi.fn().mockReturnValue([]), + }, + formatNative: vi.fn(), +})) + +// Mock getRooDirectoriesForCwd +vi.mock("../../../services/roo-config/index.js", () => ({ + getRooDirectoriesForCwd: vi.fn().mockReturnValue([]), +})) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function makeMcpTool(serverName: string, toolName: string): OpenAI.Chat.ChatCompletionTool { + return { + type: "function", + function: { + name: "mcp--" + serverName + "--" + toolName, + description: serverName + " / " + toolName, + parameters: { type: "object", properties: {} }, + }, + } +} + +function makeNativeTool(name: string): OpenAI.Chat.ChatCompletionTool { + return { + type: "function", + function: { + name: name, + description: "native " + name, + parameters: { type: "object", properties: {} }, + }, + } +} + +function createMockMcpServer(name: string, tools: McpTool[]): McpServer { + return { + name: name, + config: JSON.stringify({ type: "stdio", command: "test" }), + status: "connected", + source: "global", + tools: tools, + } as McpServer +} + +function createMockMcpHub(servers: McpServer[]): Partial { + return { + getServers: vi.fn().mockReturnValue(servers), + } +} + +// A mode with 'mcp' group + server disabled +const modeServerDisabled: ModeConfig = { + slug: "mode-server-disabled", + name: "Server Disabled", + roleDefinition: "test", + groups: [ + "read", + "edit", + "command", + [ + "mcp", + { + mcpServers: { + "weather-server": { disabled: true }, + }, + }, + ], + ], +} + +// A mode with 'mcp' group and no filtering +const modeNoFilter: ModeConfig = { + slug: "mode-no-filter", + name: "No Filter", + roleDefinition: "test", + groups: ["read", "edit", "command", "mcp"], +} + +// --------------------------------------------------------------------------- +// Import under test (after mocks) +// --------------------------------------------------------------------------- +import { buildNativeToolsArrayWithRestrictions } from "../../../core/task/build-tools" + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("build-tools MCP filtering (Gemini path)", () => { + function createMockProvider(mcpHub: Partial) { + return { + getMcpHub: vi.fn().mockReturnValue(mcpHub), + context: {} as any, + } as any + } + + it("uses filteredMcpTools in allTools when includeAllToolsWithRestrictions is true", async () => { + const mcpHub = createMockMcpHub([ + createMockMcpServer("weather-server", [{ name: "get_forecast", description: "forecast" } as McpTool]), + createMockMcpServer("db-server", [{ name: "query", description: "query db" } as McpTool]), + ]) + + const result = await buildNativeToolsArrayWithRestrictions({ + provider: createMockProvider(mcpHub), + cwd: "/test", + mode: "mode-server-disabled", + customModes: [modeServerDisabled], + experiments: {}, + apiConfiguration: undefined, + includeAllToolsWithRestrictions: true, + }) + + // allTools should use filteredMcpTools, so weather-server tools excluded + const toolNames = result.tools.map((t: any) => t.function.name) + expect(toolNames).not.toContain("mcp--weather-server--get_forecast") + expect(toolNames).toContain("mcp--db-server--query") + }) + + // ISSUE-19: Native tools remain unfiltered in Gemini's tool list. + it("native tools remain unfiltered in allTools (ISSUE-19)", async () => { + const mcpHub = createMockMcpHub([]) + + const result = await buildNativeToolsArrayWithRestrictions({ + provider: createMockProvider(mcpHub), + cwd: "/test", + mode: "mode-no-filter", + customModes: [modeNoFilter], + experiments: {}, + apiConfiguration: undefined, + includeAllToolsWithRestrictions: true, + }) + + // allTools should contain unfiltered native tools (e.g. write_to_file) + // even though mode filtering would restrict some. This is intentional + // because Gemini uses allowedFunctionNames to restrict callable tools. + const toolNames = result.tools.map((t: any) => t.function.name) + // Native tools should be present (unfiltered in allTools) + expect(toolNames.length).toBeGreaterThan(0) + // The tools array should contain native tools + const hasNativeTools = toolNames.some((n: string) => !n.startsWith("mcp--")) + expect(hasNativeTools).toBe(true) + }) + + it("excludes all tools from a disabled server in allTools", async () => { + const mcpHub = createMockMcpHub([ + createMockMcpServer("weather-server", [ + { name: "get_forecast", description: "f" } as McpTool, + { name: "get_alerts", description: "a" } as McpTool, + ]), + ]) + + const result = await buildNativeToolsArrayWithRestrictions({ + provider: createMockProvider(mcpHub), + cwd: "/test", + mode: "mode-server-disabled", + customModes: [modeServerDisabled], + experiments: {}, + apiConfiguration: undefined, + includeAllToolsWithRestrictions: true, + }) + + const mcpToolNames = result.tools.map((t: any) => t.function.name).filter((n: string) => n.startsWith("mcp--")) + + expect(mcpToolNames).toHaveLength(0) + }) +}) diff --git a/src/tests/core/tools/validateToolUse-mcp.test.ts b/src/tests/core/tools/validateToolUse-mcp.test.ts new file mode 100644 index 00000000000..f909cc76d03 --- /dev/null +++ b/src/tests/core/tools/validateToolUse-mcp.test.ts @@ -0,0 +1,168 @@ +// cd src && npx vitest run tests/core/tools/validateToolUse-mcp.test.ts + +import type { ModeConfig } from "@roo-code/types" + +import { validateToolUse } from "../../../core/tools/validateToolUse" + +// --------------------------------------------------------------------------- +// Helper: build a custom mode with MCP filtering options +// --------------------------------------------------------------------------- + +function buildMcpMode(slug: string, mcpOptions?: Record): ModeConfig { + const mcpGroup = mcpOptions ? ["mcp", mcpOptions] : "mcp" + return { + slug, + name: slug, + roleDefinition: "test mode", + groups: ["read", mcpGroup] as any, + } +} + +// --------------------------------------------------------------------------- +// Shared fixtures +// --------------------------------------------------------------------------- + +// Mode that allows serverA but disables serverB +const modeWithFilter = buildMcpMode("filtered", { + mcpServers: { + serverA: { disabled: false }, + serverB: { disabled: true }, + }, +}) + +// Mode with deny-default policy (unlisted servers are blocked) +const modeDenyDefault = buildMcpMode("deny-default", { + mcpDefaultPolicy: "deny", + mcpServers: { + allowedServer: { disabled: false }, + }, +}) + +// Mode with tool-level filtering on serverC +const modeToolFilter = buildMcpMode("tool-filter", { + mcpServers: { + serverC: { + disabled: false, + disabledTools: ["blocked_tool"], + }, + }, +}) + +// Mode with no MCP filtering (plain 'mcp' group) +const modeNoFilter = buildMcpMode("no-filter") + +const customModes: ModeConfig[] = [modeWithFilter, modeDenyDefault, modeToolFilter, modeNoFilter] + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("validateToolUse — MCP filtering", () => { + // ----- use_mcp_tool ----- + + describe("use_mcp_tool", () => { + it("does not throw when server is allowed", () => { + expect(() => + validateToolUse("use_mcp_tool", "filtered", customModes, undefined, { + server_name: "serverA", + tool_name: "any_tool", + }), + ).not.toThrow() + }) + + it("throws when server is disabled", () => { + expect(() => + validateToolUse("use_mcp_tool", "filtered", customModes, undefined, { + server_name: "serverB", + tool_name: "any_tool", + }), + ).toThrow('MCP server "serverB" is not allowed in filtered mode') + }) + + it("extracts server_name from toolParams", () => { + // serverB is disabled — the function must read server_name from params + expect(() => + validateToolUse("use_mcp_tool", "filtered", customModes, undefined, { server_name: "serverB" }), + ).toThrow("serverB") + }) + + it("ISSUE-21: also checks tool-level when tool_name is available", () => { + expect(() => + validateToolUse("use_mcp_tool", "tool-filter", customModes, undefined, { + server_name: "serverC", + tool_name: "blocked_tool", + }), + ).toThrow('MCP tool "blocked_tool" on server "serverC" is not allowed') + }) + + it("ISSUE-21: allows tool when not in disabledTools", () => { + expect(() => + validateToolUse("use_mcp_tool", "tool-filter", customModes, undefined, { + server_name: "serverC", + tool_name: "ok_tool", + }), + ).not.toThrow() + }) + }) + + // ----- access_mcp_resource ----- + + describe("access_mcp_resource", () => { + it("does not throw when server is allowed", () => { + expect(() => + validateToolUse("access_mcp_resource", "filtered", customModes, undefined, { + server_name: "serverA", + uri: "res://x", + }), + ).not.toThrow() + }) + + it("throws when server is disabled", () => { + expect(() => + validateToolUse("access_mcp_resource", "filtered", customModes, undefined, { + server_name: "serverB", + uri: "res://x", + }), + ).toThrow('MCP server "serverB" is not allowed in filtered mode') + }) + }) + + // ----- Dynamic MCP tools (mcp--server--tool) ----- + + describe("dynamic MCP tools", () => { + it("allows when server and tool are permitted", () => { + expect(() => validateToolUse("mcp--serverA--some_tool" as any, "filtered", customModes)).not.toThrow() + }) + + it("throws when server is disabled", () => { + expect(() => validateToolUse("mcp--serverB--some_tool" as any, "filtered", customModes)).toThrow( + "not allowed in filtered mode", + ) + }) + + it("throws when tool is in disabledTools", () => { + expect(() => validateToolUse("mcp--serverC--blocked_tool" as any, "tool-filter", customModes)).toThrow( + "not allowed in tool-filter mode", + ) + }) + + it("throws with deny default policy and unlisted server", () => { + expect(() => validateToolUse("mcp--unknownServer--tool" as any, "deny-default", customModes)).toThrow( + "not allowed in deny-default mode", + ) + }) + + it("allows with deny policy when server is explicitly allowed", () => { + expect(() => validateToolUse("mcp--allowedServer--tool" as any, "deny-default", customModes)).not.toThrow() + }) + }) + + // ----- Non-MCP tools unaffected ----- + + describe("non-MCP tools", () => { + it("are unaffected by MCP filtering", () => { + // read_file is a regular tool in the read group — should still work + expect(() => validateToolUse("read_file", "filtered", customModes)).not.toThrow() + }) + }) +}) diff --git a/src/tests/utils/mcp-filter.test.ts b/src/tests/utils/mcp-filter.test.ts new file mode 100644 index 00000000000..c8a637ad80f --- /dev/null +++ b/src/tests/utils/mcp-filter.test.ts @@ -0,0 +1,248 @@ +import type { ModeConfig, GroupEntry } from "@roo-code/types" + +import { getGroupName } from "../../shared/modes" +import { getMcpFilterForMode, isMcpServerAllowedForMode, isMcpToolAllowedForMode } from "../../utils/mcp-filter" + +// --------------------------------------------------------------------------- +// Helpers – reusable mode fixtures +// --------------------------------------------------------------------------- + +function makeModeWithMcpGroup(slug: string, mcpEntry: GroupEntry): ModeConfig { + return { + slug, + name: "Test Mode", + roleDefinition: "test", + groups: ["read", mcpEntry], + } +} + +function makeModeWithoutMcp(slug: string): ModeConfig { + return { + slug, + name: "No MCP", + roleDefinition: "test", + groups: ["read", "edit"], + } +} + +// --------------------------------------------------------------------------- +// getMcpFilterForMode +// --------------------------------------------------------------------------- + +describe("getMcpFilterForMode", () => { + test("returns undefined when mode has no mcp group", () => { + const modes: ModeConfig[] = [makeModeWithoutMcp("no-mcp")] + expect(getMcpFilterForMode("no-mcp", modes)).toBeUndefined() + }) + + test("returns empty options when mcp group is a plain string", () => { + const modes: ModeConfig[] = [makeModeWithMcpGroup("plain", "mcp")] + const result = getMcpFilterForMode("plain", modes) + expect(result).toEqual({}) + }) + + test("returns mcpServers and mcpDefaultPolicy from tuple options", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("filtered", [ + "mcp", + { + mcpServers: { "my-server": { disabled: true } }, + mcpDefaultPolicy: "deny", + }, + ]), + ] + const result = getMcpFilterForMode("filtered", modes) + expect(result).toEqual({ + mcpServers: { "my-server": { disabled: true } }, + mcpDefaultPolicy: "deny", + }) + }) + + test("returns undefined for unknown mode slug", () => { + expect(getMcpFilterForMode("nonexistent", [])).toBeUndefined() + }) + + test("falls back to built-in modes (e.g. code mode has mcp group)", () => { + // Passing no custom modes should still find the built-in 'code' mode + const result = getMcpFilterForMode("code") + // 'code' mode has a plain 'mcp' string entry → empty options + expect(result).toEqual({}) + }) +}) + +// --------------------------------------------------------------------------- +// isMcpServerAllowedForMode +// --------------------------------------------------------------------------- + +describe("isMcpServerAllowedForMode", () => { + test("returns true when mode has no mcp group config (default allow)", () => { + const modes: ModeConfig[] = [makeModeWithoutMcp("no-mcp")] + expect(isMcpServerAllowedForMode("any-server", "no-mcp", modes)).toBe(true) + }) + + test("returns false when server is explicitly disabled", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { "blocked-server": { disabled: true } } }]), + ] + expect(isMcpServerAllowedForMode("blocked-server", "m", modes)).toBe(false) + }) + + test("returns true when server is not in the filter (default allow policy)", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { "other-server": { disabled: true } } }]), + ] + expect(isMcpServerAllowedForMode("unlisted-server", "m", modes)).toBe(true) + }) + + test("returns false when server is not in the filter with deny default policy", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", [ + "mcp", + { + mcpServers: { "allowed-server": {} }, + mcpDefaultPolicy: "deny", + }, + ]), + ] + expect(isMcpServerAllowedForMode("unlisted-server", "m", modes)).toBe(false) + }) + + test("returns true when server is in the filter and not disabled (deny policy)", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", [ + "mcp", + { + mcpServers: { "allowed-server": {} }, + mcpDefaultPolicy: "deny", + }, + ]), + ] + expect(isMcpServerAllowedForMode("allowed-server", "m", modes)).toBe(true) + }) + + test("name matching is case-insensitive and separator-normalized", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { "My-Server": { disabled: true } } }]), + ] + // 'my_server' should match 'My-Server' after normalization + expect(isMcpServerAllowedForMode("my_server", "m", modes)).toBe(false) + expect(isMcpServerAllowedForMode("MY SERVER", "m", modes)).toBe(false) + }) +}) + +// --------------------------------------------------------------------------- +// isMcpToolAllowedForMode +// --------------------------------------------------------------------------- + +describe("isMcpToolAllowedForMode", () => { + test("returns true when server has no tool-level filtering", () => { + const modes: ModeConfig[] = [makeModeWithMcpGroup("m", ["mcp", { mcpServers: { srv: {} } }])] + expect(isMcpToolAllowedForMode("srv", "any-tool", "m", modes)).toBe(true) + }) + + test("returns true when tool is in allowedTools", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { srv: { allowedTools: ["tool-a", "tool-b"] } } }]), + ] + expect(isMcpToolAllowedForMode("srv", "tool-a", "m", modes)).toBe(true) + }) + + test("returns false when tool is NOT in allowedTools", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { srv: { allowedTools: ["tool-a"] } } }]), + ] + expect(isMcpToolAllowedForMode("srv", "tool-x", "m", modes)).toBe(false) + }) + + test("returns false when tool is in disabledTools", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { srv: { disabledTools: ["bad-tool"] } } }]), + ] + expect(isMcpToolAllowedForMode("srv", "bad-tool", "m", modes)).toBe(false) + }) + + test("returns true when tool is NOT in disabledTools", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", ["mcp", { mcpServers: { srv: { disabledTools: ["bad-tool"] } } }]), + ] + expect(isMcpToolAllowedForMode("srv", "good-tool", "m", modes)).toBe(true) + }) + + test("allowedTools takes precedence over disabledTools", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", [ + "mcp", + { + mcpServers: { + srv: { + allowedTools: ["tool-a"], + disabledTools: ["tool-a"], + }, + }, + }, + ]), + ] + // allowedTools is checked first; tool-a is in allowed list → true + expect(isMcpToolAllowedForMode("srv", "tool-a", "m", modes)).toBe(true) + }) + + test("returns false when server itself is disabled", () => { + const modes: ModeConfig[] = [ + makeModeWithMcpGroup("m", [ + "mcp", + { + mcpServers: { + srv: { + disabled: true, + allowedTools: ["tool-a"], + }, + }, + }, + ]), + ] + expect(isMcpToolAllowedForMode("srv", "tool-a", "m", modes)).toBe(false) + }) +}) + +// --------------------------------------------------------------------------- +// Cross-validation: inlined getGroupName vs real getGroupName (ISSUE-16) +// --------------------------------------------------------------------------- + +describe("ISSUE-16: inlined getGroupName cross-validation", () => { + const sampleEntries: GroupEntry[] = [ + "read", + "edit", + "mcp", + "command", + ["mcp", { mcpServers: { s: {} } }], + ["edit", { fileRegex: "\\.md$", description: "Markdown only" }], + ] + + test("inlined getGroupName matches real getGroupName for all sample entries", () => { + // The inlined helper in mcp-filter.ts is not exported directly, + // but getMcpFilterForMode uses it internally. We verify equivalence + // by checking that the real getGroupName produces expected values + // and that getMcpFilterForMode behaves consistently with those values. + for (const entry of sampleEntries) { + const realName = getGroupName(entry) + // inlined logic: typeof entry === 'string' ? entry : entry[0] + const inlinedName = typeof entry === "string" ? entry : entry[0] + expect(inlinedName).toBe(realName) + } + }) + + test("getMcpFilterForMode finds mcp group correctly for tuple entry", () => { + const modes: ModeConfig[] = [ + { + slug: "cross-val", + name: "Cross Val", + roleDefinition: "test", + groups: ["read", ["mcp", { mcpServers: { "test-srv": { disabled: true } }, mcpDefaultPolicy: "deny" }]], + }, + ] + const result = getMcpFilterForMode("cross-val", modes) + expect(result).toBeDefined() + expect(result!.mcpServers).toEqual({ "test-srv": { disabled: true } }) + expect(result!.mcpDefaultPolicy).toBe("deny") + }) +}) diff --git a/src/tests/utils/mcp-name.test.ts b/src/tests/utils/mcp-name.test.ts new file mode 100644 index 00000000000..0b48af49fe2 --- /dev/null +++ b/src/tests/utils/mcp-name.test.ts @@ -0,0 +1,31 @@ +import { normalizeForComparison } from "../../utils/mcp-name" + +describe("normalizeForComparison", () => { + it("converts to lowercase", () => { + expect(normalizeForComparison("MyServer")).toBe("myserver") + }) + + it("replaces hyphens with underscores", () => { + expect(normalizeForComparison("my-server")).toBe("my_server") + }) + + it("replaces spaces with underscores", () => { + expect(normalizeForComparison("my server")).toBe("my_server") + }) + + it("handles multiple hyphens individually (not collapsed)", () => { + expect(normalizeForComparison("my--server")).toBe("my__server") + }) + + it("handles mixed separators", () => { + expect(normalizeForComparison("My-Cool Server")).toBe("my_cool_server") + }) + + it("preserves dots and colons (FLAG-D known limitation)", () => { + expect(normalizeForComparison("server.v2:main")).toBe("server.v2:main") + }) + + it("returns empty string for empty input", () => { + expect(normalizeForComparison("")).toBe("") + }) +}) diff --git a/src/utils/mcp-filter.ts b/src/utils/mcp-filter.ts new file mode 100644 index 00000000000..dc821e4d788 --- /dev/null +++ b/src/utils/mcp-filter.ts @@ -0,0 +1,184 @@ +/** + * MCP server/tool filtering helpers for per-mode access control. + * + * ISSUE-16 (M1): This module must NOT import from 'vscode' or any module + * that transitively imports 'vscode'. The getGroupName / getGroupOptions + * helpers are inlined instead of imported from src/shared/modes.ts. + */ + +import type { GroupEntry, McpGroupOptions, McpServerFilter, ModeConfig } from "@roo-code/types" + +import { DEFAULT_MODES } from "@roo-code/types" + +import { normalizeForComparison } from "./mcp-name" + +// --------------------------------------------------------------------------- +// Inlined helpers (M1 — avoids vscode import chain via src/shared/modes.ts) +// --------------------------------------------------------------------------- + +/** + * Extract the group name from a GroupEntry, which can be either a plain + * string ('mcp') or a tuple (['mcp', { ... }]). + */ +function getGroupName(entry: GroupEntry): string { + if (typeof entry === "string") { + return entry + } + return entry[0] +} + +/** + * Extract the options object from a GroupEntry tuple. Returns undefined + * when the entry is a plain string. + */ +function getGroupOptions(entry: GroupEntry): Record | undefined { + if (typeof entry === "string") { + return undefined + } + return entry[1] as Record | undefined +} + +// --------------------------------------------------------------------------- +// Mode lookup (inlined to avoid vscode dependency) +// --------------------------------------------------------------------------- + +function findMode(modeSlug: string, customModes?: ModeConfig[]): ModeConfig | undefined { + // Custom modes take precedence + const custom = customModes?.find((m) => m.slug === modeSlug) + if (custom) { + return custom + } + // Fall back to built-in modes + return DEFAULT_MODES.find((m) => m.slug === modeSlug) +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/** + * Retrieve the MCP group options for a mode. Returns `undefined` when the + * mode does not exist or does not include an 'mcp' group. Returns an empty + * object `{}` when the mcp group is a plain string (no filtering configured). + */ +export function getMcpFilterForMode(modeSlug: string, customModes?: ModeConfig[]): McpGroupOptions | undefined { + const mode = findMode(modeSlug, customModes) + if (!mode) { + return undefined + } + + const mcpEntry = mode.groups.find((g) => getGroupName(g) === "mcp") + + if (!mcpEntry) { + return undefined + } + + const opts = getGroupOptions(mcpEntry) + if (!opts) { + // Plain string 'mcp' — no filtering configured + return {} + } + + return opts as McpGroupOptions +} + +/** + * Determine whether a given MCP server is allowed for a mode. + * + * Rules: + * - No mcp group at all → true (no filtering) + * - Server explicitly disabled → false + * - Server not listed + allow policy (default) → true + * - Server not listed + deny policy → false + * - Server listed + not disabled → true + */ +export function isMcpServerAllowedForMode(serverName: string, modeSlug: string, customModes?: ModeConfig[]): boolean { + const filter = getMcpFilterForMode(modeSlug, customModes) + + // No mcp group at all → allow everything + if (filter === undefined) { + return true + } + + const servers = filter.mcpServers + if (!servers) { + // mcp group exists but no server-level config → allow + return true + } + + const normalizedInput = normalizeForComparison(serverName) + let matchedFilter: McpServerFilter | undefined + + for (const [configName, configFilter] of Object.entries(servers)) { + if (normalizeForComparison(configName) === normalizedInput) { + matchedFilter = configFilter + break + } + } + + if (matchedFilter !== undefined) { + // Server is explicitly listed + return !matchedFilter.disabled + } + + // Server not listed — check default policy + const policy = filter.mcpDefaultPolicy || "allow" + return policy === "allow" +} + +/** + * Determine whether a specific tool on an MCP server is allowed for a mode. + * + * Rules: + * - Server disabled → false (regardless of tool config) + * - No tool-level filtering → true + * - allowedTools exists → tool must be in list (takes precedence) + * - disabledTools exists → tool must NOT be in list + * - Default → true + */ +export function isMcpToolAllowedForMode( + serverName: string, + toolName: string, + modeSlug: string, + customModes?: ModeConfig[], +): boolean { + // First check server-level access + if (!isMcpServerAllowedForMode(serverName, modeSlug, customModes)) { + return false + } + + const filter = getMcpFilterForMode(modeSlug, customModes) + if (!filter || !filter.mcpServers) { + return true + } + + // Find the server filter entry using normalized comparison + const normalizedServer = normalizeForComparison(serverName) + let serverFilter: McpServerFilter | undefined + + for (const [configName, configFilter] of Object.entries(filter.mcpServers)) { + if (normalizeForComparison(configName) === normalizedServer) { + serverFilter = configFilter + break + } + } + + if (!serverFilter) { + // Server not in the filter list → already allowed by isMcpServerAllowedForMode + return true + } + + const normalizedTool = normalizeForComparison(toolName) + + // allowedTools takes precedence + if (serverFilter.allowedTools && serverFilter.allowedTools.length > 0) { + return serverFilter.allowedTools.some((t) => normalizeForComparison(t) === normalizedTool) + } + + // disabledTools check + if (serverFilter.disabledTools && serverFilter.disabledTools.length > 0) { + return !serverFilter.disabledTools.some((t) => normalizeForComparison(t) === normalizedTool) + } + + return true +} diff --git a/src/utils/mcp-name.ts b/src/utils/mcp-name.ts index 5f75f49c64e..79b37b29142 100644 --- a/src/utils/mcp-name.ts +++ b/src/utils/mcp-name.ts @@ -18,14 +18,19 @@ export const MCP_TOOL_SEPARATOR = "--" export const MCP_TOOL_PREFIX = "mcp" /** - * Normalize a string for comparison by treating hyphens and underscores as equivalent. - * This is used to match tool names when models convert hyphens to underscores. + * Normalize a string for comparison by lowercasing and treating hyphens, + * spaces, and underscores as equivalent. + * This is used to match tool/server names when models convert hyphens to + * underscores or when config names use different separators. + * + * NOTE (FLAG-D): Dots and colons are NOT stripped. Server names like + * "my.server:v2" will only match if the config uses the same pattern. * * @param name - The name to normalize - * @returns The normalized name with all hyphens converted to underscores + * @returns The normalized name lowercased with hyphens/spaces converted to underscores */ export function normalizeForComparison(name: string): string { - return name.replace(/-/g, "_") + return name.toLowerCase().replace(/[-\s]/g, "_") } /** diff --git a/webview-ui/src/__tests__/ModesView-groupChange.spec.tsx b/webview-ui/src/__tests__/ModesView-groupChange.spec.tsx new file mode 100644 index 00000000000..596b91630eb --- /dev/null +++ b/webview-ui/src/__tests__/ModesView-groupChange.spec.tsx @@ -0,0 +1,148 @@ +import type { GroupEntry } from "@roo-code/types" + +import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "../components/modes/groupOptionsCache" + +// --------------------------------------------------------------------------- +// Test data +// --------------------------------------------------------------------------- + +const mcpOptions = { + mcpServers: { + "my-server": { disabled: false, allowedTools: ["tool-a"] }, + }, + mcpDefaultPolicy: "deny" as const, +} + +const mcpTuple: GroupEntry = ["mcp", mcpOptions] + +const readGroup: GroupEntry = "read" +const editGroup: GroupEntry = "edit" + +const groupsWithMcpTuple: GroupEntry[] = [readGroup, editGroup, mcpTuple] +const groupsPlainOnly: GroupEntry[] = [readGroup, editGroup, "mcp"] + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe("groupOptionsCache", () => { + describe("removeGroupWithCache — caching tuple options", () => { + it("caches options when removing a group with tuple entry", () => { + const cache = new Map() + const result = removeGroupWithCache(cache, groupsWithMcpTuple, "mcp") + + // The mcp group should be removed + expect(result).toEqual([readGroup, editGroup]) + + // The cache should contain the mcp options + expect(cache.get("mcp")).toEqual(mcpOptions) + }) + + it("does not cache anything for a plain string group", () => { + const cache = new Map() + const result = removeGroupWithCache(cache, groupsPlainOnly, "mcp") + + expect(result).toEqual([readGroup, editGroup]) + + // Cache should NOT have an entry for mcp + expect(cache.has("mcp")).toBe(false) + }) + }) + + describe("addGroupWithCache — restoring cached options", () => { + it("restores cached tuple options when re-adding a group", () => { + const cache = new Map() + + // First remove to populate cache + const afterRemove = removeGroupWithCache(cache, groupsWithMcpTuple, "mcp") + + // Now re-add + const afterAdd = addGroupWithCache(cache, afterRemove, "mcp") + + // Should restore as tuple with cached options + const mcpEntry = afterAdd.find((g) => (Array.isArray(g) ? g[0] === "mcp" : g === "mcp")) + expect(Array.isArray(mcpEntry)).toBe(true) + expect(mcpEntry).toEqual(["mcp", mcpOptions]) + }) + + it("adds as plain string when no cached options exist", () => { + const cache = new Map() + + // Remove plain 'mcp' — nothing to cache + const afterRemove = removeGroupWithCache(cache, groupsPlainOnly, "mcp") + + // Re-add — should be plain string since no cache + const afterAdd = addGroupWithCache(cache, afterRemove, "mcp") + + const mcpEntry = afterAdd.find((g) => (Array.isArray(g) ? g[0] === "mcp" : g === "mcp")) + expect(mcpEntry).toBe("mcp") + }) + }) + + describe("syncCacheFromGroups — external state sync", () => { + it("populates cache from groups containing tuples", () => { + const cache = new Map() + + syncCacheFromGroups(cache, groupsWithMcpTuple) + + expect(cache.get("mcp")).toEqual(mcpOptions) + }) + + it("does not populate cache from plain string groups", () => { + const cache = new Map() + + syncCacheFromGroups(cache, groupsPlainOnly) + + expect(cache.has("mcp")).toBe(false) + }) + + it("updates cache when called with new tuple data", () => { + const cache = new Map() + const updatedOptions = { + mcpServers: { + "new-server": { disabled: false }, + }, + mcpDefaultPolicy: "allow" as const, + } + const updatedTuple: GroupEntry = ["mcp", updatedOptions] + + // First sync with original data + syncCacheFromGroups(cache, groupsWithMcpTuple) + expect(cache.get("mcp")).toEqual(mcpOptions) + + // Sync again with updated data + syncCacheFromGroups(cache, [readGroup, editGroup, updatedTuple]) + expect(cache.get("mcp")).toEqual(updatedOptions) + }) + }) + + describe("MCP round-trip — toggle off then on preserves config", () => { + it("MCP group with mcpServers config survives toggle off/on", () => { + const cache = new Map() + + // Sync cache from initial state (simulates useEffect) + syncCacheFromGroups(cache, groupsWithMcpTuple) + + // Toggle off (uncheck) + const afterUncheck = removeGroupWithCache(cache, groupsWithMcpTuple, "mcp") + + // Verify mcp is removed + expect(afterUncheck.some((g) => (Array.isArray(g) ? g[0] === "mcp" : g === "mcp"))).toBe(false) + + // Toggle on (re-check) + const afterRecheck = addGroupWithCache(cache, afterUncheck, "mcp") + + // Verify mcp is restored with full config + const restored = afterRecheck.find((g) => (Array.isArray(g) ? g[0] === "mcp" : g === "mcp")) + expect(restored).toEqual(["mcp", mcpOptions]) + + // Specifically check mcpServers survived + expect((restored as [string, typeof mcpOptions])[1].mcpServers).toEqual({ + "my-server": { + disabled: false, + allowedTools: ["tool-a"], + }, + }) + }) + }) +}) diff --git a/webview-ui/src/components/modes/ModesView.tsx b/webview-ui/src/components/modes/ModesView.tsx index eeeaf026cc2..4da0b17ec09 100644 --- a/webview-ui/src/components/modes/ModesView.tsx +++ b/webview-ui/src/components/modes/ModesView.tsx @@ -24,6 +24,8 @@ import { } from "@roo/modes" import { TOOL_GROUPS } from "@roo/tools" +import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "./groupOptionsCache" + import { vscode } from "@src/utils/vscode" import { buildDocLink } from "@src/utils/docLinks" import { useAppTranslation } from "@src/i18n/TranslationContext" @@ -117,6 +119,9 @@ const ModesView = () => { const [renameInputValue, setRenameInputValue] = useState("") const renameInputRef = useRef(null) + // Cache for group tuple options so toggling off/on preserves MCP filter config + const groupOptionsCache = useRef>(new Map()) + // Optimistic rename map so search reflects new names immediately const [localRenames, setLocalRenames] = useState>({}) // Display list that overlays optimistic names @@ -462,6 +467,15 @@ const ModesView = () => { setIsCreateModeDialogOpen(true) }, [generateSlug, isNameOrSlugTaken]) + // Sync group options cache whenever custom modes change so that + // externally-loaded tuple options (e.g. mcpServers config) are preserved + // when a user toggles a group off and back on. + useEffect(() => { + for (const cm of customModes || []) { + syncCacheFromGroups(groupOptionsCache.current, cm.groups || []) + } + }, [customModes]) + // Handler for group checkbox changes const handleGroupChange = useCallback( (group: ToolGroup, isCustomMode: boolean, customMode: ModeConfig | undefined) => @@ -472,9 +486,9 @@ const ModesView = () => { const oldGroups = customMode?.groups || [] let newGroups: GroupEntry[] if (checked) { - newGroups = [...oldGroups, group] + newGroups = addGroupWithCache(groupOptionsCache.current, oldGroups, group) } else { - newGroups = oldGroups.filter((g) => getGroupName(g) !== group) + newGroups = removeGroupWithCache(groupOptionsCache.current, oldGroups, group) } if (customMode) { const source = customMode.source || "global" diff --git a/webview-ui/src/components/modes/groupOptionsCache.ts b/webview-ui/src/components/modes/groupOptionsCache.ts new file mode 100644 index 00000000000..3a3cfa259b0 --- /dev/null +++ b/webview-ui/src/components/modes/groupOptionsCache.ts @@ -0,0 +1,62 @@ +import type { GroupEntry, ToolGroup } from "@roo-code/types" + +/** + * Helper to extract the group name from a GroupEntry. + * A GroupEntry is either a plain string or a tuple [name, options]. + */ +export function getGroupName(entry: GroupEntry): string { + if (typeof entry === "string") { + return entry + } + return entry[0] +} + +/** + * Synchronise a cache map with the current groups array. + * For every tuple entry, upsert its options into the cache. + */ +export function syncCacheFromGroups(cache: Map, groups: GroupEntry[]): void { + for (const entry of groups) { + if (Array.isArray(entry) && entry[1]) { + cache.set(entry[0], entry[1]) + } + } +} + +/** + * Remove a group by name. When the entry being removed is a + * tuple (i.e. it carries options), stash those options in the + * cache so they can be restored later. + * + * Returns the filtered groups array. + */ +export function removeGroupWithCache( + cache: Map, + groups: GroupEntry[], + groupName: string, +): GroupEntry[] { + const entry = groups.find((g) => getGroupName(g) === groupName) + if (entry && Array.isArray(entry) && entry[1]) { + cache.set(entry[0], entry[1]) + } + return groups.filter((g) => getGroupName(g) !== groupName) +} + +/** + * Add a group by name. If the cache contains previously-saved + * options for this group, restore it as a tuple [name, options]. + * Otherwise add it as a plain string. + * + * Returns the new groups array. + */ +export function addGroupWithCache( + cache: Map, + groups: GroupEntry[], + groupName: ToolGroup, +): GroupEntry[] { + const cached = cache.get(groupName) + if (cached) { + return [...groups, [groupName, cached] as GroupEntry] + } + return [...groups, groupName] +} diff --git a/webview-ui/src/components/modes/useGroupOptionsCache.ts b/webview-ui/src/components/modes/useGroupOptionsCache.ts new file mode 100644 index 00000000000..8e18354af83 --- /dev/null +++ b/webview-ui/src/components/modes/useGroupOptionsCache.ts @@ -0,0 +1,33 @@ +import { useRef, useEffect, useCallback } from "react" + +import type { GroupEntry, ToolGroup } from "@roo-code/types" + +import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "./groupOptionsCache" + +/** + * Custom hook that caches group options (tuple second element) when + * groups are removed, and restores them when groups are re-added. + * + * This prevents data loss when toggling a group like "mcp" off and + * back on — without the cache, the MCP filter config (mcpServers, + * mcpDefaultPolicy) would be discarded on uncheck. + */ +export function useGroupOptionsCache(groups: GroupEntry[]) { + const groupOptionsCache = useRef>(new Map()) + + // Sync cache with external state: if external state has tuple + // entries, update the cache so that toggles preserve them. + useEffect(() => { + syncCacheFromGroups(groupOptionsCache.current, groups) + }, [groups]) + + const removeGroup = useCallback((currentGroups: GroupEntry[], groupName: string): GroupEntry[] => { + return removeGroupWithCache(groupOptionsCache.current, currentGroups, groupName) + }, []) + + const addGroup = useCallback((currentGroups: GroupEntry[], groupName: ToolGroup): GroupEntry[] => { + return addGroupWithCache(groupOptionsCache.current, currentGroups, groupName) + }, []) + + return { removeGroup, addGroup, cache: groupOptionsCache } +} From 08c501213179e32b1f39ae275082c10a79ca7656 Mon Sep 17 00:00:00 2001 From: Stefan Vetter Date: Tue, 7 Apr 2026 09:21:23 +0200 Subject: [PATCH 2/3] feat: add MCP filter configuration UI for per-mode MCP server/tool filtering Adds UI components to the mode editor for configuring which MCP servers and tools are available per mode, completing the MCP filtering feature. New components: - McpFilterConfig.tsx: Main panel with default policy selector (allow/deny) and server list, shown when 'mcp' group is enabled in mode editor - McpServerFilterRow.tsx: Per-server row with enable/disable toggle and expandable tool-level filtering (allowlist/blocklist/allow-all modes) ModesView.tsx integration: - Edit mode: McpFilterConfig renders below mcp checkbox when checked - Read-only mode: Shows filter summary via McpFilterConfig - Helper functions getMcpOptionsFromGroups/updateMcpOptionsInGroups for extracting/updating MCP tuple options in groups array - groupOptionsCache sync ensures filter config survives toggle cycles Version bump: 3.51.1 -> 3.52.0 --- src/package.json | 2 +- .../src/components/modes/McpFilterConfig.tsx | 167 +++++++++ .../components/modes/McpServerFilterRow.tsx | 324 ++++++++++++++++++ webview-ui/src/components/modes/ModesView.tsx | 144 ++++++-- .../__tests__/McpServerFilterRow.spec.tsx | 275 +++++++++++++++ 5 files changed, 876 insertions(+), 36 deletions(-) create mode 100644 webview-ui/src/components/modes/McpFilterConfig.tsx create mode 100644 webview-ui/src/components/modes/McpServerFilterRow.tsx create mode 100644 webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx diff --git a/src/package.json b/src/package.json index 7c4889abd89..e7e198ea694 100644 --- a/src/package.json +++ b/src/package.json @@ -3,7 +3,7 @@ "displayName": "%extension.displayName%", "description": "%extension.description%", "publisher": "RooVeterinaryInc", - "version": "3.51.1", + "version": "3.52.0", "icon": "assets/icons/icon.png", "galleryBanner": { "color": "#617A91", diff --git a/webview-ui/src/components/modes/McpFilterConfig.tsx b/webview-ui/src/components/modes/McpFilterConfig.tsx new file mode 100644 index 00000000000..59f82bd79f9 --- /dev/null +++ b/webview-ui/src/components/modes/McpFilterConfig.tsx @@ -0,0 +1,167 @@ +import { useCallback } from "react" + +import type { McpServer, McpGroupOptions, McpServerFilter } from "@roo-code/types" + +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui" +import { McpServerFilterRow } from "./McpServerFilterRow" + +export interface McpFilterConfigProps { + mcpServers: McpServer[] + mcpGroupOptions: McpGroupOptions | undefined + onOptionsChange: (options: McpGroupOptions | undefined) => void + isEditing: boolean +} + +function getDefaultPolicy(options: McpGroupOptions | undefined): "allow" | "deny" { + if (options && options.mcpDefaultPolicy) { + return options.mcpDefaultPolicy + } + return "allow" +} + +function hasAnyFilters(options: McpGroupOptions | undefined): boolean { + if (!options) { + return false + } + if (options.mcpDefaultPolicy && options.mcpDefaultPolicy !== "allow") { + return true + } + if (options.mcpServers && Object.keys(options.mcpServers).length > 0) { + return true + } + return false +} + +function getServerFilter(options: McpGroupOptions | undefined, serverName: string): McpServerFilter | undefined { + if (!options || !options.mcpServers) { + return undefined + } + return options.mcpServers[serverName] +} + +function buildCleanOptions( + policy: "allow" | "deny", + servers: Record | undefined, +): McpGroupOptions | undefined { + var hasServers = servers && Object.keys(servers).length > 0 + if (policy === "allow" && !hasServers) { + return undefined + } + var result: McpGroupOptions = {} + if (policy !== "allow") { + result.mcpDefaultPolicy = policy + } + if (hasServers) { + result.mcpServers = servers + } + return result +} + +export function McpFilterConfig({ mcpServers, mcpGroupOptions, onOptionsChange, isEditing }: McpFilterConfigProps) { + var policy = getDefaultPolicy(mcpGroupOptions) + var serverCount = mcpServers.length + + var handlePolicyChange = useCallback( + function (newPolicy: string) { + var typedPolicy = newPolicy as "allow" | "deny" + var currentServers = mcpGroupOptions?.mcpServers + var updated = buildCleanOptions(typedPolicy, currentServers) + onOptionsChange(updated) + }, + [mcpGroupOptions, onOptionsChange], + ) + + var handleServerFilterChange = useCallback( + function (serverName: string, filter: McpServerFilter | undefined) { + var currentServers = mcpGroupOptions?.mcpServers || {} + var updatedServers: Record + + if (filter) { + updatedServers = { ...currentServers, [serverName]: filter } + } else { + updatedServers = { ...currentServers } + delete updatedServers[serverName] + } + + var currentPolicy = getDefaultPolicy(mcpGroupOptions) + var updated = buildCleanOptions(currentPolicy, updatedServers) + onOptionsChange(updated) + }, + [mcpGroupOptions, onOptionsChange], + ) + + // Read-only mode + if (!isEditing) { + return ( +
+ {!hasAnyFilters(mcpGroupOptions) ? ( + All servers and tools allowed + ) : ( +
+
{"Default policy: " + policy}
+ {mcpGroupOptions?.mcpServers && Object.keys(mcpGroupOptions.mcpServers).length > 0 && ( +
+ {Object.keys(mcpGroupOptions.mcpServers).length + " server(s) with custom filters"} +
+ )} +
+ )} +
+ ) + } + + return ( +
+ {/* Default Policy selector */} +
+ + +
+ + {/* Server list */} +
+
+ MCP Servers + + {serverCount + (serverCount === 1 ? " server" : " servers")} + +
+ + {serverCount === 0 ? ( +
+ No MCP servers connected +
+ ) : ( +
+ {mcpServers.map(function (server) { + var tools = (server.tools || []).map(function (t) { + return { name: t.name, description: t.description } + }) + return ( + + ) + })} +
+ )} +
+
+ ) +} diff --git a/webview-ui/src/components/modes/McpServerFilterRow.tsx b/webview-ui/src/components/modes/McpServerFilterRow.tsx new file mode 100644 index 00000000000..b88cb901a61 --- /dev/null +++ b/webview-ui/src/components/modes/McpServerFilterRow.tsx @@ -0,0 +1,324 @@ +import { useState, useCallback } from "react" + +import type { McpServerFilter } from "@roo-code/types" + +import { Checkbox } from "@/components/ui/checkbox" +import { ToggleSwitch } from "@/components/ui/toggle-switch" + +type FilterMode = "allowAll" | "allowlist" | "blocklist" + +export interface McpServerFilterRowProps { + serverName: string + serverStatus: string + availableTools: Array<{ name: string; description?: string }> + filter: McpServerFilter | undefined + onFilterChange: (serverName: string, filter: McpServerFilter | undefined) => void + isEditing: boolean +} + +function getFilterMode(filter: McpServerFilter | undefined): FilterMode { + if (!filter) { + return "allowAll" + } + if (filter.allowedTools) { + return "allowlist" + } + if (filter.disabledTools) { + return "blocklist" + } + return "allowAll" +} + +function isToolEnabled(toolName: string, filter: McpServerFilter | undefined): boolean { + if (!filter) { + return true + } + if (filter.allowedTools) { + return filter.allowedTools.includes(toolName) + } + if (filter.disabledTools) { + return !filter.disabledTools.includes(toolName) + } + return true +} + +function getEnabledToolCount(tools: Array<{ name: string }>, filter: McpServerFilter | undefined): number { + return tools.filter(function (t) { + return isToolEnabled(t.name, filter) + }).length +} + +function getStatusDotClass(status: string, isDisabled: boolean): string { + if (isDisabled) { + return "bg-vscode-descriptionForeground" + } + if (status === "connected") { + return "bg-vscode-charts-green" + } + if (status === "connecting") { + return "bg-vscode-charts-yellow" + } + return "bg-vscode-descriptionForeground" +} + +const FILTER_MODE_LABELS: Record = { + allowAll: "Allow All", + allowlist: "Allowlist", + blocklist: "Blocklist", +} + +export function McpServerFilterRow({ + serverName, + serverStatus, + availableTools, + filter, + onFilterChange, + isEditing, +}: McpServerFilterRowProps) { + const [isExpanded, setIsExpanded] = useState(false) + + const isDisabled = filter?.disabled === true + const filterMode = getFilterMode(filter) + const enabledCount = getEnabledToolCount(availableTools, filter) + const totalCount = availableTools.length + + const handleToggleServer = useCallback( + function () { + if (isDisabled) { + // Re-enable: remove disabled flag, keep other filter settings + var updated: McpServerFilter | undefined = filter ? { ...filter, disabled: undefined } : undefined + // Clean up empty object + if (updated && !updated.allowedTools && !updated.disabledTools && !updated.disabled) { + updated = undefined + } + onFilterChange(serverName, updated) + } else { + // Disable the server + var newFilter: McpServerFilter = filter ? { ...filter, disabled: true } : { disabled: true } + onFilterChange(serverName, newFilter) + } + }, + [isDisabled, filter, serverName, onFilterChange], + ) + + var handleToggleExpand = useCallback( + function () { + if (!isDisabled) { + setIsExpanded(function (prev) { + return !prev + }) + } + }, + [isDisabled], + ) + + var handleFilterModeChange = useCallback( + function (newMode: FilterMode) { + if (newMode === "allowAll") { + var cleaned: McpServerFilter | undefined = filter ? { disabled: filter.disabled } : undefined + if (cleaned && !cleaned.disabled) { + cleaned = undefined + } + onFilterChange(serverName, cleaned) + } else if (newMode === "allowlist") { + // Start allowlist with all tools included + var allNames = availableTools.map(function (t) { + return t.name + }) + onFilterChange(serverName, { + ...filter, + allowedTools: allNames, + disabledTools: undefined, + }) + } else { + // Start blocklist with none disabled + onFilterChange(serverName, { + ...filter, + disabledTools: [], + allowedTools: undefined, + }) + } + }, + [filter, serverName, availableTools, onFilterChange], + ) + + var handleToggleTool = useCallback( + function (toolName: string) { + var currentlyEnabled = isToolEnabled(toolName, filter) + + if (filterMode === "allowlist") { + var currentAllowed = filter?.allowedTools || [] + var newAllowed = currentlyEnabled + ? currentAllowed.filter(function (n) { + return n !== toolName + }) + : currentAllowed.concat([toolName]) + onFilterChange(serverName, { + ...filter, + allowedTools: newAllowed, + }) + } else if (filterMode === "blocklist") { + var currentDisabled = filter?.disabledTools || [] + var newDisabled = currentlyEnabled + ? currentDisabled.concat([toolName]) + : currentDisabled.filter(function (n) { + return n !== toolName + }) + onFilterChange(serverName, { + ...filter, + disabledTools: newDisabled, + }) + } + }, + [filter, filterMode, serverName, onFilterChange], + ) + + // Read-only mode + if (!isEditing) { + return ( +
+
+ {serverName} + + {isDisabled ? "disabled" : enabledCount + " of " + totalCount + " tools allowed"} + +
+ ) + } + + // Editable mode + return ( +
+ {/* Server header row */} +
+ {/* Chevron */} + {!isDisabled && ( + + )} + + {/* Status dot */} +
+ + {/* Server name */} + {serverName} + + {/* Tool count */} + {totalCount + " tools"} + + {/* Spacer */} +
+ + {/* Toggle switch */} +
+ +
+
+ + {/* Expandable tool list */} + {!isDisabled && isExpanded && ( +
+ {/* Filter mode selector */} +
+ Mode: + {(["allowAll", "allowlist", "blocklist"] as FilterMode[]).map(function (mode) { + return ( + + ) + })} +
+ + {/* Tool rows */} + {filterMode !== "allowAll" && availableTools.length > 0 && ( +
+ {availableTools.map(function (tool) { + var enabled = isToolEnabled(tool.name, filter) + return ( +
+ + + {tool.name} + + {tool.description && ( + + {tool.description} + + )} +
+ ) + })} +
+ )} + + {/* Allow all message */} + {filterMode === "allowAll" && ( +
+ All tools are enabled for this mode. +
+ )} + + {/* No tools message */} + {availableTools.length === 0 && ( +
No tools available.
+ )} +
+ )} +
+ ) +} + +export default McpServerFilterRow diff --git a/webview-ui/src/components/modes/ModesView.tsx b/webview-ui/src/components/modes/ModesView.tsx index 4da0b17ec09..5713a29355e 100644 --- a/webview-ui/src/components/modes/ModesView.tsx +++ b/webview-ui/src/components/modes/ModesView.tsx @@ -10,7 +10,7 @@ import { import { Trans } from "react-i18next" import { ChevronDown, X, Upload, Download } from "lucide-react" -import { ModeConfig, GroupEntry, PromptComponent, ToolGroup, modeConfigSchema } from "@roo-code/types" +import { ModeConfig, GroupEntry, McpGroupOptions, PromptComponent, ToolGroup, modeConfigSchema } from "@roo-code/types" import { Mode, @@ -25,6 +25,7 @@ import { import { TOOL_GROUPS } from "@roo/tools" import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "./groupOptionsCache" +import { McpFilterConfig } from "./McpFilterConfig" import { vscode } from "@src/utils/vscode" import { buildDocLink } from "@src/utils/docLinks" @@ -65,6 +66,28 @@ function getGroupName(group: GroupEntry): ToolGroup { return Array.isArray(group) ? group[0] : group } +// Extract MCP options from a groups array +function getMcpOptionsFromGroups(groups: GroupEntry[]): McpGroupOptions | undefined { + for (var i = 0; i < groups.length; i++) { + var entry = groups[i] + if (Array.isArray(entry) && entry[0] === "mcp" && entry[1]) { + return entry[1] as McpGroupOptions + } + } + return undefined +} + +// Update MCP options in a groups array +function updateMcpOptionsInGroups(groups: GroupEntry[], options: McpGroupOptions | undefined): GroupEntry[] { + return groups.map(function (entry) { + var name = typeof entry === "string" ? entry : entry[0] + if (name === "mcp") { + return options ? (["mcp", options] as GroupEntry) : "mcp" + } + return entry + }) +} + const ModesView = () => { const { t } = useAppTranslation() @@ -76,6 +99,7 @@ const ModesView = () => { customInstructions, setCustomInstructions, customModes, + mcpServers, } = useExtensionState() // Use a local state to track the visually active mode @@ -1140,40 +1164,75 @@ const ModesView = () => {
)} {isToolsEditMode && findModeBySlug(visualMode, customModes) ? ( -
- {availableGroups.map((group) => { - const currentMode = getCurrentMode() - const isCustomMode = findModeBySlug(visualMode, customModes) - const customMode = isCustomMode - const isGroupEnabled = isCustomMode - ? customMode?.groups?.some((g) => getGroupName(g) === group) - : currentMode?.groups?.some((g) => getGroupName(g) === group) - + <> +
+ {availableGroups.map((group) => { + const currentMode = getCurrentMode() + const isCustomMode = findModeBySlug(visualMode, customModes) + const customMode = isCustomMode + const isGroupEnabled = isCustomMode + ? customMode?.groups?.some((g) => getGroupName(g) === group) + : currentMode?.groups?.some((g) => getGroupName(g) === group) + + return ( + + {t(`prompts:tools.toolNames.${group}`)} + {group === "edit" && ( +
+ {t("prompts:tools.allowedFiles")}{" "} + {(() => { + const currentMode = getCurrentMode() + const editGroup = currentMode?.groups?.find( + (g) => + Array.isArray(g) && + g[0] === "edit" && + g[1]?.fileRegex, + ) + if (!Array.isArray(editGroup)) return t("prompts:allFiles") + return ( + editGroup[1].description || + `/${editGroup[1].fileRegex}/` + ) + })()} +
+ )} +
+ ) + })} +
+ {(() => { + const customMode = findModeBySlug(visualMode, customModes) + const isMcpEnabled = customMode?.groups?.some((g) => getGroupName(g) === "mcp") + if (!isMcpEnabled || !customMode) return null + const mcpOptions = getMcpOptionsFromGroups(customMode.groups || []) return ( - - {t(`prompts:tools.toolNames.${group}`)} - {group === "edit" && ( -
- {t("prompts:tools.allowedFiles")}{" "} - {(() => { - const currentMode = getCurrentMode() - const editGroup = currentMode?.groups?.find( - (g) => - Array.isArray(g) && g[0] === "edit" && g[1]?.fileRegex, - ) - if (!Array.isArray(editGroup)) return t("prompts:allFiles") - return editGroup[1].description || `/${editGroup[1].fileRegex}/` - })()} -
- )} -
+ { + const oldGroups = customMode.groups || [] + const newGroups = updateMcpOptionsInGroups(oldGroups, options) + // Also update the cache so toggle off/on preserves config + if (options) { + groupOptionsCache.current.set("mcp", options) + } else { + groupOptionsCache.current.delete("mcp") + } + updateCustomMode(customMode.slug, { + ...customMode, + groups: newGroups, + source: customMode.source || "global", + }) + }} + isEditing={true} + /> ) - })} -
+ })()} + ) : (
{(() => { @@ -1190,13 +1249,28 @@ const ModesView = () => { const groupName = getGroupName(group) const displayName = t(`prompts:tools.toolNames.${groupName}`) if (Array.isArray(group) && group[1]?.fileRegex) { - const description = group[1].description || `/${group[1].fileRegex}/` - return `${displayName} (${description})` + const description = + group[1].description || "/" + group[1].fileRegex + "/" + return displayName + " (" + description + ")" } return displayName }) .join(", ") })()} + {(() => { + const currentMode = getCurrentMode() + const mcpOptions = getMcpOptionsFromGroups(currentMode?.groups || []) + const isMcpEnabled = currentMode?.groups?.some((g) => getGroupName(g) === "mcp") + if (!isMcpEnabled) return null + return ( + {}} + isEditing={false} + /> + ) + })()}
)}
diff --git a/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx b/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx new file mode 100644 index 00000000000..22760c8ea86 --- /dev/null +++ b/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx @@ -0,0 +1,275 @@ +import React from "react" +import { render, fireEvent, screen } from "@/utils/test-utils" + +import { McpServerFilterRow } from "../McpServerFilterRow" +import type { McpServerFilterRowProps } from "../McpServerFilterRow" + +vi.mock("@/components/ui/checkbox", function () { + return { + Checkbox: function MockCheckbox({ checked, onCheckedChange, "aria-label": ariaLabel }: any) { + return ( + + ) + }, + checkboxVariants: function () { + return "" + }, + } +}) + +vi.mock("@/components/ui/toggle-switch", function () { + return { + ToggleSwitch: function MockToggleSwitch({ + checked, + onChange, + "aria-label": ariaLabel, + "data-testid": testId, + }: any) { + return ( + + }, + SelectContent: function MockSelectContent({ children }: any) { + return
{children}
+ }, + SelectItem: function MockSelectItem({ children, value }: any) { + return
{children}
+ }, + SelectValue: function MockSelectValue() { + return mock-value + }, + } +}) + +/** + * Mock McpServerFilterRow to avoid testing child component internals. + */ +vi.mock("../McpServerFilterRow", function () { + return { + McpServerFilterRow: function MockRow({ serverName }: any) { + return
{serverName}
+ }, + } +}) + +/** + * Creates a minimal McpServer mock with required fields. + */ +function createMockServer(name: string, status?: "connected" | "connecting" | "disconnected"): McpServer { + return { + name: name, + config: "{}", + status: status || "connected", + tools: [{ name: "tool-1", description: "A test tool", inputSchema: undefined }], + } as McpServer +} + +/** + * Helper to render McpFilterConfig with sensible defaults. + * Allows overriding any prop via partial overrides. + */ +function renderConfig(overrides: Partial = {}) { + const defaultProps: McpFilterConfigProps = { + mcpServers: [], + mcpGroupOptions: undefined, + onOptionsChange: vi.fn(), + isEditing: true, + ...overrides, + } + return render() +} + +describe("McpFilterConfig loading state", function () { + beforeEach(function () { + vi.clearAllMocks() + }) + + /** + * When isLoading is true and no servers have been fetched yet, + * the component should display a loading spinner instead of + * the empty "No MCP servers connected" message. + */ + it("shows loading spinner when isLoading is true and no servers", function () { + renderConfig({ + isLoading: true, + mcpServers: [], + isEditing: true, + } as any) + + // Verify spinner is present + const spinner = screen.getByTestId("mcp-loading-spinner") + expect(spinner).toBeInTheDocument() + + // Verify loading text is shown + expect(screen.getByText("Loading MCP servers...")).toBeInTheDocument() + + // Verify empty state message is NOT shown + expect(screen.queryByText("No MCP servers connected")).not.toBeInTheDocument() + }) + + /** + * When isLoading is false and no servers exist, the component + * should show the standard empty state message. + */ + it("shows empty message when isLoading is false and no servers", function () { + renderConfig({ + isLoading: false, + mcpServers: [], + isEditing: true, + } as any) + + // Verify empty message is present + expect(screen.getByText("No MCP servers connected")).toBeInTheDocument() + + // Verify spinner is NOT present + expect(screen.queryByTestId("mcp-loading-spinner")).not.toBeInTheDocument() + }) + + /** + * When servers have loaded successfully, the component should + * render the server list without any spinner or empty message. + */ + it("shows server list when isLoading is false and servers exist", function () { + const servers = [createMockServer("my-mcp-server"), createMockServer("another-server")] + + renderConfig({ + isLoading: false, + mcpServers: servers, + isEditing: true, + } as any) + + // Verify spinner is NOT present + expect(screen.queryByTestId("mcp-loading-spinner")).not.toBeInTheDocument() + + // Verify empty message is NOT present + expect(screen.queryByText("No MCP servers connected")).not.toBeInTheDocument() + + // Verify server rows are rendered + expect(screen.getByTestId("mcp-server-filter-row-my-mcp-server")).toBeInTheDocument() + expect(screen.getByTestId("mcp-server-filter-row-another-server")).toBeInTheDocument() + }) + + /** + * When isLoading is still true but servers have already been provided, + * the server list should take priority over the loading spinner. + * This handles the case where data arrives before the loading flag resets. + */ + it("does not show spinner when isLoading is true but servers already loaded", function () { + const servers = [createMockServer("existing-server")] + + renderConfig({ + isLoading: true, + mcpServers: servers, + isEditing: true, + } as any) + + // Spinner should NOT be present because servers override loading + expect(screen.queryByTestId("mcp-loading-spinner")).not.toBeInTheDocument() + + // Loading text should NOT be present + expect(screen.queryByText("Loading MCP servers...")).not.toBeInTheDocument() + + // Server content should be present + expect(screen.getByTestId("mcp-server-filter-row-existing-server")).toBeInTheDocument() + }) + + /** + * The loading spinner should also appear in read-only mode + * (isEditing=false) when data hasn't loaded yet. + */ + it("shows spinner in read-only mode when loading", function () { + renderConfig({ + isLoading: true, + mcpServers: [], + isEditing: false, + } as any) + + // Verify spinner is present even in read-only mode + const spinner = screen.getByTestId("mcp-loading-spinner") + expect(spinner).toBeInTheDocument() + + // Verify loading text is shown + expect(screen.getByText("Loading MCP servers...")).toBeInTheDocument() + }) +}) diff --git a/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx b/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx index 22760c8ea86..4aba9e633a5 100644 --- a/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx +++ b/webview-ui/src/components/modes/__tests__/McpServerFilterRow.spec.tsx @@ -46,14 +46,14 @@ vi.mock("@/components/ui/toggle-switch", function () { } }) -var mockTools = [ +const mockTools = [ { name: "tool-a", description: "First tool" }, { name: "tool-b", description: "Second tool" }, { name: "tool-c" }, ] function renderRow(overrides: Partial = {}) { - var defaultProps: McpServerFilterRowProps = { + const defaultProps: McpServerFilterRowProps = { serverName: "test-server", serverStatus: "connected", availableTools: mockTools, @@ -108,48 +108,48 @@ describe("McpServerFilterRow", function () { }) it("shows green dot when connected", function () { - var { container } = renderRow({ serverStatus: "connected" }) + const { container } = renderRow({ serverStatus: "connected" }) expect(container.querySelector(".bg-vscode-charts-green")).toBeInTheDocument() }) it("shows yellow dot when connecting", function () { - var { container } = renderRow({ serverStatus: "connecting" }) + const { container } = renderRow({ serverStatus: "connecting" }) expect(container.querySelector(".bg-vscode-charts-yellow")).toBeInTheDocument() }) it("shows gray dot when disconnected", function () { - var { container } = renderRow({ serverStatus: "disconnected" }) + const { container } = renderRow({ serverStatus: "disconnected" }) expect(container.querySelector(".bg-vscode-descriptionForeground")).toBeInTheDocument() }) it("toggle checked when server enabled", function () { renderRow({ filter: undefined }) - var toggle = screen.getByRole("switch", { name: "Toggle test-server server" }) + const toggle = screen.getByRole("switch", { name: "Toggle test-server server" }) expect(toggle).toHaveAttribute("aria-checked", "true") }) it("toggle unchecked when server disabled", function () { renderRow({ filter: { disabled: true } }) - var toggle = screen.getByRole("switch", { name: "Toggle test-server server" }) + const toggle = screen.getByRole("switch", { name: "Toggle test-server server" }) expect(toggle).toHaveAttribute("aria-checked", "false") }) }) describe("toggle server", function () { it("disables server on toggle when enabled", function () { - var { onFilterChange } = renderRow({ filter: undefined }) + const { onFilterChange } = renderRow({ filter: undefined }) fireEvent.click(screen.getByRole("switch", { name: "Toggle test-server server" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", { disabled: true }) }) it("enables server on toggle when disabled", function () { - var { onFilterChange } = renderRow({ filter: { disabled: true } }) + const { onFilterChange } = renderRow({ filter: { disabled: true } }) fireEvent.click(screen.getByRole("switch", { name: "Toggle test-server server" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", undefined) }) it("preserves allowedTools when re-enabling", function () { - var { onFilterChange } = renderRow({ + const { onFilterChange } = renderRow({ filter: { disabled: true, allowedTools: ["tool-a"] }, }) fireEvent.click(screen.getByRole("switch", { name: "Toggle test-server server" })) @@ -174,14 +174,14 @@ describe("McpServerFilterRow", function () { it("does not expand when server is disabled", function () { renderRow({ filter: { disabled: true } }) - var row = screen.getByTestId("mcp-server-filter-row-test-server") + const row = screen.getByTestId("mcp-server-filter-row-test-server") expect(row.querySelector(".codicon-chevron-right")).not.toBeInTheDocument() }) }) describe("filter mode selector", function () { function expandRow(overrides: Partial = {}) { - var result = renderRow(overrides) + const result = renderRow(overrides) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) return result } @@ -192,7 +192,7 @@ describe("McpServerFilterRow", function () { }) it("switches to allowlist mode", function () { - var { onFilterChange } = expandRow() + const { onFilterChange } = expandRow() fireEvent.click(screen.getByTestId("mcp-filter-mode-btn-allowlist")) expect(onFilterChange).toHaveBeenCalledWith("test-server", { allowedTools: ["tool-a", "tool-b", "tool-c"], @@ -201,7 +201,7 @@ describe("McpServerFilterRow", function () { }) it("switches to blocklist mode", function () { - var { onFilterChange } = expandRow() + const { onFilterChange } = expandRow() fireEvent.click(screen.getByTestId("mcp-filter-mode-btn-blocklist")) expect(onFilterChange).toHaveBeenCalledWith("test-server", { disabledTools: [], @@ -210,7 +210,7 @@ describe("McpServerFilterRow", function () { }) it("switches back to allow all mode", function () { - var { onFilterChange } = expandRow({ filter: { allowedTools: ["tool-a"] } }) + const { onFilterChange } = expandRow({ filter: { allowedTools: ["tool-a"] } }) fireEvent.click(screen.getByTestId("mcp-filter-mode-btn-allowAll")) expect(onFilterChange).toHaveBeenCalledWith("test-server", undefined) }) @@ -226,14 +226,14 @@ describe("McpServerFilterRow", function () { }) it("removes tool from allowlist when unchecked", function () { - var { onFilterChange } = renderRow({ filter: { allowedTools: ["tool-a", "tool-b"] } }) + const { onFilterChange } = renderRow({ filter: { allowedTools: ["tool-a", "tool-b"] } }) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) fireEvent.click(screen.getByRole("checkbox", { name: "Disable tool tool-a" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", { allowedTools: ["tool-b"] }) }) it("adds tool to allowlist when checked", function () { - var { onFilterChange } = renderRow({ filter: { allowedTools: ["tool-a"] } }) + const { onFilterChange } = renderRow({ filter: { allowedTools: ["tool-a"] } }) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) fireEvent.click(screen.getByRole("checkbox", { name: "Enable tool tool-b" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", { @@ -244,14 +244,14 @@ describe("McpServerFilterRow", function () { describe("tool checkboxes in blocklist mode", function () { it("adds tool to disabledTools when unchecked", function () { - var { onFilterChange } = renderRow({ filter: { disabledTools: [] } }) + const { onFilterChange } = renderRow({ filter: { disabledTools: [] } }) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) fireEvent.click(screen.getByRole("checkbox", { name: "Disable tool tool-a" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", { disabledTools: ["tool-a"] }) }) it("removes tool from disabledTools when re-checked", function () { - var { onFilterChange } = renderRow({ filter: { disabledTools: ["tool-b"] } }) + const { onFilterChange } = renderRow({ filter: { disabledTools: ["tool-b"] } }) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) fireEvent.click(screen.getByRole("checkbox", { name: "Enable tool tool-b" })) expect(onFilterChange).toHaveBeenCalledWith("test-server", { disabledTools: [] }) @@ -268,7 +268,7 @@ describe("McpServerFilterRow", function () { it("does not render description span when not present", function () { renderRow({ filter: { disabledTools: [] } }) fireEvent.click(screen.getByTestId("mcp-server-header-test-server")) - var toolC = screen.getByTestId("mcp-tool-filter-tool-c") + const toolC = screen.getByTestId("mcp-tool-filter-tool-c") expect(toolC.querySelectorAll("span").length).toBe(1) }) }) diff --git a/webview-ui/src/components/modes/__tests__/groupOptionsCache.spec.ts b/webview-ui/src/components/modes/__tests__/groupOptionsCache.spec.ts new file mode 100644 index 00000000000..f091fe9973e --- /dev/null +++ b/webview-ui/src/components/modes/__tests__/groupOptionsCache.spec.ts @@ -0,0 +1,75 @@ +/** + * SE-2: Tests for groupOptionsCache with mode-scoped keys. + * + * The updated functions accept an additional `modeSlug` parameter and + * store / retrieve cache entries under composite keys of the form + * 'modeSlug:groupName' so that options from different modes never + * collide. + */ + +import type { GroupEntry, ToolGroup } from "@roo-code/types" + +import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "../groupOptionsCache" + +describe("groupOptionsCache mode-scoped keys", () => { + it("syncCacheFromGroups stores options under mode-scoped key", () => { + const cache = new Map() + const groups: GroupEntry[] = [["mcp", { mcpServers: { s1: {} } }]] + + syncCacheFromGroups(cache, groups, "modeA") + + expect(cache.get("modeA:mcp")).toEqual({ mcpServers: { s1: {} } }) + // Old un-scoped key must NOT be used + expect(cache.get("mcp")).toBeUndefined() + }) + + it("syncCacheFromGroups keeps separate entries for different modes", () => { + const cache = new Map() + + const groupsA: GroupEntry[] = [["mcp", { mcpDefaultPolicy: "deny" }]] + const groupsB: GroupEntry[] = [["mcp", { mcpServers: { x: { disabled: true } } }]] + + syncCacheFromGroups(cache, groupsA, "modeA") + syncCacheFromGroups(cache, groupsB, "modeB") + + expect(cache.get("modeA:mcp")).toEqual({ mcpDefaultPolicy: "deny" }) + expect(cache.get("modeB:mcp")).toEqual({ + mcpServers: { x: { disabled: true } }, + }) + }) + + it("removeGroupWithCache stores options under mode-scoped key", () => { + const cache = new Map() + const groups: GroupEntry[] = [["mcp", { mcpServers: { s1: {} } }], "read"] + + const result = removeGroupWithCache(cache, groups, "mcp", "modeA") + + expect(result).toHaveLength(1) + expect(result[0]).toBe("read") + expect(cache.get("modeA:mcp")).toEqual({ mcpServers: { s1: {} } }) + }) + + it("addGroupWithCache restores correct mode options", () => { + const cache = new Map() + cache.set("modeA:mcp", { mcpDefaultPolicy: "deny" }) + cache.set("modeB:mcp", { mcpServers: { x: {} } }) + + const groups: GroupEntry[] = ["read"] + const result = addGroupWithCache(cache, groups, "mcp" as ToolGroup, "modeA") + + expect(result).toHaveLength(2) + expect(result[1]).toEqual(["mcp", { mcpDefaultPolicy: "deny" }]) + }) + + it("addGroupWithCache returns plain string when no cache entry for mode", () => { + const cache = new Map() + cache.set("modeA:mcp", { mcpDefaultPolicy: "deny" }) + + const groups: GroupEntry[] = ["read"] + // Request for modeB which has NO cached entry + const result = addGroupWithCache(cache, groups, "mcp" as ToolGroup, "modeB") + + expect(result).toHaveLength(2) + expect(result[1]).toBe("mcp") + }) +}) diff --git a/webview-ui/src/components/modes/groupOptionsCache.ts b/webview-ui/src/components/modes/groupOptionsCache.ts index 3a3cfa259b0..03449f680bc 100644 --- a/webview-ui/src/components/modes/groupOptionsCache.ts +++ b/webview-ui/src/components/modes/groupOptionsCache.ts @@ -11,14 +11,22 @@ export function getGroupName(entry: GroupEntry): string { return entry[0] } +/** + * Build a mode-scoped cache key: 'modeSlug:groupName'. + */ +function cacheKey(modeSlug: string, groupName: string): string { + return modeSlug + ":" + groupName +} + /** * Synchronise a cache map with the current groups array. - * For every tuple entry, upsert its options into the cache. + * For every tuple entry, upsert its options into the cache + * under a mode-scoped key. */ -export function syncCacheFromGroups(cache: Map, groups: GroupEntry[]): void { +export function syncCacheFromGroups(cache: Map, groups: GroupEntry[], modeSlug: string): void { for (const entry of groups) { if (Array.isArray(entry) && entry[1]) { - cache.set(entry[0], entry[1]) + cache.set(cacheKey(modeSlug, entry[0]), entry[1]) } } } @@ -26,7 +34,7 @@ export function syncCacheFromGroups(cache: Map, groups: GroupEnt /** * Remove a group by name. When the entry being removed is a * tuple (i.e. it carries options), stash those options in the - * cache so they can be restored later. + * cache under a mode-scoped key so they can be restored later. * * Returns the filtered groups array. */ @@ -34,18 +42,19 @@ export function removeGroupWithCache( cache: Map, groups: GroupEntry[], groupName: string, + modeSlug: string, ): GroupEntry[] { const entry = groups.find((g) => getGroupName(g) === groupName) if (entry && Array.isArray(entry) && entry[1]) { - cache.set(entry[0], entry[1]) + cache.set(cacheKey(modeSlug, entry[0]), entry[1]) } return groups.filter((g) => getGroupName(g) !== groupName) } /** * Add a group by name. If the cache contains previously-saved - * options for this group, restore it as a tuple [name, options]. - * Otherwise add it as a plain string. + * options for this group under the current mode, restore it as + * a tuple [name, options]. Otherwise add it as a plain string. * * Returns the new groups array. */ @@ -53,8 +62,9 @@ export function addGroupWithCache( cache: Map, groups: GroupEntry[], groupName: ToolGroup, + modeSlug: string, ): GroupEntry[] { - const cached = cache.get(groupName) + const cached = cache.get(cacheKey(modeSlug, groupName)) if (cached) { return [...groups, [groupName, cached] as GroupEntry] } diff --git a/webview-ui/src/components/modes/useGroupOptionsCache.ts b/webview-ui/src/components/modes/useGroupOptionsCache.ts index 8e18354af83..709c92bf6ab 100644 --- a/webview-ui/src/components/modes/useGroupOptionsCache.ts +++ b/webview-ui/src/components/modes/useGroupOptionsCache.ts @@ -12,22 +12,28 @@ import { syncCacheFromGroups, removeGroupWithCache, addGroupWithCache } from "./ * back on — without the cache, the MCP filter config (mcpServers, * mcpDefaultPolicy) would be discarded on uncheck. */ -export function useGroupOptionsCache(groups: GroupEntry[]) { +export function useGroupOptionsCache(groups: GroupEntry[], modeSlug: string) { const groupOptionsCache = useRef>(new Map()) // Sync cache with external state: if external state has tuple // entries, update the cache so that toggles preserve them. useEffect(() => { - syncCacheFromGroups(groupOptionsCache.current, groups) - }, [groups]) + syncCacheFromGroups(groupOptionsCache.current, groups, modeSlug) + }, [groups, modeSlug]) - const removeGroup = useCallback((currentGroups: GroupEntry[], groupName: string): GroupEntry[] => { - return removeGroupWithCache(groupOptionsCache.current, currentGroups, groupName) - }, []) + const removeGroup = useCallback( + (currentGroups: GroupEntry[], groupName: string): GroupEntry[] => { + return removeGroupWithCache(groupOptionsCache.current, currentGroups, groupName, modeSlug) + }, + [modeSlug], + ) - const addGroup = useCallback((currentGroups: GroupEntry[], groupName: ToolGroup): GroupEntry[] => { - return addGroupWithCache(groupOptionsCache.current, currentGroups, groupName) - }, []) + const addGroup = useCallback( + (currentGroups: GroupEntry[], groupName: ToolGroup): GroupEntry[] => { + return addGroupWithCache(groupOptionsCache.current, currentGroups, groupName, modeSlug) + }, + [modeSlug], + ) return { removeGroup, addGroup, cache: groupOptionsCache } } diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index ce7a607d9a8..a9515e97df8 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -37,6 +37,7 @@ export interface ExtensionStateContextType extends ExtensionState { showWelcome: boolean theme: any mcpServers: McpServer[] + mcpServersLoaded: boolean currentCheckpoint?: string currentTaskTodos?: TodoItem[] // Initial todos for the current task filePaths: string[] @@ -272,6 +273,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode const [openedTabs, setOpenedTabs] = useState>([]) const [commands, setCommands] = useState([]) const [mcpServers, setMcpServers] = useState([]) + const [mcpServersLoaded, setMcpServersLoaded] = useState(false) const [currentCheckpoint, setCurrentCheckpoint] = useState() const [extensionRouterModels, setExtensionRouterModels] = useState(undefined) const [marketplaceItems, setMarketplaceItems] = useState([]) @@ -400,6 +402,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode } case "mcpServers": { setMcpServers(message.mcpServers ?? []) + setMcpServersLoaded(true) break } case "currentCheckpointUpdated": { @@ -492,6 +495,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode showWelcome, theme, mcpServers, + mcpServersLoaded, currentCheckpoint, filePaths, openedTabs,