Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 117 additions & 57 deletions core/src/main/java/com/google/adk/tools/mcp/McpToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.ToolPredicate;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Booleans;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
Expand All @@ -32,6 +34,7 @@
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -51,7 +54,7 @@ public class McpToolset implements BaseToolset {
private final McpSessionManager mcpSessionManager;
private McpSyncClient mcpSession;
private final ObjectMapper objectMapper;
private final Optional<Object> toolFilter;
private final @Nullable Object toolFilter;

private static final int MAX_RETRIES = 3;
private static final long RETRY_DELAY_MILLIS = 100;
Expand All @@ -62,17 +65,29 @@ public class McpToolset implements BaseToolset {
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolPredicate A {@link ToolPredicate}
*/
public McpToolset(
SseServerParameters connectionParams,
ObjectMapper objectMapper,
Optional<Object> toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
ToolPredicate toolPredicate) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = Objects.requireNonNull(toolPredicate);
}

/**
* Initializes the McpToolset with SSE server parameters.
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolNames A list of tool names
*/
public McpToolset(
SseServerParameters connectionParams, ObjectMapper objectMapper, List<String> toolNames) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = ImmutableList.copyOf(toolNames);
}

/**
Expand All @@ -82,44 +97,49 @@ public McpToolset(
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(SseServerParameters connectionParams, ObjectMapper objectMapper) {
this(connectionParams, objectMapper, Optional.empty());
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = null;
}

/**
* Initializes the McpToolset with local server parameters.
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolPredicate A {@link ToolPredicate}
*/
public McpToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, Optional<Object> toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolPredicate) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = Objects.requireNonNull(toolPredicate);
}

/**
* Initializes the McpToolset with local server parameters and no tool filter.
* Initializes the McpToolset with local server parameters.
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolNames A list of tool names
*/
public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) {
this(connectionParams, objectMapper, Optional.empty());
public McpToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, List<String> toolNames) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = ImmutableList.copyOf(toolNames);
}

/**
* Initializes the McpToolset with SSE server parameters, using the ObjectMapper used across the
* ADK.
* Initializes the McpToolset with local server parameters and no tool filter.
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(SseServerParameters connectionParams, Optional<Object> toolFilter) {
this(connectionParams, JsonBaseModel.getMapper(), toolFilter);
public McpToolset(ServerParameters connectionParams, ObjectMapper objectMapper) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = null;
}

/**
Expand All @@ -129,62 +149,101 @@ public McpToolset(SseServerParameters connectionParams, Optional<Object> toolFil
* @param connectionParams The SSE connection parameters to the MCP server.
*/
public McpToolset(SseServerParameters connectionParams) {
this(connectionParams, JsonBaseModel.getMapper(), Optional.empty());
this(connectionParams, JsonBaseModel.getMapper());
}

/**
* Initializes the McpToolset with local server parameters, using the ObjectMapper used across the
* ADK.
* ADK and no tool filter.
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
*/
public McpToolset(ServerParameters connectionParams, Optional<Object> toolFilter) {
this(connectionParams, JsonBaseModel.getMapper(), toolFilter);
public McpToolset(ServerParameters connectionParams) {
this(connectionParams, JsonBaseModel.getMapper());
}

/**
* Initializes the McpToolset with local server parameters, using the ObjectMapper used across the
* ADK and no tool filter.
* Initializes the McpToolset with an McpSessionManager.
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param mcpSessionManager A McpSessionManager instance for testing.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolPredicate A {@link ToolPredicate}
*/
public McpToolset(ServerParameters connectionParams) {
this(connectionParams, JsonBaseModel.getMapper(), Optional.empty());
public McpToolset(
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolPredicate) {
this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager);
this.objectMapper = Objects.requireNonNull(objectMapper);
this.toolFilter = Objects.requireNonNull(toolPredicate);
}

/**
* Initializes the McpToolset with an McpSessionManager.
*
* @param mcpSessionManager A McpSessionManager instance for testing.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolNames A list of tool names
*/
public McpToolset(
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional<Object> toolFilter) {
Objects.requireNonNull(mcpSessionManager);
Objects.requireNonNull(objectMapper);
this.mcpSessionManager = mcpSessionManager;
this.objectMapper = objectMapper;
this.toolFilter = toolFilter;
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, List<String> toolNames) {
this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager);
this.objectMapper = Objects.requireNonNull(objectMapper);
this.toolFilter = ImmutableList.copyOf(toolNames);
}

/**
* Initializes the McpToolset with an McpSessionManager and no tool filter.
*
* @param mcpSessionManager A McpSessionManager instance for testing.
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(McpSessionManager mcpSessionManager, ObjectMapper objectMapper) {
this.mcpSessionManager = Objects.requireNonNull(mcpSessionManager);
this.objectMapper = Objects.requireNonNull(objectMapper);
this.toolFilter = null;
}

/**
* Initializes the McpToolset with Steamable HTTP server parameters.
* Initializes the McpToolset with Streamable HTTP server parameters.
*
* @param connectionParams The Streamable HTTP connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolPredicate A {@link ToolPredicate}
*/
public McpToolset(
StreamableHttpServerParameters connectionParams,
ObjectMapper objectMapper,
Optional<Object> toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
ToolPredicate toolPredicate) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = Objects.requireNonNull(toolPredicate);
}

/**
* Initializes the McpToolset with Streamable HTTP server parameters.
*
* @param connectionParams The Streamable HTTP connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolNames A list of tool names
*/
public McpToolset(
StreamableHttpServerParameters connectionParams,
ObjectMapper objectMapper,
List<String> toolNames) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = ImmutableList.copyOf(toolNames);
}

/**
* Initializes the McpToolset with Streamable HTTP server parameters and no tool filter.
*
* @param connectionParams The Streamable HTTP connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
*/
public McpToolset(StreamableHttpServerParameters connectionParams, ObjectMapper objectMapper) {
this.objectMapper = Objects.requireNonNull(objectMapper);
this.mcpSessionManager = new McpSessionManager(Objects.requireNonNull(connectionParams));
this.toolFilter = null;
}

/**
Expand All @@ -194,7 +253,7 @@ public McpToolset(
* @param connectionParams The Streamable HTTP connection parameters to the MCP server.
*/
public McpToolset(StreamableHttpServerParameters connectionParams) {
this(connectionParams, JsonBaseModel.getMapper(), Optional.empty());
this(connectionParams, JsonBaseModel.getMapper());
}

@Override
Expand All @@ -215,8 +274,7 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
tool ->
new McpTool(
tool, this.mcpSession, this.mcpSessionManager, this.objectMapper))
.filter(
tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext)));
.filter(tool -> isToolSelected(tool, toolFilter, readonlyContext)));
})
.retryWhen(
errorObservable ->
Expand Down Expand Up @@ -357,16 +415,18 @@ public static McpToolset fromConfig(BaseTool.ToolConfig config, String configAbs
+ " for McpToolset");
}

// Convert tool filter to Optional<Object>
Optional<Object> toolFilter = Optional.ofNullable(mcpToolsetConfig.toolFilter());

List<String> toolNames = mcpToolsetConfig.toolFilter();
Object connectionParameters =
Optional.<Object>ofNullable(mcpToolsetConfig.stdioConnectionParams())
.or(() -> Optional.ofNullable(mcpToolsetConfig.sseServerParams()))
.orElse(mcpToolsetConfig.stdioConnectionParams());

// Create McpToolset with McpSessionManager having appropriate connection parameters
return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolFilter);
if (toolNames != null) {
return new McpToolset(new McpSessionManager(connectionParameters), mapper, toolNames);
} else {
return new McpToolset(new McpSessionManager(connectionParameters), mapper);
}
} catch (IllegalArgumentException e) {
throw new ConfigurationException("Failed to parse McpToolsetConfig from ToolArgsConfig", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.List;
import java.util.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -324,7 +323,7 @@ public void getTools_withToolFilter_returnsFilteredTools() {
when(mockMcpSyncClient.listTools()).thenReturn(mockResult);

McpToolset toolset =
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.of(toolFilter));
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), toolFilter);

List<BaseTool> tools = toolset.getTools(mockReadonlyContext).toList().blockingGet();

Expand All @@ -340,8 +339,7 @@ public void getTools_retriesAndFailsAfterMaxRetries() {
when(mockMcpSessionManager.createSession()).thenReturn(mockMcpSyncClient);
when(mockMcpSyncClient.listTools()).thenThrow(new RuntimeException("Test Exception"));

McpToolset toolset =
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty());
McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper());

toolset
.getTools(mockReadonlyContext)
Expand All @@ -362,8 +360,7 @@ public void getTools_succeedsOnLastRetryAttempt() {
.thenThrow(new RuntimeException("Attempt 2 failed"))
.thenReturn(mockResult);

McpToolset toolset =
new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper(), Optional.empty());
McpToolset toolset = new McpToolset(mockMcpSessionManager, JsonBaseModel.getMapper());

List<BaseTool> tools = toolset.getTools(mockReadonlyContext).toList().blockingGet();

Expand Down
Loading