Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spring-ai-modules/spring-ai-mcp/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

<properties>
<java.version>21</java.version>
<spring-ai.version>1.0.1</spring-ai.version>
<spring-ai.version>1.1.2</spring-ai.version>
</properties>

<profiles>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.baeldung.springai.mcp.test;

import org.springaicommunity.mcp.annotation.McpTool;
import org.springaicommunity.mcp.annotation.McpToolParam;
import org.springframework.stereotype.Component;

@Component
public class ExchangeRateMcpTool {

private final ExchangeRateService exchangeRateService;

public ExchangeRateMcpTool(ExchangeRateService exchangeRateService) {
this.exchangeRateService = exchangeRateService;
}

@McpTool(description = "Get the latest exchange rates for a base currency")
public ExchangeRateResponse getLatestExchangeRate(
@McpToolParam(description = "Base currency code, e.g. GBP, USD, EUR", required = true) String base) {
return exchangeRateService.getLatestExchangeRate(base);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.baeldung.springai.mcp.test;

import java.util.Map;

public record ExchangeRateResponse(double amount, String base, String date, Map<String, Double> rates) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.baeldung.springai.mcp.test;

import org.springframework.stereotype.Service;
import org.springframework.web.client.RestClient;

@Service
public class ExchangeRateService {

private static final String FRANKFURTER_URL = "https://api.frankfurter.dev/v1/latest?base={base}";

private final RestClient restClient;

public ExchangeRateService(RestClient.Builder restClientBuilder) {
this.restClient = restClientBuilder.build();
}

public ExchangeRateResponse getLatestExchangeRate(String base) {
if (base == null || base.isBlank()) {
throw new IllegalArgumentException("base is required");
}
return restClient.get()
.uri(FRANKFURTER_URL, base.trim().toUpperCase())
.retrieve()
.body(ExchangeRateResponse.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.baeldung.springai.mcp.test;

import org.springframework.ai.model.anthropic.autoconfigure.AnthropicChatAutoConfiguration;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.PropertySource;

/**
* Excluding the below auto-configuration to avoid start up
* failure. Its corresponding starter is present on the classpath but is
* only needed by the MCP client application.
*/
@SpringBootApplication(exclude = {
AnthropicChatAutoConfiguration.class
})
@PropertySource("classpath:application-test-mcp-server.properties")
public class TestMcpApplication {

public static void main(String[] args) {
SpringApplication.run(TestMcpApplication.class, args);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
com.baeldung.author-tools.enabled=false
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.baeldung.springai.mcp.test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

import java.util.Map;
import java.util.Objects;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.boot.test.web.server.LocalServerPort;

@SpringBootTest(
webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT
)
class ExchangeRateMcpToolSseIntegrationTest {

@LocalServerPort
private int port;

@Autowired
private TestMcpClientFactory testMcpClientFactory;

@MockBean
private ExchangeRateService exchangeRateService;

private McpSyncClient client;

@BeforeEach
void setUp() {
client = testMcpClientFactory.create("http://localhost:" + port);
client.initialize();
}

@AfterEach
void cleanUp() {
client.closeGracefully();
}

@Test
void whenMcpClientListTools_thenTheToolIsRegistered() {
boolean registered = client.listTools().tools().stream()
.anyMatch(tool -> Objects.equals(tool.name(), "getLatestExchangeRate"));
assertThat(registered).isTrue();
}

@Test
void whenMcpClientCallTool_thenTheToolReturnsMockedResponse() {
when(exchangeRateService.getLatestExchangeRate("GBP")).thenReturn(
new ExchangeRateResponse(1.0, "GBP", "2026-03-08", Map.of("USD", 1.27))
);

McpSchema.Tool exchangeRateTool = client.listTools().tools().stream()
.filter(tool -> "getLatestExchangeRate".equals(tool.name()))
.findFirst()
.orElseThrow();

String argumentName = exchangeRateTool.inputSchema().properties().keySet().stream()
.findFirst()
.orElseThrow();

McpSchema.CallToolResult result = client.callTool(
new McpSchema.CallToolRequest("getLatestExchangeRate", Map.of(argumentName, "GBP"))
);

assertThat(result).isNotNull();
assertThat(result.isError()).isFalse();
assertTrue(result.toString().contains("GBP"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.baeldung.springai.mcp.test;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.boot.test.web.server.LocalServerPort;

import java.util.Map;
import java.util.Objects;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.when;

@SpringBootTest(
webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
properties = "spring.ai.mcp.server.protocol=streamable"
)
class ExchangeRateMcpToolStreamableIntegrationTest {

@LocalServerPort
private int port;

@MockBean
private ExchangeRateService exchangeRateService;

@Autowired
private TestMcpClientFactory testMcpClientFactory;

private McpSyncClient client;

@BeforeEach
void setUp() {
client = testMcpClientFactory.create("http://localhost:" + port);
client.initialize();
}

@AfterEach
void cleanUp() {
client.close();
}

@Test
void whenMcpClientListTools_thenTheToolIsRegistered() {
boolean registered = client.listTools().tools().stream()
.anyMatch(tool -> Objects.equals(tool.name(), "getLatestExchangeRate"));
assertThat(registered).isTrue();
}

@Test
void whenMcpClientCallTool_thenTheToolReturnsMockedResponse() {
when(exchangeRateService.getLatestExchangeRate("GBP")).thenReturn(
new ExchangeRateResponse(1.0, "GBP", "2026-03-08", Map.of("USD", 1.27))
);

McpSchema.Tool exchangeRateTool = client.listTools().tools().stream()
.filter(tool -> "getLatestExchangeRate".equals(tool.name()))
.findFirst()
.orElseThrow();

String argumentName = exchangeRateTool.inputSchema().properties().keySet().stream()
.findFirst()
.orElseThrow();

McpSchema.CallToolResult result = client.callTool(
new McpSchema.CallToolRequest("getLatestExchangeRate", Map.of(argumentName, "GBP"))
);

assertThat(result).isNotNull();
assertThat(result.isError()).isFalse();
assertTrue(result.toString().contains("GBP"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.baeldung.springai.mcp.test;

import org.junit.jupiter.api.Test;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class ExchangeRateMcpToolUnitTest {

@Test
void whenBaseIsNotBlank_thenGetExchangeRateShouldReturnResponse() {
ExchangeRateService exchangeRateService = mock(ExchangeRateService.class);
ExchangeRateResponse expected = new ExchangeRateResponse(
1.0,
"GBP",
"2026-03-08",
Map.of("USD", 1.27, "EUR", 1.17)
);
when(exchangeRateService.getLatestExchangeRate("gbp")).thenReturn(expected);

ExchangeRateMcpTool tool = new ExchangeRateMcpTool(exchangeRateService);
ExchangeRateResponse actual = tool.getLatestExchangeRate("gbp");

assertThat(actual).isEqualTo(expected);
verify(exchangeRateService).getLatestExchangeRate("gbp");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.baeldung.springai.mcp.test;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
public class TestMcpClientFactory {

private final String protocol;

public TestMcpClientFactory(@Value("${spring.ai.mcp.server.protocol:sse}") String protocol) {
this.protocol = protocol;
}

public McpSyncClient create(String baseUrl) {
String resolvedProtocol = protocol.trim().toLowerCase();
return switch (resolvedProtocol) {
case "sse" -> McpClient.sync(HttpClientSseClientTransport.builder(baseUrl)
.sseEndpoint("/sse")
.build()
).build();
case "streamable" -> McpClient.sync(HttpClientStreamableHttpTransport.builder(baseUrl)
.endpoint("/mcp")
.build()
).build();
default -> throw new IllegalArgumentException("Unknown MCP protocol: " + protocol);
};
}
}