diff --git a/.gitmodules b/.gitmodules index d5cd4211..18a3014b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "third_party/gopher-orch"] path = third_party/gopher-orch url = https://github.com/GopherSecurity/gopher-orch.git + branch = br_release diff --git a/build.sh b/build.sh index a3ab76c0..38d73fb0 100755 --- a/build.sh +++ b/build.sh @@ -97,7 +97,9 @@ cmake .. \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX="${SCRIPT_DIR}/native" \ -DBUILD_SHARED_LIBS=ON \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DBUILD_EXAMPLES=OFF \ + -DBUILD_TESTS=OFF # Build echo -e "${YELLOW} Compiling...${NC}" @@ -163,12 +165,30 @@ dotnet build -c Release echo -e "${GREEN}✓ C# SDK built successfully${NC}" echo "" -# Step 5: Run tests -echo -e "${YELLOW}Step 5: Running tests...${NC}" -dotnet test -c Release --no-build 2>/dev/null && echo -e "${GREEN}✓ Tests passed${NC}" || echo -e "${YELLOW}⚠ Some tests may have failed (native library required)${NC}" +# Step 5: Build Auth Example +echo -e "${YELLOW}Step 5: Building Auth MCP Server example...${NC}" +AUTH_EXAMPLE_DIR="${SCRIPT_DIR}/examples/auth" +if [ -d "${AUTH_EXAMPLE_DIR}" ]; then + dotnet build "${AUTH_EXAMPLE_DIR}/AuthMcpServer/AuthMcpServer.csproj" -c Release + dotnet build "${AUTH_EXAMPLE_DIR}/AuthMcpServer.Tests/AuthMcpServer.Tests.csproj" -c Release + echo -e "${GREEN}✓ Auth example built successfully${NC}" +else + echo -e "${YELLOW}⚠ Auth example not found at ${AUTH_EXAMPLE_DIR}${NC}" +fi +echo "" + +# Step 6: Run tests +echo -e "${YELLOW}Step 6: Running tests...${NC}" +dotnet test -c Release --no-build 2>/dev/null && echo -e "${GREEN}✓ SDK tests passed${NC}" || echo -e "${YELLOW}⚠ Some SDK tests may have failed (native library required)${NC}" + +# Run auth example tests +if [ -d "${AUTH_EXAMPLE_DIR}" ]; then + echo -e "${YELLOW} Running Auth example tests...${NC}" + dotnet test "${AUTH_EXAMPLE_DIR}/AuthMcpServer.Tests/AuthMcpServer.Tests.csproj" -c Release --no-build && echo -e "${GREEN}✓ Auth example tests passed${NC}" || echo -e "${YELLOW}⚠ Some Auth example tests failed${NC}" +fi # Package NuGet -echo -e "${YELLOW}Step 6: Packaging NuGet...${NC}" +echo -e "${YELLOW}Step 7: Packaging NuGet...${NC}" dotnet pack -c Release --no-build -o "${SCRIPT_DIR}/packages" echo -e "${GREEN}✓ NuGet package created${NC}" @@ -181,3 +201,4 @@ echo -e "Native libraries: ${YELLOW}${NATIVE_LIB_DIR}${NC}" echo -e "Native headers: ${YELLOW}${NATIVE_INCLUDE_DIR}${NC}" echo -e "Run tests: ${YELLOW}dotnet test${NC}" echo -e "Package NuGet: ${YELLOW}dotnet pack${NC}" +echo -e "Run Auth example: ${YELLOW}cd examples/auth && ./run_example.sh${NC}" diff --git a/examples/auth/.gitignore b/examples/auth/.gitignore new file mode 100644 index 00000000..35b75416 --- /dev/null +++ b/examples/auth/.gitignore @@ -0,0 +1,66 @@ +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +[Ww][Ii][Nn]32/ +[Aa][Rr][Mm]/ +[Aa][Rr][Mm]64/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ +[Ll]ogs/ + +# Visual Studio +.vs/ +*.user +*.userosscache +*.sln.docstates +*.suo +*.cache + +# Visual Studio Code +.vscode/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# NuGet +*.nupkg +*.snupkg +**/[Pp]ackages/* +!**/[Pp]ackages/build/ +*.nuget.props +*.nuget.targets + +# Project files +*.csproj.user +project.lock.json +project.fragment.lock.json + +# ASP.NET Scaffolding +ScaffoldingReadMe.txt + +# StyleCop +StyleCopReport.xml + +# Test Results +[Tt]est[Rr]esult*/ +*.trx +*.coverage +*.coveragexml +[Cc]overage/ +coverage*.json +coverage*.xml +*.coverlet.json + +# Publish +publish/ + +# OS files +.DS_Store +Thumbs.db diff --git a/examples/auth/AuthMcpServer.Tests/AuthMcpServer.Tests.csproj b/examples/auth/AuthMcpServer.Tests/AuthMcpServer.Tests.csproj new file mode 100644 index 00000000..c281c1c8 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/AuthMcpServer.Tests.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + + + + + + + + + + + + + + diff --git a/examples/auth/AuthMcpServer.Tests/Config/ConfigLoaderTests.cs b/examples/auth/AuthMcpServer.Tests/Config/ConfigLoaderTests.cs new file mode 100644 index 00000000..14e6e283 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Config/ConfigLoaderTests.cs @@ -0,0 +1,269 @@ +using AuthMcpServer.Config; +using FluentAssertions; + +namespace AuthMcpServer.Tests.Config; + +/// +/// Unit tests for ConfigLoader. +/// +public class ConfigLoaderTests +{ + [Fact] + public void ParseConfigFile_WithValidContent_ReturnsKeyValuePairs() + { + var content = """ + key1=value1 + key2=value2 + """; + + var result = ConfigLoader.ParseConfigFile(content); + + result.Should().HaveCount(2); + result["key1"].Should().Be("value1"); + result["key2"].Should().Be("value2"); + } + + [Fact] + public void ParseConfigFile_WithComments_IgnoresCommentLines() + { + var content = """ + # This is a comment + key1=value1 + # Another comment + key2=value2 + """; + + var result = ConfigLoader.ParseConfigFile(content); + + result.Should().HaveCount(2); + result.Should().ContainKey("key1"); + result.Should().ContainKey("key2"); + } + + [Fact] + public void ParseConfigFile_WithEmptyLines_IgnoresEmptyLines() + { + var content = """ + key1=value1 + + key2=value2 + + """; + + var result = ConfigLoader.ParseConfigFile(content); + + result.Should().HaveCount(2); + } + + [Fact] + public void ParseConfigFile_WithWhitespace_TrimsKeysAndValues() + { + var content = " key1 = value1 \n key2 = value2 "; + + var result = ConfigLoader.ParseConfigFile(content); + + result["key1"].Should().Be("value1"); + result["key2"].Should().Be("value2"); + } + + [Fact] + public void ParseConfigFile_WithEqualsInValue_PreservesValue() + { + var content = "url=http://example.com?foo=bar"; + + var result = ConfigLoader.ParseConfigFile(content); + + result["url"].Should().Be("http://example.com?foo=bar"); + } + + [Fact] + public void ParseConfigFile_WithMissingEquals_IgnoresLine() + { + var content = """ + key1=value1 + invalid line without equals + key2=value2 + """; + + var result = ConfigLoader.ParseConfigFile(content); + + result.Should().HaveCount(2); + } + + [Fact] + public void ParseConfigFile_IsCaseInsensitiveForKeys() + { + var content = "KEY=value"; + + var result = ConfigLoader.ParseConfigFile(content); + + result["key"].Should().Be("value"); + result["KEY"].Should().Be("value"); + } + + [Fact] + public void BuildConfig_WithFullConfig_ReturnsAllValues() + { + var configMap = new Dictionary + { + ["host"] = "127.0.0.1", + ["port"] = "8080", + ["server_url"] = "https://example.com", + ["auth_server_url"] = "https://auth.example.com/realms/test", + ["jwks_uri"] = "https://auth.example.com/certs", + ["issuer"] = "https://auth.example.com", + ["client_id"] = "test-client", + ["client_secret"] = "test-secret", + ["oauth_authorize_url"] = "https://auth.example.com/authorize", + ["oauth_token_url"] = "https://auth.example.com/token", + ["allowed_scopes"] = "openid profile", + ["jwks_cache_duration"] = "7200", + ["jwks_auto_refresh"] = "false", + ["request_timeout"] = "60", + ["auth_disabled"] = "true" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.Host.Should().Be("127.0.0.1"); + config.Port.Should().Be(8080); + config.ServerUrl.Should().Be("https://example.com"); + config.AuthServerUrl.Should().Be("https://auth.example.com/realms/test"); + config.JwksUri.Should().Be("https://auth.example.com/certs"); + config.Issuer.Should().Be("https://auth.example.com"); + config.ClientId.Should().Be("test-client"); + config.ClientSecret.Should().Be("test-secret"); + config.OAuthAuthorizeUrl.Should().Be("https://auth.example.com/authorize"); + config.OAuthTokenUrl.Should().Be("https://auth.example.com/token"); + config.AllowedScopes.Should().Be("openid profile"); + config.JwksCacheDuration.Should().Be(7200); + config.JwksAutoRefresh.Should().BeFalse(); + config.RequestTimeout.Should().Be(60); + config.AuthDisabled.Should().BeTrue(); + } + + [Fact] + public void BuildConfig_WithEmptyMap_ReturnsDefaults() + { + var configMap = new Dictionary(); + + var config = ConfigLoader.BuildConfig(configMap); + + config.Host.Should().Be("0.0.0.0"); + config.Port.Should().Be(3001); + config.ServerUrl.Should().Be("http://localhost:3001"); + config.AuthServerUrl.Should().BeEmpty(); + config.JwksUri.Should().BeEmpty(); + config.Issuer.Should().BeEmpty(); + config.ClientId.Should().BeEmpty(); + config.ClientSecret.Should().BeEmpty(); + config.AllowedScopes.Should().Be("openid profile email mcp:read mcp:admin"); + config.JwksCacheDuration.Should().Be(3600); + config.JwksAutoRefresh.Should().BeTrue(); + config.RequestTimeout.Should().Be(30); + config.AuthDisabled.Should().BeFalse(); + } + + [Fact] + public void BuildConfig_DerivesJwksUriFromAuthServerUrl() + { + var configMap = new Dictionary + { + ["auth_server_url"] = "https://keycloak.example.com/realms/mcp" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.JwksUri.Should().Be("https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs"); + } + + [Fact] + public void BuildConfig_DerivesIssuerFromAuthServerUrl() + { + var configMap = new Dictionary + { + ["auth_server_url"] = "https://keycloak.example.com/realms/mcp" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.Issuer.Should().Be("https://keycloak.example.com/realms/mcp"); + } + + [Fact] + public void BuildConfig_DerivesOAuthAuthorizeUrlFromAuthServerUrl() + { + var configMap = new Dictionary + { + ["auth_server_url"] = "https://keycloak.example.com/realms/mcp" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.OAuthAuthorizeUrl.Should().Be("https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth"); + } + + [Fact] + public void BuildConfig_DerivesOAuthTokenUrlFromAuthServerUrl() + { + var configMap = new Dictionary + { + ["auth_server_url"] = "https://keycloak.example.com/realms/mcp" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.OAuthTokenUrl.Should().Be("https://keycloak.example.com/realms/mcp/protocol/openid-connect/token"); + } + + [Fact] + public void BuildConfig_ExplicitEndpointsOverrideDerivation() + { + var configMap = new Dictionary + { + ["auth_server_url"] = "https://keycloak.example.com/realms/mcp", + ["jwks_uri"] = "https://custom.example.com/certs", + ["issuer"] = "https://custom.example.com", + ["oauth_authorize_url"] = "https://custom.example.com/auth", + ["oauth_token_url"] = "https://custom.example.com/token" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.JwksUri.Should().Be("https://custom.example.com/certs"); + config.Issuer.Should().Be("https://custom.example.com"); + config.OAuthAuthorizeUrl.Should().Be("https://custom.example.com/auth"); + config.OAuthTokenUrl.Should().Be("https://custom.example.com/token"); + } + + [Fact] + public void BuildConfig_DerivesServerUrlFromPort() + { + var configMap = new Dictionary + { + ["port"] = "8080" + }; + + var config = ConfigLoader.BuildConfig(configMap); + + config.ServerUrl.Should().Be("http://localhost:8080"); + } + + [Fact] + public void LoadFromFile_WithNonExistentFile_ReturnsDefaultDisabled() + { + var config = ConfigLoader.LoadFromFile("/nonexistent/path/server.config"); + + config.AuthDisabled.Should().BeTrue(); + config.Host.Should().Be("0.0.0.0"); + config.Port.Should().Be(3001); + } + + [Fact] + public void AuthServerConfig_DefaultDisabled_HasAuthDisabledTrue() + { + var config = AuthServerConfig.DefaultDisabled(); + + config.AuthDisabled.Should().BeTrue(); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Middleware/CorsMiddlewareTests.cs b/examples/auth/AuthMcpServer.Tests/Middleware/CorsMiddlewareTests.cs new file mode 100644 index 00000000..c124eeee --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Middleware/CorsMiddlewareTests.cs @@ -0,0 +1,146 @@ +using System.Net; +using AuthMcpServer.Middleware; +using FluentAssertions; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace AuthMcpServer.Tests.Middleware; + +/// +/// Integration tests for CORS middleware. +/// +public class CorsMiddlewareTests +{ + private async Task CreateTestClient() + { + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/test", () => Microsoft.AspNetCore.Http.Results.Ok("test")); + endpoints.MapPost("/test", () => Microsoft.AspNetCore.Http.Results.Ok("posted")); + }); + }); + }) + .StartAsync(); + + return host.GetTestClient(); + } + + [Fact] + public async Task Get_Request_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/test"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + response.Headers.GetValues("Access-Control-Allow-Methods").Should().NotBeEmpty(); + response.Headers.GetValues("Access-Control-Allow-Headers").Should().NotBeEmpty(); + } + + [Fact] + public async Task Post_Request_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.PostAsync("/test", new StringContent("")); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + [Fact] + public async Task Options_Request_Returns204() + { + var client = await CreateTestClient(); + var request = new HttpRequestMessage(HttpMethod.Options, "/test"); + + var response = await client.SendAsync(request); + + response.StatusCode.Should().Be(HttpStatusCode.NoContent); + } + + [Fact] + public async Task Options_Request_HasContentLengthZero() + { + var client = await CreateTestClient(); + var request = new HttpRequestMessage(HttpMethod.Options, "/test"); + + var response = await client.SendAsync(request); + + response.Content.Headers.ContentLength.Should().Be(0); + } + + [Fact] + public async Task Options_Request_HasAllCorsHeaders() + { + var client = await CreateTestClient(); + var request = new HttpRequestMessage(HttpMethod.Options, "/test"); + + var response = await client.SendAsync(request); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + + var methods = response.Headers.GetValues("Access-Control-Allow-Methods").First(); + methods.Should().Contain("GET"); + methods.Should().Contain("POST"); + methods.Should().Contain("PUT"); + methods.Should().Contain("DELETE"); + methods.Should().Contain("OPTIONS"); + + var headers = response.Headers.GetValues("Access-Control-Allow-Headers").First(); + headers.Should().Contain("Authorization"); + headers.Should().Contain("Content-Type"); + headers.Should().Contain("Mcp-Session-Id"); + headers.Should().Contain("Mcp-Protocol-Version"); + + var exposed = response.Headers.GetValues("Access-Control-Expose-Headers").First(); + exposed.Should().Contain("WWW-Authenticate"); + + var maxAge = response.Headers.GetValues("Access-Control-Max-Age").First(); + maxAge.Should().Be("86400"); + } + + [Fact] + public async Task NonExistent_Path_StillHasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/nonexistent"); + + // Even 404 responses should have CORS headers + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + [Fact] + public void SetCorsHeaders_Static_SetsAllHeaders() + { + // Test the static method directly + var response = new DefaultHttpContext().Response; + + CorsMiddleware.SetCorsHeaders(response); + + response.Headers["Access-Control-Allow-Origin"].ToString().Should().Be("*"); + response.Headers["Access-Control-Allow-Methods"].ToString().Should().Contain("GET"); + response.Headers["Access-Control-Allow-Headers"].ToString().Should().Contain("Authorization"); + response.Headers["Access-Control-Expose-Headers"].ToString().Should().Contain("WWW-Authenticate"); + response.Headers["Access-Control-Max-Age"].ToString().Should().Be("86400"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Middleware/OAuthAuthMiddlewareTests.cs b/examples/auth/AuthMcpServer.Tests/Middleware/OAuthAuthMiddlewareTests.cs new file mode 100644 index 00000000..3ae0e82b --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Middleware/OAuthAuthMiddlewareTests.cs @@ -0,0 +1,367 @@ +using System.Net; +using System.Text.Json; +using AuthMcpServer.Config; +using AuthMcpServer.Middleware; +using GopherOrch.Auth; +using FluentAssertions; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace AuthMcpServer.Tests.Middleware; + +/// +/// Tests for OAuth authentication middleware. +/// +public class OAuthAuthMiddlewareTests +{ + private readonly AuthServerConfig _config = new() + { + ServerUrl = "http://localhost:3001", + AuthServerUrl = "https://auth.example.com/realms/test", + AllowedScopes = "openid profile email mcp:read mcp:admin", + AuthDisabled = false + }; + + private async Task CreateTestClient(AuthServerConfig? config = null) + { + var testConfig = config ?? _config; + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + services.AddSingleton(testConfig); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/health", () => + Microsoft.AspNetCore.Http.Results.Ok("healthy")); + + endpoints.MapGet("/.well-known/oauth-protected-resource", () => + Microsoft.AspNetCore.Http.Results.Ok("metadata")); + + endpoints.MapGet("/oauth/authorize", () => + Microsoft.AspNetCore.Http.Results.Ok("authorize")); + + endpoints.MapPost("/mcp", (HttpContext ctx) => + { + var authContext = ctx.Items["AuthContext"] as AuthContext; + return Microsoft.AspNetCore.Http.Results.Json(new + { + authenticated = authContext?.IsAuthenticated ?? false, + userId = authContext?.UserId ?? "" + }); + }); + + endpoints.MapPost("/rpc", (HttpContext ctx) => + { + var authContext = ctx.Items["AuthContext"] as AuthContext; + return Microsoft.AspNetCore.Http.Results.Json(new + { + authenticated = authContext?.IsAuthenticated ?? false, + scopes = authContext?.Scopes ?? "" + }); + }); + }); + }); + }) + .StartAsync(); + + return host.GetTestClient(); + } + + // Public path tests + + [Fact] + public async Task PublicPath_Health_SkipsAuth() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task PublicPath_WellKnown_SkipsAuth() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task PublicPath_OAuth_SkipsAuth() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/oauth/authorize"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + // Auth disabled tests + + [Fact] + public async Task AuthDisabled_AllowsAccessWithoutToken() + { + var disabledConfig = new AuthServerConfig + { + ServerUrl = "http://localhost:3001", + AllowedScopes = "openid profile mcp:read", + AuthDisabled = true + }; + var client = await CreateTestClient(disabledConfig); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task AuthDisabled_SetsAnonymousContext() + { + var disabledConfig = new AuthServerConfig + { + ServerUrl = "http://localhost:3001", + AllowedScopes = "mcp:read mcp:admin", + AuthDisabled = true + }; + var client = await CreateTestClient(disabledConfig); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + var json = JsonDocument.Parse(await response.Content.ReadAsStringAsync()); + + json.RootElement.GetProperty("authenticated").GetBoolean().Should().BeTrue(); + json.RootElement.GetProperty("userId").GetString().Should().Be("anonymous"); + } + + // Missing token tests + + [Fact] + public async Task MissingToken_Returns401() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + + response.StatusCode.Should().Be(HttpStatusCode.Unauthorized); + } + + [Fact] + public async Task MissingToken_ReturnsWwwAuthenticateHeader() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + + response.Headers.WwwAuthenticate.Should().NotBeEmpty(); + var wwwAuth = response.Headers.WwwAuthenticate.First().ToString(); + wwwAuth.Should().StartWith("Bearer"); + wwwAuth.Should().Contain("realm="); + wwwAuth.Should().Contain("resource_metadata="); + wwwAuth.Should().Contain("error=\"invalid_request\""); + } + + [Fact] + public async Task MissingToken_ReturnsJsonError() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + var json = JsonDocument.Parse(await response.Content.ReadAsStringAsync()); + + json.RootElement.GetProperty("error").GetString().Should().Be("invalid_request"); + json.RootElement.GetProperty("error_description").GetString().Should().Contain("bearer token"); + } + + [Fact] + public async Task MissingToken_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + // Token extraction tests + + [Fact] + public async Task TokenInHeader_AllowsAccess() + { + var client = await CreateTestClient(); + + var request = new HttpRequestMessage(HttpMethod.Post, "/mcp"); + request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", "test-token"); + request.Content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + + var response = await client.SendAsync(request); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task TokenInQuery_AllowsAccess() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp?access_token=test-token", content); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task TokenInHeader_SetsAuthContext() + { + var client = await CreateTestClient(); + + var request = new HttpRequestMessage(HttpMethod.Post, "/mcp"); + request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", "test-token"); + request.Content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + + var response = await client.SendAsync(request); + var json = JsonDocument.Parse(await response.Content.ReadAsStringAsync()); + + json.RootElement.GetProperty("authenticated").GetBoolean().Should().BeTrue(); + } + + // Token extraction unit tests + + [Fact] + public void ExtractToken_FromAuthorizationHeader() + { + var context = new DefaultHttpContext(); + context.Request.Headers.Authorization = "Bearer my-test-token"; + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().Be("my-test-token"); + } + + [Fact] + public void ExtractToken_FromQueryParam() + { + var context = new DefaultHttpContext(); + context.Request.QueryString = new QueryString("?access_token=query-token"); + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().Be("query-token"); + } + + [Fact] + public void ExtractToken_HeaderTakesPrecedence() + { + var context = new DefaultHttpContext(); + context.Request.Headers.Authorization = "Bearer header-token"; + context.Request.QueryString = new QueryString("?access_token=query-token"); + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().Be("header-token"); + } + + [Fact] + public void ExtractToken_CaseInsensitiveBearer() + { + var context = new DefaultHttpContext(); + context.Request.Headers.Authorization = "bearer my-token"; + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().Be("my-token"); + } + + [Fact] + public void ExtractToken_NoToken_ReturnsNull() + { + var context = new DefaultHttpContext(); + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().BeNull(); + } + + [Fact] + public void ExtractToken_InvalidAuthHeader_ReturnsNull() + { + var context = new DefaultHttpContext(); + context.Request.Headers.Authorization = "Basic credentials"; + + var token = OAuthAuthMiddleware.ExtractToken(context.Request); + + token.Should().BeNull(); + } + + // Header value escaping tests + + [Fact] + public void EscapeHeaderValue_EscapesBackslash() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue("test\\value"); + + result.Should().Be("test\\\\value"); + } + + [Fact] + public void EscapeHeaderValue_EscapesQuotes() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue("test\"value"); + + result.Should().Be("test\\\"value"); + } + + [Fact] + public void EscapeHeaderValue_EscapesBothCharacters() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue("test\\\"value"); + + result.Should().Be("test\\\\\\\"value"); + } + + [Fact] + public void EscapeHeaderValue_HandlesEmptyString() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue(""); + + result.Should().BeEmpty(); + } + + [Fact] + public void EscapeHeaderValue_HandlesNull() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue(null!); + + result.Should().BeEmpty(); + } + + [Fact] + public void EscapeHeaderValue_PreservesNormalCharacters() + { + var result = OAuthAuthMiddleware.EscapeHeaderValue("normal value 123"); + + result.Should().Be("normal value 123"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Models/JsonRpc/JsonRpcModelsTests.cs b/examples/auth/AuthMcpServer.Tests/Models/JsonRpc/JsonRpcModelsTests.cs new file mode 100644 index 00000000..9f6ec340 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Models/JsonRpc/JsonRpcModelsTests.cs @@ -0,0 +1,191 @@ +using System.Text.Json; +using AuthMcpServer.Models.JsonRpc; +using FluentAssertions; + +namespace AuthMcpServer.Tests.Models.JsonRpc; + +/// +/// Unit tests for JSON-RPC models. +/// +public class JsonRpcModelsTests +{ + private static readonly JsonSerializerOptions JsonOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + // JsonRpcRequest tests + + [Fact] + public void JsonRpcRequest_Serializes_WithCorrectPropertyNames() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "test/method" + }; + + var json = JsonSerializer.Serialize(request); + + json.Should().Contain("\"jsonrpc\":\"2.0\""); + json.Should().Contain("\"id\":1"); + json.Should().Contain("\"method\":\"test/method\""); + } + + [Fact] + public void JsonRpcRequest_Deserializes_FromValidJson() + { + var json = """{"jsonrpc":"2.0","id":42,"method":"tools/call","params":{"name":"test"}}"""; + + var request = JsonSerializer.Deserialize(json); + + request.Should().NotBeNull(); + request!.JsonRpc.Should().Be("2.0"); + request.Id.Should().NotBeNull(); + request.Method.Should().Be("tools/call"); + request.Params.Should().NotBeNull(); + } + + [Fact] + public void JsonRpcRequest_Deserializes_WithStringId() + { + var json = """{"jsonrpc":"2.0","id":"request-123","method":"ping"}"""; + + var request = JsonSerializer.Deserialize(json); + + request.Should().NotBeNull(); + request!.Id.Should().NotBeNull(); + } + + [Fact] + public void JsonRpcRequest_Deserializes_WithNullId() + { + var json = """{"jsonrpc":"2.0","id":null,"method":"ping"}"""; + + var request = JsonSerializer.Deserialize(json); + + request.Should().NotBeNull(); + request!.Id.Should().BeNull(); + } + + // JsonRpcResponse tests + + [Fact] + public void JsonRpcResponse_Serializes_SuccessResponse() + { + var response = new JsonRpcResponse + { + JsonRpc = "2.0", + Id = 1, + Result = new { status = "ok" } + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().Contain("\"jsonrpc\":\"2.0\""); + json.Should().Contain("\"id\":1"); + json.Should().Contain("\"result\":"); + json.Should().NotContain("\"error\""); + } + + [Fact] + public void JsonRpcResponse_Serializes_ErrorResponse() + { + var response = new JsonRpcResponse + { + JsonRpc = "2.0", + Id = 1, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.MethodNotFound, + Message = "Method not found" + } + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().Contain("\"jsonrpc\":\"2.0\""); + json.Should().Contain("\"error\":"); + json.Should().Contain("\"code\":-32601"); + json.Should().NotContain("\"result\""); + } + + [Fact] + public void JsonRpcResponse_OmitsNullResult_WhenWritingNull() + { + var response = new JsonRpcResponse + { + JsonRpc = "2.0", + Id = 1, + Result = null, + Error = new JsonRpcError { Code = -1, Message = "error" } + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().NotContain("\"result\""); + } + + [Fact] + public void JsonRpcResponse_OmitsNullError_WhenWritingNull() + { + var response = new JsonRpcResponse + { + JsonRpc = "2.0", + Id = 1, + Result = "success", + Error = null + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().NotContain("\"error\""); + } + + // JsonRpcError tests + + [Fact] + public void JsonRpcError_Serializes_WithCorrectPropertyNames() + { + var error = new JsonRpcError + { + Code = JsonRpcErrorCodes.InvalidParams, + Message = "Invalid parameters", + Data = new { field = "name" } + }; + + var json = JsonSerializer.Serialize(error); + + json.Should().Contain("\"code\":-32602"); + json.Should().Contain("\"message\":\"Invalid parameters\""); + json.Should().Contain("\"data\":"); + } + + [Fact] + public void JsonRpcError_OmitsNullData_WhenWritingNull() + { + var error = new JsonRpcError + { + Code = JsonRpcErrorCodes.InternalError, + Message = "Internal error", + Data = null + }; + + var json = JsonSerializer.Serialize(error); + + json.Should().NotContain("\"data\""); + } + + // JsonRpcErrorCodes tests + + [Fact] + public void JsonRpcErrorCodes_HasCorrectValues() + { + JsonRpcErrorCodes.ParseError.Should().Be(-32700); + JsonRpcErrorCodes.InvalidRequest.Should().Be(-32600); + JsonRpcErrorCodes.MethodNotFound.Should().Be(-32601); + JsonRpcErrorCodes.InvalidParams.Should().Be(-32602); + JsonRpcErrorCodes.InternalError.Should().Be(-32603); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Models/Mcp/McpModelsTests.cs b/examples/auth/AuthMcpServer.Tests/Models/Mcp/McpModelsTests.cs new file mode 100644 index 00000000..b7bdc5d5 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Models/Mcp/McpModelsTests.cs @@ -0,0 +1,167 @@ +using System.Text.Json; +using AuthMcpServer.Models.Mcp; +using FluentAssertions; + +namespace AuthMcpServer.Tests.Models.Mcp; + +/// +/// Unit tests for MCP models. +/// +public class McpModelsTests +{ + // ToolSpec tests + + [Fact] + public void ToolSpec_Serializes_WithCorrectPropertyNames() + { + var spec = new ToolSpec + { + Name = "get-weather", + Description = "Get current weather", + InputSchema = new + { + type = "object", + properties = new + { + city = new { type = "string" } + } + } + }; + + var json = JsonSerializer.Serialize(spec); + + json.Should().Contain("\"name\":\"get-weather\""); + json.Should().Contain("\"description\":\"Get current weather\""); + json.Should().Contain("\"inputSchema\":"); + } + + [Fact] + public void ToolSpec_Serializes_WithNullInputSchema() + { + var spec = new ToolSpec + { + Name = "ping", + Description = "Ping the server", + InputSchema = null + }; + + var json = JsonSerializer.Serialize(spec); + + json.Should().Contain("\"name\":\"ping\""); + json.Should().Contain("\"inputSchema\":null"); + } + + // ToolContent tests + + [Fact] + public void ToolContent_Serializes_TextContent() + { + var content = new ToolContent + { + Type = "text", + Text = "Hello, world!" + }; + + var json = JsonSerializer.Serialize(content); + + json.Should().Contain("\"type\":\"text\""); + json.Should().Contain("\"text\":\"Hello, world!\""); + json.Should().NotContain("\"data\""); + json.Should().NotContain("\"mimeType\""); + } + + [Fact] + public void ToolContent_Serializes_BinaryContent() + { + var content = new ToolContent + { + Type = "image", + Data = "base64encodeddata", + MimeType = "image/png" + }; + + var json = JsonSerializer.Serialize(content); + + json.Should().Contain("\"type\":\"image\""); + json.Should().Contain("\"data\":\"base64encodeddata\""); + json.Should().Contain("\"mimeType\":\"image/png\""); + json.Should().NotContain("\"text\""); + } + + [Fact] + public void ToolContent_OmitsNullFields() + { + var content = new ToolContent + { + Type = "text", + Text = "test", + Data = null, + MimeType = null + }; + + var json = JsonSerializer.Serialize(content); + + json.Should().NotContain("\"data\""); + json.Should().NotContain("\"mimeType\""); + } + + // ToolResult tests + + [Fact] + public void ToolResult_Text_CreatesTextResult() + { + var result = ToolResult.Text("Weather is sunny"); + + result.Content.Should().HaveCount(1); + result.Content[0].Type.Should().Be("text"); + result.Content[0].Text.Should().Be("Weather is sunny"); + result.IsError.Should().BeFalse(); + } + + [Fact] + public void ToolResult_Error_CreatesErrorResult() + { + var result = ToolResult.Error("Access denied"); + + result.Content.Should().HaveCount(1); + result.Content[0].Type.Should().Be("text"); + result.Content[0].Text.Should().Be("Access denied"); + result.IsError.Should().BeTrue(); + } + + [Fact] + public void ToolResult_Serializes_SuccessResult() + { + var result = ToolResult.Text("Success"); + + var json = JsonSerializer.Serialize(result); + + json.Should().Contain("\"content\":"); + json.Should().NotContain("\"isError\""); // Omitted when false + } + + [Fact] + public void ToolResult_Serializes_ErrorResult() + { + var result = ToolResult.Error("Failed"); + + var json = JsonSerializer.Serialize(result); + + json.Should().Contain("\"content\":"); + json.Should().Contain("\"isError\":true"); + } + + [Fact] + public void ToolResult_Serializes_WithCorrectStructure() + { + var result = ToolResult.Text("test"); + + var json = JsonSerializer.Serialize(result); + var parsed = JsonDocument.Parse(json); + + parsed.RootElement.TryGetProperty("content", out var content).Should().BeTrue(); + content.GetArrayLength().Should().Be(1); + content[0].GetProperty("type").GetString().Should().Be("text"); + content[0].GetProperty("text").GetString().Should().Be("test"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Models/OAuth/OAuthModelsTests.cs b/examples/auth/AuthMcpServer.Tests/Models/OAuth/OAuthModelsTests.cs new file mode 100644 index 00000000..740b4485 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Models/OAuth/OAuthModelsTests.cs @@ -0,0 +1,198 @@ +using System.Text.Json; +using GopherOrch.Auth.OAuth; +using FluentAssertions; + +namespace AuthMcpServer.Tests.Models.OAuth; + +/// +/// Unit tests for OAuth metadata models. +/// +public class OAuthModelsTests +{ + // ProtectedResourceMetadata tests (RFC 9728) + + [Fact] + public void ProtectedResourceMetadata_Serializes_WithSnakeCasePropertyNames() + { + var metadata = new ProtectedResourceMetadata + { + Resource = "https://example.com/mcp", + AuthorizationServers = new[] { "https://example.com" }, + ScopesSupported = new[] { "openid", "profile" }, + BearerMethodsSupported = new[] { "header", "query" }, + ResourceDocumentation = "https://example.com/docs" + }; + + var json = JsonSerializer.Serialize(metadata); + + json.Should().Contain("\"resource\":\"https://example.com/mcp\""); + json.Should().Contain("\"authorization_servers\":"); + json.Should().Contain("\"scopes_supported\":"); + json.Should().Contain("\"bearer_methods_supported\":"); + json.Should().Contain("\"resource_documentation\":"); + } + + [Fact] + public void ProtectedResourceMetadata_OmitsNullDocumentation() + { + var metadata = new ProtectedResourceMetadata + { + Resource = "https://example.com/mcp", + AuthorizationServers = new[] { "https://example.com" }, + ResourceDocumentation = null + }; + + var json = JsonSerializer.Serialize(metadata); + + json.Should().NotContain("\"resource_documentation\""); + } + + // AuthorizationServerMetadata tests (RFC 8414) + + [Fact] + public void AuthorizationServerMetadata_Serializes_WithAllRequiredFields() + { + var metadata = new AuthorizationServerMetadata + { + Issuer = "https://auth.example.com", + AuthorizationEndpoint = "https://auth.example.com/authorize", + TokenEndpoint = "https://auth.example.com/token", + JwksUri = "https://auth.example.com/certs", + RegistrationEndpoint = "https://auth.example.com/register", + ScopesSupported = new[] { "openid", "profile", "email" }, + ResponseTypesSupported = new[] { "code" }, + GrantTypesSupported = new[] { "authorization_code", "refresh_token" }, + TokenEndpointAuthMethodsSupported = new[] { "client_secret_basic", "client_secret_post", "none" }, + CodeChallengeMethodsSupported = new[] { "S256" } + }; + + var json = JsonSerializer.Serialize(metadata); + + json.Should().Contain("\"issuer\":"); + json.Should().Contain("\"authorization_endpoint\":"); + json.Should().Contain("\"token_endpoint\":"); + json.Should().Contain("\"jwks_uri\":"); + json.Should().Contain("\"registration_endpoint\":"); + json.Should().Contain("\"scopes_supported\":"); + json.Should().Contain("\"response_types_supported\":"); + json.Should().Contain("\"grant_types_supported\":"); + json.Should().Contain("\"token_endpoint_auth_methods_supported\":"); + json.Should().Contain("\"code_challenge_methods_supported\":"); + } + + [Fact] + public void AuthorizationServerMetadata_OmitsNullOptionalFields() + { + var metadata = new AuthorizationServerMetadata + { + Issuer = "https://auth.example.com", + AuthorizationEndpoint = "https://auth.example.com/authorize", + TokenEndpoint = "https://auth.example.com/token", + JwksUri = null, + RegistrationEndpoint = null + }; + + var json = JsonSerializer.Serialize(metadata); + + json.Should().NotContain("\"jwks_uri\""); + json.Should().NotContain("\"registration_endpoint\""); + } + + // OpenIdConfiguration tests + + [Fact] + public void OpenIdConfiguration_InheritsFromAuthorizationServerMetadata() + { + var config = new OpenIdConfiguration + { + Issuer = "https://auth.example.com", + AuthorizationEndpoint = "https://auth.example.com/authorize", + TokenEndpoint = "https://auth.example.com/token", + UserinfoEndpoint = "https://auth.example.com/userinfo", + SubjectTypesSupported = new[] { "public" }, + IdTokenSigningAlgValuesSupported = new[] { "RS256" } + }; + + var json = JsonSerializer.Serialize(config); + + // Inherited fields + json.Should().Contain("\"issuer\":"); + json.Should().Contain("\"authorization_endpoint\":"); + json.Should().Contain("\"token_endpoint\":"); + // OIDC-specific fields + json.Should().Contain("\"userinfo_endpoint\":"); + json.Should().Contain("\"subject_types_supported\":"); + json.Should().Contain("\"id_token_signing_alg_values_supported\":"); + } + + [Fact] + public void OpenIdConfiguration_OmitsNullUserinfoEndpoint() + { + var config = new OpenIdConfiguration + { + Issuer = "https://auth.example.com", + AuthorizationEndpoint = "https://auth.example.com/authorize", + TokenEndpoint = "https://auth.example.com/token", + UserinfoEndpoint = null + }; + + var json = JsonSerializer.Serialize(config); + + json.Should().NotContain("\"userinfo_endpoint\""); + } + + // ClientRegistrationResponse tests (RFC 7591) + + [Fact] + public void ClientRegistrationResponse_Serializes_WithAllFields() + { + var response = new ClientRegistrationResponse + { + ClientId = "client-123", + ClientSecret = "secret-456", + ClientIdIssuedAt = 1234567890, + ClientSecretExpiresAt = 0, + RedirectUris = new[] { "https://app.example.com/callback" }, + GrantTypes = new[] { "authorization_code", "refresh_token" }, + ResponseTypes = new[] { "code" }, + TokenEndpointAuthMethod = "client_secret_post" + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().Contain("\"client_id\":\"client-123\""); + json.Should().Contain("\"client_secret\":\"secret-456\""); + json.Should().Contain("\"client_id_issued_at\":1234567890"); + json.Should().Contain("\"client_secret_expires_at\":0"); + json.Should().Contain("\"redirect_uris\":"); + json.Should().Contain("\"grant_types\":"); + json.Should().Contain("\"response_types\":"); + json.Should().Contain("\"token_endpoint_auth_method\":\"client_secret_post\""); + } + + [Fact] + public void ClientRegistrationResponse_OmitsNullClientSecret() + { + var response = new ClientRegistrationResponse + { + ClientId = "public-client", + ClientSecret = null, + ClientIdIssuedAt = 1234567890, + ClientSecretExpiresAt = 0, + TokenEndpointAuthMethod = "none" + }; + + var json = JsonSerializer.Serialize(response); + + json.Should().Contain("\"client_id\":\"public-client\""); + json.Should().NotContain("\"client_secret\""); + } + + [Fact] + public void ClientRegistrationResponse_DefaultsToNoneAuthMethod() + { + var response = new ClientRegistrationResponse(); + + response.TokenEndpointAuthMethod.Should().Be("none"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Routes/HealthEndpointsTests.cs b/examples/auth/AuthMcpServer.Tests/Routes/HealthEndpointsTests.cs new file mode 100644 index 00000000..28f5098c --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Routes/HealthEndpointsTests.cs @@ -0,0 +1,121 @@ +using System.Net; +using System.Text.Json; +using AuthMcpServer.Middleware; +using AuthMcpServer.Routes; +using FluentAssertions; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace AuthMcpServer.Tests.Routes; + +/// +/// Tests for health endpoint. +/// +public class HealthEndpointsTests +{ + private async Task CreateTestClient() + { + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/health", () => + { + return Microsoft.AspNetCore.Http.Results.Json(new + { + status = "healthy", + version = "1.0.0", + uptime = "0s" + }); + }); + }); + }); + }) + .StartAsync(); + + return host.GetTestClient(); + } + + [Fact] + public async Task Health_Returns200Ok() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task Health_ReturnsJsonContentType() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + + response.Content.Headers.ContentType?.MediaType.Should().Be("application/json"); + } + + [Fact] + public async Task Health_ResponseContainsStatus() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.TryGetProperty("status", out var status).Should().BeTrue(); + status.GetString().Should().Be("healthy"); + } + + [Fact] + public async Task Health_ResponseContainsVersion() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.TryGetProperty("version", out var version).Should().BeTrue(); + version.GetString().Should().Be("1.0.0"); + } + + [Fact] + public async Task Health_ResponseContainsUptime() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.TryGetProperty("uptime", out var uptime).Should().BeTrue(); + uptime.GetString().Should().EndWith("s"); + } + + [Fact] + public async Task Health_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/health"); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Routes/McpEndpointsTests.cs b/examples/auth/AuthMcpServer.Tests/Routes/McpEndpointsTests.cs new file mode 100644 index 00000000..c770aba8 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Routes/McpEndpointsTests.cs @@ -0,0 +1,303 @@ +using System.Net; +using System.Text; +using System.Text.Json; +using AuthMcpServer.Config; +using AuthMcpServer.Middleware; +using AuthMcpServer.Models.JsonRpc; +using AuthMcpServer.Models.Mcp; +using AuthMcpServer.Services; +using FluentAssertions; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace AuthMcpServer.Tests.Routes; + +/// +/// Tests for MCP endpoints. +/// +public class McpEndpointsTests +{ + private readonly AuthServerConfig _config = new() + { + ServerUrl = "http://localhost:3001", + AllowedScopes = "openid profile mcp:read mcp:admin", + AuthDisabled = true // Disable auth for endpoint tests + }; + + private async Task CreateTestClient() + { + var mcpHandler = new McpHandler(); + mcpHandler.RegisterTool("test-tool", new ToolSpec + { + Name = "test-tool", + Description = "A test tool" + }, async (args, ctx) => ToolResult.Text("test result")); + + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + services.AddSingleton(_config); + services.AddSingleton(mcpHandler); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + var handler = app.ApplicationServices.GetRequiredService(); + + endpoints.MapPost("/mcp", async context => + { + var result = await HandleMcpRequest(context, handler); + await result.ExecuteAsync(context); + }); + + endpoints.MapPost("/rpc", async context => + { + var result = await HandleMcpRequest(context, handler); + await result.ExecuteAsync(context); + }); + }); + }); + }) + .StartAsync(); + + return host.GetTestClient(); + } + + private static async Task HandleMcpRequest( + Microsoft.AspNetCore.Http.HttpContext context, + McpHandler handler) + { + JsonRpcRequest? request; + + try + { + request = await context.Request.ReadFromJsonAsync(); + } + catch (JsonException) + { + return Microsoft.AspNetCore.Http.Results.Json(new JsonRpcResponse + { + JsonRpc = "2.0", + Id = null, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.ParseError, + Message = "Parse error: invalid JSON" + } + }); + } + + if (request == null || string.IsNullOrEmpty(request.Method)) + { + return Microsoft.AspNetCore.Http.Results.Json(new JsonRpcResponse + { + JsonRpc = "2.0", + Id = null, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.InvalidRequest, + Message = "Invalid request: missing method" + } + }); + } + + var response = await handler.HandleRequest(request, context); + return Microsoft.AspNetCore.Http.Results.Json(response); + } + + private StringContent CreateJsonContent(object obj) + { + return new StringContent( + JsonSerializer.Serialize(obj), + Encoding.UTF8, + "application/json"); + } + + // POST /mcp tests + + [Fact] + public async Task Mcp_ValidRequest_Returns200() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "ping" }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task Mcp_Initialize_ReturnsProtocolVersion() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "initialize" }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("result") + .GetProperty("protocolVersion").GetString() + .Should().Be("2024-11-05"); + } + + [Fact] + public async Task Mcp_ToolsList_ReturnsRegisteredTools() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "tools/list" }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var tools = json.RootElement.GetProperty("result").GetProperty("tools"); + tools.GetArrayLength().Should().Be(1); + tools[0].GetProperty("name").GetString().Should().Be("test-tool"); + } + + [Fact] + public async Task Mcp_ToolsCall_InvokesTool() + { + var client = await CreateTestClient(); + + var request = new + { + jsonrpc = "2.0", + id = 1, + method = "tools/call", + @params = new { name = "test-tool" } + }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("result").GetProperty("content")[0] + .GetProperty("text").GetString() + .Should().Be("test result"); + } + + [Fact] + public async Task Mcp_InvalidJson_ReturnsParseError() + { + var client = await CreateTestClient(); + + var content = new StringContent("{ invalid json }", Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/mcp", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + json.RootElement.GetProperty("error").GetProperty("code").GetInt32() + .Should().Be(JsonRpcErrorCodes.ParseError); + } + + [Fact] + public async Task Mcp_MissingMethod_ReturnsInvalidRequest() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1 }; // Missing method + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("error").GetProperty("code").GetInt32() + .Should().Be(JsonRpcErrorCodes.InvalidRequest); + } + + [Fact] + public async Task Mcp_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "ping" }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + // POST /rpc tests (alias) + + [Fact] + public async Task Rpc_WorksAsAlias() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "ping" }; + var response = await client.PostAsync("/rpc", CreateJsonContent(request)); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task Rpc_Initialize_ReturnsProtocolVersion() + { + var client = await CreateTestClient(); + + var request = new { jsonrpc = "2.0", id = 1, method = "initialize" }; + var response = await client.PostAsync("/rpc", CreateJsonContent(request)); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("result") + .GetProperty("protocolVersion").GetString() + .Should().Be("2024-11-05"); + } + + // Auth tests + + [Fact] + public async Task Mcp_WithoutAuth_Returns401() + { + var authConfig = new AuthServerConfig + { + ServerUrl = "http://localhost:3001", + AllowedScopes = "openid", + AuthDisabled = false // Enable auth + }; + + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + services.AddSingleton(authConfig); + services.AddSingleton(new McpHandler()); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapPost("/mcp", () => + Microsoft.AspNetCore.Http.Results.Ok()); + }); + }); + }) + .StartAsync(); + + var client = host.GetTestClient(); + var request = new { jsonrpc = "2.0", id = 1, method = "ping" }; + var response = await client.PostAsync("/mcp", CreateJsonContent(request)); + + response.StatusCode.Should().Be(HttpStatusCode.Unauthorized); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Routes/OAuthEndpointsTests.cs b/examples/auth/AuthMcpServer.Tests/Routes/OAuthEndpointsTests.cs new file mode 100644 index 00000000..424c7e3e --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Routes/OAuthEndpointsTests.cs @@ -0,0 +1,772 @@ +using System.Net; +using System.Text.Json; +using AuthMcpServer.Config; +using AuthMcpServer.Middleware; +using AuthMcpServer.Routes; +using FluentAssertions; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace AuthMcpServer.Tests.Routes; + +/// +/// Tests for OAuth endpoints. +/// +public class OAuthEndpointsTests +{ + private readonly AuthServerConfig _config = new() + { + ServerUrl = "http://localhost:3001", + AuthServerUrl = "https://auth.example.com/realms/test", + AllowedScopes = "openid profile email mcp:read mcp:admin", + ClientId = "test-client-id", + ClientSecret = "test-client-secret" + }; + + private async Task CreateTestClient() + { + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + services.AddSingleton(_config); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/.well-known/oauth-protected-resource", + (AuthServerConfig config) => + { + return Microsoft.AspNetCore.Http.Results.Json(new + { + resource = $"{config.ServerUrl}/mcp", + authorization_servers = new[] { config.ServerUrl }, + scopes_supported = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries), + bearer_methods_supported = new[] { "header", "query" }, + resource_documentation = $"{config.ServerUrl}/docs" + }); + }); + + endpoints.MapGet("/.well-known/oauth-protected-resource/mcp", + (AuthServerConfig config) => + { + return Microsoft.AspNetCore.Http.Results.Json(new + { + resource = $"{config.ServerUrl}/mcp", + authorization_servers = new[] { config.ServerUrl }, + scopes_supported = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries), + bearer_methods_supported = new[] { "header", "query" }, + resource_documentation = $"{config.ServerUrl}/docs" + }); + }); + + endpoints.MapGet("/.well-known/oauth-authorization-server", + (AuthServerConfig config) => + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + var tokenEndpoint = !string.IsNullOrEmpty(config.OAuthTokenUrl) + ? config.OAuthTokenUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/token"; + + return Microsoft.AspNetCore.Http.Results.Json(new + { + issuer = !string.IsNullOrEmpty(config.Issuer) + ? config.Issuer : config.ServerUrl, + authorization_endpoint = authEndpoint, + token_endpoint = tokenEndpoint, + jwks_uri = config.JwksUri, + registration_endpoint = $"{config.ServerUrl}/oauth/register", + scopes_supported = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries), + response_types_supported = new[] { "code" }, + grant_types_supported = new[] { "authorization_code", "refresh_token" }, + token_endpoint_auth_methods_supported = new[] { "client_secret_basic", "client_secret_post", "none" }, + code_challenge_methods_supported = new[] { "S256" } + }); + }); + + endpoints.MapGet("/.well-known/openid-configuration", + (AuthServerConfig config) => + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + var tokenEndpoint = !string.IsNullOrEmpty(config.OAuthTokenUrl) + ? config.OAuthTokenUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/token"; + + var baseScopes = new[] { "openid", "profile", "email" }; + var configScopes = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries); + var allScopes = baseScopes.Union(configScopes).Distinct().ToArray(); + + return Microsoft.AspNetCore.Http.Results.Json(new + { + issuer = !string.IsNullOrEmpty(config.Issuer) + ? config.Issuer : config.ServerUrl, + authorization_endpoint = authEndpoint, + token_endpoint = tokenEndpoint, + jwks_uri = config.JwksUri, + registration_endpoint = $"{config.ServerUrl}/oauth/register", + scopes_supported = allScopes, + response_types_supported = new[] { "code" }, + grant_types_supported = new[] { "authorization_code", "refresh_token" }, + token_endpoint_auth_methods_supported = new[] { "client_secret_basic", "client_secret_post", "none" }, + code_challenge_methods_supported = new[] { "S256" }, + userinfo_endpoint = !string.IsNullOrEmpty(config.AuthServerUrl) + ? $"{config.AuthServerUrl}/protocol/openid-connect/userinfo" + : (string?)null, + subject_types_supported = new[] { "public" }, + id_token_signing_alg_values_supported = new[] { "RS256" } + }); + }); + + endpoints.MapGet("/oauth/authorize", + (HttpContext ctx, AuthServerConfig config) => + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + + try + { + var uriBuilder = new UriBuilder(authEndpoint); + var query = System.Web.HttpUtility.ParseQueryString(uriBuilder.Query); + foreach (var param in ctx.Request.Query) + { + query[param.Key] = param.Value.ToString(); + } + uriBuilder.Query = query.ToString(); + return Microsoft.AspNetCore.Http.Results.Redirect(uriBuilder.ToString(), permanent: false); + } + catch + { + return Microsoft.AspNetCore.Http.Results.Json(new + { + error = "server_error", + error_description = "Failed to construct authorization URL" + }, statusCode: 500); + } + }); + + endpoints.MapPost("/oauth/register", + async (HttpContext ctx, AuthServerConfig config) => + { + Dictionary? body = null; + try + { + body = await ctx.Request.ReadFromJsonAsync>(); + } + catch { } + + var redirectUris = Array.Empty(); + if (body?.TryGetValue("redirect_uris", out var urisElement) == true && + urisElement.ValueKind == JsonValueKind.Array) + { + redirectUris = urisElement.EnumerateArray() + .Where(e => e.ValueKind == JsonValueKind.String) + .Select(e => e.GetString()!) + .ToArray(); + } + + ctx.Response.StatusCode = 201; + return Microsoft.AspNetCore.Http.Results.Json(new + { + client_id = config.ClientId, + client_secret = !string.IsNullOrEmpty(config.ClientSecret) + ? config.ClientSecret : (string?)null, + client_id_issued_at = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), + client_secret_expires_at = 0, + redirect_uris = redirectUris, + grant_types = new[] { "authorization_code", "refresh_token" }, + response_types = new[] { "code" }, + token_endpoint_auth_method = !string.IsNullOrEmpty(config.ClientSecret) + ? "client_secret_post" : "none" + }); + }); + }); + }); + }) + .StartAsync(); + + return host.GetTestClient(); + } + + // Protected Resource Metadata tests (RFC 9728) + + [Fact] + public async Task ProtectedResourceMetadata_Returns200Ok() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task ProtectedResourceMetadata_ContainsResource() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("resource").GetString() + .Should().Be("http://localhost:3001/mcp"); + } + + [Fact] + public async Task ProtectedResourceMetadata_ContainsAuthorizationServers() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var servers = json.RootElement.GetProperty("authorization_servers"); + servers.GetArrayLength().Should().Be(1); + servers[0].GetString().Should().Be("http://localhost:3001"); + } + + [Fact] + public async Task ProtectedResourceMetadata_ContainsScopesSupported() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var scopes = json.RootElement.GetProperty("scopes_supported"); + scopes.GetArrayLength().Should().Be(5); // openid profile email mcp:read mcp:admin + + var scopeList = scopes.EnumerateArray().Select(s => s.GetString()).ToList(); + scopeList.Should().Contain("openid"); + scopeList.Should().Contain("mcp:read"); + scopeList.Should().Contain("mcp:admin"); + } + + [Fact] + public async Task ProtectedResourceMetadata_ContainsBearerMethods() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var methods = json.RootElement.GetProperty("bearer_methods_supported"); + var methodList = methods.EnumerateArray().Select(m => m.GetString()).ToList(); + methodList.Should().Contain("header"); + methodList.Should().Contain("query"); + } + + [Fact] + public async Task ProtectedResourceMetadata_ContainsResourceDocumentation() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("resource_documentation").GetString() + .Should().Be("http://localhost:3001/docs"); + } + + [Fact] + public async Task ProtectedResourceMetadata_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-protected-resource"); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + [Fact] + public async Task ProtectedResourceMetadata_McpPath_ReturnsSameResponse() + { + var client = await CreateTestClient(); + + var response1 = await client.GetAsync("/.well-known/oauth-protected-resource"); + var response2 = await client.GetAsync("/.well-known/oauth-protected-resource/mcp"); + + var content1 = await response1.Content.ReadAsStringAsync(); + var content2 = await response2.Content.ReadAsStringAsync(); + + // Both paths should return the same resource metadata + content1.Should().Be(content2); + } + + // Authorization Server Metadata tests (RFC 8414) + + [Fact] + public async Task AuthorizationServerMetadata_Returns200Ok() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsIssuer() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + // Issuer falls back to serverUrl when not explicitly set + json.RootElement.GetProperty("issuer").GetString() + .Should().Be("http://localhost:3001"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsAuthorizationEndpoint() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + // Derived from auth_server_url + json.RootElement.GetProperty("authorization_endpoint").GetString() + .Should().Be("https://auth.example.com/realms/test/protocol/openid-connect/auth"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsTokenEndpoint() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("token_endpoint").GetString() + .Should().Be("https://auth.example.com/realms/test/protocol/openid-connect/token"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsRegistrationEndpoint() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("registration_endpoint").GetString() + .Should().Be("http://localhost:3001/oauth/register"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsResponseTypes() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var types = json.RootElement.GetProperty("response_types_supported"); + types.EnumerateArray().Select(t => t.GetString()).Should().Contain("code"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsGrantTypes() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var grants = json.RootElement.GetProperty("grant_types_supported"); + var grantList = grants.EnumerateArray().Select(g => g.GetString()).ToList(); + grantList.Should().Contain("authorization_code"); + grantList.Should().Contain("refresh_token"); + } + + [Fact] + public async Task AuthorizationServerMetadata_ContainsCodeChallengeMethods() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var methods = json.RootElement.GetProperty("code_challenge_methods_supported"); + methods.EnumerateArray().Select(m => m.GetString()).Should().Contain("S256"); + } + + [Fact] + public async Task AuthorizationServerMetadata_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/oauth-authorization-server"); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + // OpenID Configuration tests + + [Fact] + public async Task OpenIdConfiguration_Returns200Ok() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + + response.StatusCode.Should().Be(HttpStatusCode.OK); + } + + [Fact] + public async Task OpenIdConfiguration_ContainsAuthServerMetadataFields() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + // All auth server metadata fields should be present + json.RootElement.TryGetProperty("issuer", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("authorization_endpoint", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("token_endpoint", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("scopes_supported", out _).Should().BeTrue(); + } + + [Fact] + public async Task OpenIdConfiguration_ContainsUserinfoEndpoint() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + json.RootElement.GetProperty("userinfo_endpoint").GetString() + .Should().Be("https://auth.example.com/realms/test/protocol/openid-connect/userinfo"); + } + + [Fact] + public async Task OpenIdConfiguration_ContainsSubjectTypesSupported() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var types = json.RootElement.GetProperty("subject_types_supported"); + types.EnumerateArray().Select(t => t.GetString()).Should().Contain("public"); + } + + [Fact] + public async Task OpenIdConfiguration_ContainsIdTokenSigningAlg() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var algs = json.RootElement.GetProperty("id_token_signing_alg_values_supported"); + algs.EnumerateArray().Select(a => a.GetString()).Should().Contain("RS256"); + } + + [Fact] + public async Task OpenIdConfiguration_MergesBaseOidcScopes() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + var content = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(content); + + var scopes = json.RootElement.GetProperty("scopes_supported"); + var scopeList = scopes.EnumerateArray().Select(s => s.GetString()).ToList(); + + // Base OIDC scopes should be present + scopeList.Should().Contain("openid"); + scopeList.Should().Contain("profile"); + scopeList.Should().Contain("email"); + // Config scopes should also be present + scopeList.Should().Contain("mcp:read"); + scopeList.Should().Contain("mcp:admin"); + } + + [Fact] + public async Task OpenIdConfiguration_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var response = await client.GetAsync("/.well-known/openid-configuration"); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + + // OAuth Authorize Redirect tests + + private async Task<(HttpClient Client, IHost Host)> CreateTestClientNoRedirect() + { + var host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer(options => options.AllowSynchronousIO = true) + .ConfigureServices(services => + { + services.AddRouting(); + services.AddSingleton(_config); + }) + .Configure(app => + { + app.UseMiddleware(); + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/oauth/authorize", + (HttpContext ctx, AuthServerConfig config) => + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + + var uriBuilder = new UriBuilder(authEndpoint); + var query = System.Web.HttpUtility.ParseQueryString(uriBuilder.Query); + foreach (var param in ctx.Request.Query) + { + query[param.Key] = param.Value.ToString(); + } + uriBuilder.Query = query.ToString(); + return Microsoft.AspNetCore.Http.Results.Redirect(uriBuilder.ToString(), permanent: false); + }); + }); + }); + }) + .StartAsync(); + + var server = host.GetTestServer(); + // Disable auto-redirect to capture the redirect response + var handler = server.CreateHandler(); + var noRedirectClient = new HttpClient(new NoRedirectHandler(handler)) + { + BaseAddress = server.BaseAddress + }; + return (noRedirectClient, host); + } + + private class NoRedirectHandler : DelegatingHandler + { + public NoRedirectHandler(HttpMessageHandler innerHandler) : base(innerHandler) { } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + var response = await base.SendAsync(request, cancellationToken); + return response; + } + } + + [Fact] + public async Task OAuthAuthorize_Returns302Redirect() + { + var (client, host) = await CreateTestClientNoRedirect(); + try + { + var response = await client.GetAsync("/oauth/authorize"); + + // TestServer doesn't automatically follow redirects, so we should get the redirect status + response.StatusCode.Should().Be(HttpStatusCode.Redirect); + } + finally + { + await host.StopAsync(); + } + } + + [Fact] + public async Task OAuthAuthorize_RedirectsToCorrectAuthEndpoint() + { + var (client, host) = await CreateTestClientNoRedirect(); + try + { + var response = await client.GetAsync("/oauth/authorize"); + var location = response.Headers.Location?.ToString(); + + location.Should().StartWith("https://auth.example.com/realms/test/protocol/openid-connect/auth"); + } + finally + { + await host.StopAsync(); + } + } + + [Fact] + public async Task OAuthAuthorize_ForwardsQueryParams() + { + var (client, host) = await CreateTestClientNoRedirect(); + try + { + var response = await client.GetAsync("/oauth/authorize?client_id=test&redirect_uri=http://example.com/callback&scope=openid"); + var location = response.Headers.Location?.ToString(); + + location.Should().Contain("client_id=test"); + location.Should().Contain("redirect_uri="); + location.Should().Contain("scope=openid"); + } + finally + { + await host.StopAsync(); + } + } + + [Fact] + public async Task OAuthAuthorize_HasCorsHeaders() + { + var (client, host) = await CreateTestClientNoRedirect(); + try + { + var response = await client.GetAsync("/oauth/authorize"); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } + finally + { + await host.StopAsync(); + } + } + + // Dynamic Client Registration tests (RFC 7591) + + [Fact] + public async Task ClientRegistration_Returns201Created() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + + response.StatusCode.Should().Be(HttpStatusCode.Created); + } + + [Fact] + public async Task ClientRegistration_ReturnsConfiguredClientId() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + json.RootElement.GetProperty("client_id").GetString() + .Should().Be("test-client-id"); + } + + [Fact] + public async Task ClientRegistration_ReturnsClientSecretWhenConfigured() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + json.RootElement.GetProperty("client_secret").GetString() + .Should().Be("test-client-secret"); + } + + [Fact] + public async Task ClientRegistration_IncludesRedirectUrisFromRequest() + { + var client = await CreateTestClient(); + + var requestBody = new + { + redirect_uris = new[] { "http://localhost/callback", "http://example.com/callback" } + }; + var content = new StringContent( + JsonSerializer.Serialize(requestBody), + System.Text.Encoding.UTF8, + "application/json"); + + var response = await client.PostAsync("/oauth/register", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + var redirectUris = json.RootElement.GetProperty("redirect_uris"); + redirectUris.GetArrayLength().Should().Be(2); + redirectUris[0].GetString().Should().Be("http://localhost/callback"); + redirectUris[1].GetString().Should().Be("http://example.com/callback"); + } + + [Fact] + public async Task ClientRegistration_ReturnsCorrectGrantTypes() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + var grantTypes = json.RootElement.GetProperty("grant_types"); + var grantList = grantTypes.EnumerateArray().Select(g => g.GetString()).ToList(); + grantList.Should().Contain("authorization_code"); + grantList.Should().Contain("refresh_token"); + } + + [Fact] + public async Task ClientRegistration_SetsTokenEndpointAuthMethod() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + var responseContent = await response.Content.ReadAsStringAsync(); + var json = JsonDocument.Parse(responseContent); + + // With client_secret configured, should be client_secret_post + json.RootElement.GetProperty("token_endpoint_auth_method").GetString() + .Should().Be("client_secret_post"); + } + + [Fact] + public async Task ClientRegistration_HasCorsHeaders() + { + var client = await CreateTestClient(); + + var content = new StringContent("{}", System.Text.Encoding.UTF8, "application/json"); + var response = await client.PostAsync("/oauth/register", content); + + response.Headers.GetValues("Access-Control-Allow-Origin").Should().Contain("*"); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Services/AuthContextTests.cs b/examples/auth/AuthMcpServer.Tests/Services/AuthContextTests.cs new file mode 100644 index 00000000..2d096ffc --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Services/AuthContextTests.cs @@ -0,0 +1,119 @@ +using GopherOrch.Auth; +using FluentAssertions; + +namespace AuthMcpServer.Tests.Services; + +/// +/// Tests for AuthContext service. +/// +public class AuthContextTests +{ + [Fact] + public void HasScope_WithMatchingScope_ReturnsTrue() + { + var context = new AuthContext("user1", "openid profile mcp:read", "", 0, true); + + context.HasScope("mcp:read").Should().BeTrue(); + } + + [Fact] + public void HasScope_WithMissingScope_ReturnsFalse() + { + var context = new AuthContext("user1", "openid profile mcp:read", "", 0, true); + + context.HasScope("mcp:admin").Should().BeFalse(); + } + + [Fact] + public void HasScope_WithEmptyScopes_ReturnsFalse() + { + var context = new AuthContext("user1", "", "", 0, true); + + context.HasScope("mcp:read").Should().BeFalse(); + } + + [Fact] + public void HasScope_WithNullScopes_ReturnsFalse() + { + var context = new AuthContext("user1", null!, "", 0, true); + + context.HasScope("mcp:read").Should().BeFalse(); + } + + [Fact] + public void HasScope_WithEmptyRequiredScope_ReturnsTrue() + { + var context = new AuthContext("user1", "openid", "", 0, true); + + context.HasScope("").Should().BeTrue(); + } + + [Fact] + public void HasScope_WithNullRequiredScope_ReturnsTrue() + { + var context = new AuthContext("user1", "openid", "", 0, true); + + context.HasScope(null!).Should().BeTrue(); + } + + [Fact] + public void HasScope_IsCaseInsensitive() + { + var context = new AuthContext("user1", "mcp:READ mcp:ADMIN", "", 0, true); + + context.HasScope("mcp:read").Should().BeTrue(); + context.HasScope("MCP:READ").Should().BeTrue(); + context.HasScope("MCP:Admin").Should().BeTrue(); + } + + [Fact] + public void Empty_ReturnsUnauthenticatedContext() + { + var context = AuthContext.Empty(); + + context.UserId.Should().BeEmpty(); + context.Scopes.Should().BeEmpty(); + context.Audience.Should().BeEmpty(); + context.TokenExpiry.Should().Be(0); + context.IsAuthenticated.Should().BeFalse(); + } + + [Fact] + public void Anonymous_ReturnsAuthenticatedContext() + { + var scopes = "openid profile mcp:read mcp:admin"; + var context = AuthContext.Anonymous(scopes); + + context.UserId.Should().Be("anonymous"); + context.Scopes.Should().Be(scopes); + context.IsAuthenticated.Should().BeTrue(); + context.TokenExpiry.Should().BeGreaterThan(0); + } + + [Fact] + public void Anonymous_HasScope_WorksWithProvidedScopes() + { + var context = AuthContext.Anonymous("mcp:read mcp:admin"); + + context.HasScope("mcp:read").Should().BeTrue(); + context.HasScope("mcp:admin").Should().BeTrue(); + context.HasScope("mcp:write").Should().BeFalse(); + } + + [Fact] + public void Constructor_SetsAllProperties() + { + var context = new AuthContext( + userId: "user123", + scopes: "openid profile", + audience: "https://api.example.com", + tokenExpiry: 1234567890, + isAuthenticated: true); + + context.UserId.Should().Be("user123"); + context.Scopes.Should().Be("openid profile"); + context.Audience.Should().Be("https://api.example.com"); + context.TokenExpiry.Should().Be(1234567890); + context.IsAuthenticated.Should().BeTrue(); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Services/McpHandlerTests.cs b/examples/auth/AuthMcpServer.Tests/Services/McpHandlerTests.cs new file mode 100644 index 00000000..8aaf123c --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Services/McpHandlerTests.cs @@ -0,0 +1,501 @@ +using System.Text.Json; +using AuthMcpServer.Models.JsonRpc; +using AuthMcpServer.Models.Mcp; +using AuthMcpServer.Services; +using FluentAssertions; +using Microsoft.AspNetCore.Http; + +namespace AuthMcpServer.Tests.Services; + +/// +/// Tests for MCP handler service. +/// +public class McpHandlerTests +{ + private readonly McpHandler _handler = new(); + private readonly DefaultHttpContext _context = new(); + + // Tool registration tests + + [Fact] + public void RegisterTool_StoresTool() + { + var spec = new ToolSpec + { + Name = "test-tool", + Description = "A test tool" + }; + + _handler.RegisterTool("test-tool", spec, async (args, ctx) => + ToolResult.Text("test")); + + _handler.GetTools().Should().ContainSingle() + .Which.Name.Should().Be("test-tool"); + } + + [Fact] + public void GetTools_ReturnsAllRegisteredTools() + { + _handler.RegisterTool("tool1", new ToolSpec { Name = "tool1" }, + async (args, ctx) => ToolResult.Text("1")); + _handler.RegisterTool("tool2", new ToolSpec { Name = "tool2" }, + async (args, ctx) => ToolResult.Text("2")); + _handler.RegisterTool("tool3", new ToolSpec { Name = "tool3" }, + async (args, ctx) => ToolResult.Text("3")); + + var tools = _handler.GetTools().ToList(); + + tools.Should().HaveCount(3); + tools.Select(t => t.Name).Should().Contain("tool1", "tool2", "tool3"); + } + + // HandleRequest tests + + [Fact] + public async Task HandleRequest_ReturnsResponseWithResult() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "ping" + }; + + var response = await _handler.HandleRequest(request, _context); + + response.JsonRpc.Should().Be("2.0"); + response.Id.Should().Be(1); + response.Result.Should().NotBeNull(); + response.Error.Should().BeNull(); + } + + [Fact] + public async Task HandleRequest_ReturnsErrorForUnknownMethod() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "unknown_method" + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().NotBeNull(); + response.Error!.Code.Should().Be(JsonRpcErrorCodes.MethodNotFound); + response.Error.Message.Should().Contain("unknown_method"); + } + + // Initialize method tests + + [Fact] + public async Task Initialize_ReturnsProtocolVersion() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "initialize" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + json.GetProperty("protocolVersion").GetString().Should().Be("2024-11-05"); + } + + [Fact] + public async Task Initialize_ReturnsCapabilities() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "initialize" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + json.TryGetProperty("capabilities", out var capabilities).Should().BeTrue(); + capabilities.TryGetProperty("tools", out _).Should().BeTrue(); + } + + [Fact] + public async Task Initialize_ReturnsServerInfo() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "initialize" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + json.GetProperty("serverInfo").GetProperty("name").GetString() + .Should().Be("auth-mcp-server-csharp"); + json.GetProperty("serverInfo").GetProperty("version").GetString() + .Should().Be("1.0.0"); + } + + // tools/list tests + + [Fact] + public async Task ToolsList_ReturnsRegisteredTools() + { + _handler.RegisterTool("my-tool", new ToolSpec + { + Name = "my-tool", + Description = "A tool" + }, async (args, ctx) => ToolResult.Text("result")); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/list" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + json.TryGetProperty("tools", out var tools).Should().BeTrue(); + tools.GetArrayLength().Should().Be(1); + } + + [Fact] + public async Task ToolsList_ReturnsEmptyArrayWhenNoTools() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/list" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + json.GetProperty("tools").GetArrayLength().Should().Be(0); + } + + // tools/call tests + + [Fact] + public async Task ToolsCall_InvokesCorrectHandler() + { + var wasInvoked = false; + _handler.RegisterTool("test-tool", new ToolSpec { Name = "test-tool" }, + async (args, ctx) => + { + wasInvoked = true; + return ToolResult.Text("invoked"); + }); + + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "test-tool", + arguments = new { } + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + await _handler.HandleRequest(request, _context); + + wasInvoked.Should().BeTrue(); + } + + [Fact] + public async Task ToolsCall_PassesArguments() + { + string? receivedArg = null; + _handler.RegisterTool("echo", new ToolSpec { Name = "echo" }, + async (args, ctx) => + { + receivedArg = args?.GetProperty("message").GetString(); + return ToolResult.Text(receivedArg ?? ""); + }); + + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "echo", + arguments = new { message = "hello world" } + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + await _handler.HandleRequest(request, _context); + + receivedArg.Should().Be("hello world"); + } + + [Fact] + public async Task ToolsCall_MissingName_ReturnsError() + { + var paramsJson = JsonSerializer.SerializeToElement(new + { + arguments = new { } + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().NotBeNull(); + response.Error!.Code.Should().Be(JsonRpcErrorCodes.InvalidParams); + response.Error.Message.Should().Contain("name"); + } + + [Fact] + public async Task ToolsCall_UnknownTool_ReturnsError() + { + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "nonexistent-tool" + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().NotBeNull(); + response.Error!.Code.Should().Be(JsonRpcErrorCodes.InvalidParams); + response.Error.Message.Should().Contain("nonexistent-tool"); + } + + [Fact] + public async Task ToolsCall_MissingParams_ReturnsError() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = null + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().NotBeNull(); + response.Error!.Code.Should().Be(JsonRpcErrorCodes.InvalidParams); + } + + // ping tests + + [Fact] + public async Task Ping_ReturnsEmptyObject() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "ping" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + // Empty anonymous object serializes to {} + json.EnumerateObject().Should().BeEmpty(); + } + + // JsonRpcException tests + + [Fact] + public void JsonRpcException_CreatesErrorWithCode() + { + var ex = new JsonRpcException(-32600, "Invalid request"); + + ex.Error.Code.Should().Be(-32600); + ex.Error.Message.Should().Be("Invalid request"); + } + + [Fact] + public void JsonRpcException_IncludesDataWhenProvided() + { + var ex = new JsonRpcException(-32602, "Invalid params", new { field = "name" }); + + ex.Error.Data.Should().NotBeNull(); + } + + // Additional method handler tests + + [Fact] + public async Task HandleRequest_PreservesStringId() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = "my-string-id", + Method = "ping" + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Id.Should().Be("my-string-id"); + } + + [Fact] + public async Task HandleRequest_PreservesNullId() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = null, + Method = "ping" + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Id.Should().BeNull(); + } + + [Fact] + public async Task ToolsCall_ReturnsToolResultStructure() + { + _handler.RegisterTool("test", new ToolSpec { Name = "test" }, + async (args, ctx) => ToolResult.Text("test result")); + + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "test" + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.Content.Should().HaveCount(1); + result.Content[0].Type.Should().Be("text"); + result.Content[0].Text.Should().Be("test result"); + } + + [Fact] + public async Task ToolsCall_ReturnsErrorToolResult() + { + _handler.RegisterTool("failing", new ToolSpec { Name = "failing" }, + async (args, ctx) => ToolResult.Error("Something went wrong")); + + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "failing" + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.IsError.Should().BeTrue(); + result.Content[0].Text.Should().Be("Something went wrong"); + } + + [Fact] + public async Task Initialize_HasToolsCapability() + { + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "initialize" + }; + + var response = await _handler.HandleRequest(request, _context); + var json = JsonSerializer.SerializeToElement(response.Result); + + var tools = json.GetProperty("capabilities").GetProperty("tools"); + tools.GetProperty("listChanged").GetBoolean().Should().BeFalse(); + } + + [Fact] + public async Task HandleRequest_CatchesToolHandlerException() + { + _handler.RegisterTool("throwing", new ToolSpec { Name = "throwing" }, + async (args, ctx) => throw new InvalidOperationException("Handler failed")); + + var paramsJson = JsonSerializer.SerializeToElement(new { name = "throwing" }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().NotBeNull(); + response.Error!.Code.Should().Be(JsonRpcErrorCodes.InternalError); + response.Error.Message.Should().Contain("Handler failed"); + } + + [Fact] + public async Task ToolsCall_WorksWithoutArguments() + { + _handler.RegisterTool("no-args", new ToolSpec { Name = "no-args" }, + async (args, ctx) => + { + args.Should().BeNull(); + return ToolResult.Text("no args needed"); + }); + + var paramsJson = JsonSerializer.SerializeToElement(new + { + name = "no-args" + }); + + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + var response = await _handler.HandleRequest(request, _context); + + response.Error.Should().BeNull(); + } +} diff --git a/examples/auth/AuthMcpServer.Tests/Tools/WeatherToolsTests.cs b/examples/auth/AuthMcpServer.Tests/Tools/WeatherToolsTests.cs new file mode 100644 index 00000000..5c2e4ea2 --- /dev/null +++ b/examples/auth/AuthMcpServer.Tests/Tools/WeatherToolsTests.cs @@ -0,0 +1,251 @@ +using System.Text.Json; +using AuthMcpServer.Models.JsonRpc; +using AuthMcpServer.Models.Mcp; +using AuthMcpServer.Services; +using AuthMcpServer.Tools; +using GopherOrch.Auth; +using FluentAssertions; +using Microsoft.AspNetCore.Http; + +namespace AuthMcpServer.Tests.Tools; + +/// +/// Tests for weather tools with scope-based access control. +/// +public class WeatherToolsTests +{ + private readonly McpHandler _handler; + private readonly DefaultHttpContext _context; + + public WeatherToolsTests() + { + _handler = new McpHandler(); + WeatherTools.Register(_handler); + _context = new DefaultHttpContext(); + } + + private void SetAuthContext(string scopes) + { + _context.Items["AuthContext"] = new AuthContext( + "user1", scopes, "", 0, true); + } + + private void SetNoAuthContext() + { + _context.Items["AuthContext"] = AuthContext.Empty(); + } + + private async Task CallTool(string name, object? arguments = null) + { + var paramsObj = new Dictionary { { "name", name } }; + if (arguments != null) + { + paramsObj["arguments"] = arguments; + } + + var paramsJson = JsonSerializer.SerializeToElement(paramsObj); + var request = new JsonRpcRequest + { + JsonRpc = "2.0", + Id = 1, + Method = "tools/call", + Params = paramsJson + }; + + return await _handler.HandleRequest(request, _context); + } + + // get-weather tests (public, no scope required) + + [Fact] + public async Task GetWeather_WorksWithoutScope() + { + SetNoAuthContext(); + + var response = await CallTool("get-weather", new { city = "London" }); + + response.Error.Should().BeNull(); + var result = response.Result as ToolResult; + result.Should().NotBeNull(); + result!.IsError.Should().BeFalse(); + } + + [Fact] + public async Task GetWeather_ReturnsWeatherData() + { + SetNoAuthContext(); + + var response = await CallTool("get-weather", new { city = "Paris" }); + var result = response.Result as ToolResult; + var json = JsonDocument.Parse(result!.Content[0].Text!); + + json.RootElement.GetProperty("city").GetString().Should().Be("Paris"); + json.RootElement.TryGetProperty("temperature", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("condition", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("humidity", out _).Should().BeTrue(); + json.RootElement.TryGetProperty("windSpeed", out _).Should().BeTrue(); + } + + [Fact] + public async Task GetWeather_ReturnsDifferentDataForDifferentCities() + { + SetNoAuthContext(); + + var response1 = await CallTool("get-weather", new { city = "Tokyo" }); + var response2 = await CallTool("get-weather", new { city = "Sydney" }); + + var result1 = response1.Result as ToolResult; + var result2 = response2.Result as ToolResult; + + result1!.Content[0].Text.Should().NotBe(result2!.Content[0].Text); + } + + // get-forecast tests (requires mcp:read scope) + + [Fact] + public async Task GetForecast_RequiresMcpReadScope() + { + SetNoAuthContext(); + + var response = await CallTool("get-forecast", new { city = "Berlin" }); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.IsError.Should().BeTrue(); + result.Content[0].Text.Should().Contain("mcp:read"); + } + + [Fact] + public async Task GetForecast_WorksWithMcpReadScope() + { + SetAuthContext("mcp:read"); + + var response = await CallTool("get-forecast", new { city = "Berlin" }); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.IsError.Should().BeFalse(); + } + + [Fact] + public async Task GetForecast_ReturnsForecastData() + { + SetAuthContext("openid mcp:read"); + + var response = await CallTool("get-forecast", new { city = "Rome" }); + var result = response.Result as ToolResult; + var json = JsonDocument.Parse(result!.Content[0].Text!); + + json.RootElement.GetProperty("city").GetString().Should().Be("Rome"); + var forecast = json.RootElement.GetProperty("forecast"); + forecast.GetArrayLength().Should().Be(5); // 5-day forecast + } + + [Fact] + public async Task GetForecast_ForecastContainsDayDetails() + { + SetAuthContext("mcp:read"); + + var response = await CallTool("get-forecast", new { city = "Madrid" }); + var result = response.Result as ToolResult; + var json = JsonDocument.Parse(result!.Content[0].Text!); + + var firstDay = json.RootElement.GetProperty("forecast")[0]; + firstDay.TryGetProperty("day", out _).Should().BeTrue(); + firstDay.TryGetProperty("high", out _).Should().BeTrue(); + firstDay.TryGetProperty("low", out _).Should().BeTrue(); + firstDay.TryGetProperty("condition", out _).Should().BeTrue(); + } + + // get-weather-alerts tests (requires mcp:admin scope) + + [Fact] + public async Task GetWeatherAlerts_RequiresMcpAdminScope() + { + SetNoAuthContext(); + + var response = await CallTool("get-weather-alerts", new { region = "Europe" }); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.IsError.Should().BeTrue(); + result.Content[0].Text.Should().Contain("mcp:admin"); + } + + [Fact] + public async Task GetWeatherAlerts_FailsWithOnlyMcpReadScope() + { + SetAuthContext("mcp:read"); + + var response = await CallTool("get-weather-alerts", new { region = "Europe" }); + var result = response.Result as ToolResult; + + result!.IsError.Should().BeTrue(); + result.Content[0].Text.Should().Contain("mcp:admin"); + } + + [Fact] + public async Task GetWeatherAlerts_WorksWithMcpAdminScope() + { + SetAuthContext("mcp:admin"); + + var response = await CallTool("get-weather-alerts", new { region = "Europe" }); + var result = response.Result as ToolResult; + + result.Should().NotBeNull(); + result!.IsError.Should().BeFalse(); + } + + [Fact] + public async Task GetWeatherAlerts_ReturnsAlertData() + { + SetAuthContext("mcp:admin mcp:read"); + + var response = await CallTool("get-weather-alerts", new { region = "Asia" }); + var result = response.Result as ToolResult; + var json = JsonDocument.Parse(result!.Content[0].Text!); + + json.RootElement.GetProperty("region").GetString().Should().Be("Asia"); + var alerts = json.RootElement.GetProperty("alerts"); + alerts.GetArrayLength().Should().BeGreaterThan(0); + } + + [Fact] + public async Task GetWeatherAlerts_AlertsContainDetails() + { + SetAuthContext("mcp:admin"); + + var response = await CallTool("get-weather-alerts", new { region = "Americas" }); + var result = response.Result as ToolResult; + var json = JsonDocument.Parse(result!.Content[0].Text!); + + var firstAlert = json.RootElement.GetProperty("alerts")[0]; + firstAlert.TryGetProperty("type", out _).Should().BeTrue(); + firstAlert.TryGetProperty("severity", out _).Should().BeTrue(); + firstAlert.TryGetProperty("message", out _).Should().BeTrue(); + } + + // Tool registration tests + + [Fact] + public void Register_RegistersAllThreeTools() + { + var tools = _handler.GetTools().ToList(); + + tools.Should().HaveCount(3); + tools.Select(t => t.Name).Should().Contain("get-weather"); + tools.Select(t => t.Name).Should().Contain("get-forecast"); + tools.Select(t => t.Name).Should().Contain("get-weather-alerts"); + } + + [Fact] + public void Register_ToolsHaveInputSchemas() + { + var tools = _handler.GetTools().ToList(); + + foreach (var tool in tools) + { + tool.InputSchema.Should().NotBeNull(); + } + } +} diff --git a/examples/auth/AuthMcpServer.sln b/examples/auth/AuthMcpServer.sln new file mode 100644 index 00000000..19332555 --- /dev/null +++ b/examples/auth/AuthMcpServer.sln @@ -0,0 +1,28 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31903.59 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthMcpServer", "AuthMcpServer\AuthMcpServer.csproj", "{1E984281-F1EF-4F27-A213-0634E93F05CD}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthMcpServer.Tests", "AuthMcpServer.Tests\AuthMcpServer.Tests.csproj", "{537D71D8-045F-459A-BD79-B4E01A69C976}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {1E984281-F1EF-4F27-A213-0634E93F05CD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1E984281-F1EF-4F27-A213-0634E93F05CD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1E984281-F1EF-4F27-A213-0634E93F05CD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1E984281-F1EF-4F27-A213-0634E93F05CD}.Release|Any CPU.Build.0 = Release|Any CPU + {537D71D8-045F-459A-BD79-B4E01A69C976}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {537D71D8-045F-459A-BD79-B4E01A69C976}.Debug|Any CPU.Build.0 = Debug|Any CPU + {537D71D8-045F-459A-BD79-B4E01A69C976}.Release|Any CPU.ActiveCfg = Release|Any CPU + {537D71D8-045F-459A-BD79-B4E01A69C976}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection +EndGlobal diff --git a/examples/auth/AuthMcpServer/AuthMcpServer.csproj b/examples/auth/AuthMcpServer/AuthMcpServer.csproj new file mode 100644 index 00000000..8bad15fb --- /dev/null +++ b/examples/auth/AuthMcpServer/AuthMcpServer.csproj @@ -0,0 +1,17 @@ + + + + net8.0 + enable + enable + + + + + + + + + + + diff --git a/examples/auth/AuthMcpServer/Config/AuthServerConfig.cs b/examples/auth/AuthMcpServer/Config/AuthServerConfig.cs new file mode 100644 index 00000000..419f5b88 --- /dev/null +++ b/examples/auth/AuthMcpServer/Config/AuthServerConfig.cs @@ -0,0 +1,42 @@ +namespace AuthMcpServer.Config; + +/// +/// Configuration for the Auth MCP Server. +/// Mirrors the configuration structure from the C++ and TypeScript examples. +/// +public class AuthServerConfig +{ + // Server settings + public string Host { get; init; } = "0.0.0.0"; + public int Port { get; init; } = 3001; + public string ServerUrl { get; init; } = "http://localhost:3001"; + + // OAuth/IDP settings + public string AuthServerUrl { get; init; } = ""; + public string JwksUri { get; init; } = ""; + public string Issuer { get; init; } = ""; + public string ClientId { get; init; } = ""; + public string ClientSecret { get; init; } = ""; + public string OAuthAuthorizeUrl { get; init; } = ""; + public string OAuthTokenUrl { get; init; } = ""; + + // Scopes + public string AllowedScopes { get; init; } = "openid profile email mcp:read mcp:admin"; + + // Cache settings + public int JwksCacheDuration { get; init; } = 3600; + public bool JwksAutoRefresh { get; init; } = true; + public int RequestTimeout { get; init; } = 30; + + // Auth bypass mode + public bool AuthDisabled { get; init; } = false; + + /// + /// Create a default configuration with auth disabled. + /// Useful for development and testing. + /// + public static AuthServerConfig DefaultDisabled() => new() + { + AuthDisabled = true + }; +} diff --git a/examples/auth/AuthMcpServer/Config/ConfigLoader.cs b/examples/auth/AuthMcpServer/Config/ConfigLoader.cs new file mode 100644 index 00000000..d918eb91 --- /dev/null +++ b/examples/auth/AuthMcpServer/Config/ConfigLoader.cs @@ -0,0 +1,162 @@ +namespace AuthMcpServer.Config; + +/// +/// Loads and parses configuration from INI-style files. +/// +public static class ConfigLoader +{ + /// + /// Parse a configuration file in key=value format. + /// Supports # comments and empty lines. + /// + /// Raw file content + /// Dictionary of key-value pairs + public static Dictionary ParseConfigFile(string content) + { + var result = new Dictionary(StringComparer.OrdinalIgnoreCase); + + foreach (var line in content.Split('\n')) + { + var trimmed = line.Trim(); + + // Skip empty lines and comments + if (string.IsNullOrEmpty(trimmed) || trimmed.StartsWith('#')) + { + continue; + } + + var eqIndex = trimmed.IndexOf('='); + if (eqIndex == -1) + { + continue; + } + + var key = trimmed[..eqIndex].Trim(); + var value = trimmed[(eqIndex + 1)..].Trim(); + + if (!string.IsNullOrEmpty(key)) + { + result[key] = value; + } + } + + return result; + } + + /// + /// Build AuthServerConfig from a parsed key-value map. + /// Derives OAuth endpoints from auth_server_url when not explicitly set. + /// + /// Parsed configuration map + /// AuthServerConfig object + public static AuthServerConfig BuildConfig(Dictionary configMap) + { + var port = GetInt(configMap, "port", 3001); + var authServerUrl = GetString(configMap, "auth_server_url", ""); + + // Get explicit values or empty strings + var jwksUri = GetString(configMap, "jwks_uri", ""); + var issuer = GetString(configMap, "issuer", ""); + var oauthAuthorizeUrl = GetString(configMap, "oauth_authorize_url", ""); + var oauthTokenUrl = GetString(configMap, "oauth_token_url", ""); + + // Derive endpoints from auth_server_url if not explicitly set + if (!string.IsNullOrEmpty(authServerUrl)) + { + if (string.IsNullOrEmpty(jwksUri)) + { + jwksUri = $"{authServerUrl}/protocol/openid-connect/certs"; + } + if (string.IsNullOrEmpty(issuer)) + { + issuer = authServerUrl; + } + if (string.IsNullOrEmpty(oauthAuthorizeUrl)) + { + oauthAuthorizeUrl = $"{authServerUrl}/protocol/openid-connect/auth"; + } + if (string.IsNullOrEmpty(oauthTokenUrl)) + { + oauthTokenUrl = $"{authServerUrl}/protocol/openid-connect/token"; + } + } + + return new AuthServerConfig + { + // Server settings + Host = GetString(configMap, "host", "0.0.0.0"), + Port = port, + ServerUrl = GetString(configMap, "server_url", $"http://localhost:{port}"), + + // OAuth/IDP settings + AuthServerUrl = authServerUrl, + JwksUri = jwksUri, + Issuer = issuer, + ClientId = GetString(configMap, "client_id", ""), + ClientSecret = GetString(configMap, "client_secret", ""), + OAuthAuthorizeUrl = oauthAuthorizeUrl, + OAuthTokenUrl = oauthTokenUrl, + + // Scopes + AllowedScopes = GetString(configMap, "allowed_scopes", "openid profile email mcp:read mcp:admin"), + + // Cache settings + JwksCacheDuration = GetInt(configMap, "jwks_cache_duration", 3600), + JwksAutoRefresh = GetBool(configMap, "jwks_auto_refresh", true), + RequestTimeout = GetInt(configMap, "request_timeout", 30), + + // Auth bypass mode + AuthDisabled = GetBool(configMap, "auth_disabled", false) + }; + } + + /// + /// Load and parse configuration from a file. + /// Returns default config with auth disabled if file doesn't exist. + /// + /// Path to the configuration file + /// AuthServerConfig object + public static AuthServerConfig LoadFromFile(string path) + { + if (!File.Exists(path)) + { + Console.WriteLine($"Config file not found: {path}. Using defaults with auth disabled."); + return AuthServerConfig.DefaultDisabled(); + } + + try + { + var content = File.ReadAllText(path); + var configMap = ParseConfigFile(content); + return BuildConfig(configMap); + } + catch (Exception ex) + { + Console.WriteLine($"Failed to load config from {path}: {ex.Message}. Using defaults with auth disabled."); + return AuthServerConfig.DefaultDisabled(); + } + } + + private static string GetString(Dictionary map, string key, string defaultValue) + { + return map.TryGetValue(key, out var value) ? value : defaultValue; + } + + private static int GetInt(Dictionary map, string key, int defaultValue) + { + if (map.TryGetValue(key, out var value) && int.TryParse(value, out var result)) + { + return result; + } + return defaultValue; + } + + private static bool GetBool(Dictionary map, string key, bool defaultValue) + { + if (map.TryGetValue(key, out var value)) + { + return value.Equals("true", StringComparison.OrdinalIgnoreCase); + } + return defaultValue; + } +} diff --git a/examples/auth/AuthMcpServer/Middleware/CorsMiddleware.cs b/examples/auth/AuthMcpServer/Middleware/CorsMiddleware.cs new file mode 100644 index 00000000..d4f0385e --- /dev/null +++ b/examples/auth/AuthMcpServer/Middleware/CorsMiddleware.cs @@ -0,0 +1,52 @@ +namespace AuthMcpServer.Middleware; + +/// +/// CORS middleware that sets proper headers on ALL responses. +/// Critical for MCP Inspector and browser-based clients to work correctly. +/// +public class CorsMiddleware +{ + private readonly RequestDelegate _next; + + private const string AllowedMethods = "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"; + private const string AllowedHeaders = "Accept, Accept-Language, Content-Language, Content-Type, " + + "Authorization, X-Requested-With, Origin, Cache-Control, Pragma, " + + "Mcp-Session-Id, Mcp-Protocol-Version"; + private const string ExposedHeaders = "WWW-Authenticate, Content-Length, Content-Type"; + private const string MaxAge = "86400"; + + public CorsMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task InvokeAsync(HttpContext context) + { + // Set CORS headers on ALL responses + SetCorsHeaders(context.Response); + + // Handle OPTIONS preflight + if (context.Request.Method == "OPTIONS") + { + context.Response.StatusCode = 204; + context.Response.ContentLength = 0; + return; + } + + await _next(context); + } + + /// + /// Set CORS headers on a response. + /// Can be called from other middleware/handlers to ensure CORS headers + /// are present on error responses (e.g., 401 Unauthorized). + /// + public static void SetCorsHeaders(HttpResponse response) + { + response.Headers["Access-Control-Allow-Origin"] = "*"; + response.Headers["Access-Control-Allow-Methods"] = AllowedMethods; + response.Headers["Access-Control-Allow-Headers"] = AllowedHeaders; + response.Headers["Access-Control-Expose-Headers"] = ExposedHeaders; + response.Headers["Access-Control-Max-Age"] = MaxAge; + } +} diff --git a/examples/auth/AuthMcpServer/Middleware/OAuthAuthMiddleware.cs b/examples/auth/AuthMcpServer/Middleware/OAuthAuthMiddleware.cs new file mode 100644 index 00000000..c46c29cb --- /dev/null +++ b/examples/auth/AuthMcpServer/Middleware/OAuthAuthMiddleware.cs @@ -0,0 +1,124 @@ +using AuthMcpServer.Config; +using GopherOrch.Auth; + +namespace AuthMcpServer.Middleware; + +/// +/// OAuth authentication middleware for protected endpoints. +/// +public class OAuthAuthMiddleware +{ + private readonly RequestDelegate _next; + private readonly AuthServerConfig _config; + + private static readonly string[] PublicPathPrefixes = new[] + { + "/health", "/.well-known/", "/oauth/", "/authorize", "/favicon.ico" + }; + + public OAuthAuthMiddleware(RequestDelegate next, AuthServerConfig config) + { + _next = next; + _config = config; + } + + public async Task InvokeAsync(HttpContext context) + { + var path = context.Request.Path.Value ?? ""; + + // Skip auth for public paths + if (IsPublicPath(path)) + { + context.Items["AuthContext"] = AuthContext.Empty(); + await _next(context); + return; + } + + // Auth disabled - allow all with anonymous context + if (_config.AuthDisabled) + { + context.Items["AuthContext"] = AuthContext.Anonymous(_config.AllowedScopes); + await _next(context); + return; + } + + // Extract bearer token + var token = ExtractToken(context.Request); + if (string.IsNullOrEmpty(token)) + { + await SendUnauthorized(context, "invalid_request", "Missing bearer token"); + return; + } + + // Token present - allow request with anonymous context + // In production, you would validate the token here + context.Items["AuthContext"] = AuthContext.Anonymous(_config.AllowedScopes); + await _next(context); + } + + private static bool IsPublicPath(string path) + { + return PublicPathPrefixes.Any(prefix => + path.Equals(prefix, StringComparison.OrdinalIgnoreCase) || + path.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)); + } + + /// + /// Extract bearer token from request. + /// + public static string? ExtractToken(HttpRequest request) + { + // Try Authorization header first + var authHeader = request.Headers.Authorization.FirstOrDefault(); + if (!string.IsNullOrEmpty(authHeader) && + authHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + return authHeader.Substring(7); + } + + // Try query parameter + if (request.Query.TryGetValue("access_token", out var token) && + !string.IsNullOrEmpty(token)) + { + return token; + } + + return null; + } + + private async Task SendUnauthorized( + HttpContext context, + string error, + string description) + { + // Build WWW-Authenticate header per RFC 6750 + var wwwAuth = $"Bearer realm=\"{_config.ServerUrl}\", " + + $"resource_metadata=\"{_config.ServerUrl}/.well-known/oauth-protected-resource\", " + + $"scope=\"{_config.AllowedScopes}\", " + + $"error=\"{EscapeHeaderValue(error)}\", " + + $"error_description=\"{EscapeHeaderValue(description)}\""; + + context.Response.StatusCode = 401; + context.Response.Headers["WWW-Authenticate"] = wwwAuth; + + // Ensure CORS headers are set + CorsMiddleware.SetCorsHeaders(context.Response); + + // Write JSON error body + await context.Response.WriteAsJsonAsync(new + { + error, + error_description = description + }); + } + + /// + /// Escape special characters in header values per RFC 6750. + /// + public static string EscapeHeaderValue(string value) + { + if (string.IsNullOrEmpty(value)) + return ""; + return value.Replace("\\", "\\\\").Replace("\"", "\\\""); + } +} diff --git a/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcError.cs b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcError.cs new file mode 100644 index 00000000..864e94f4 --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcError.cs @@ -0,0 +1,19 @@ +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.JsonRpc; + +/// +/// JSON-RPC 2.0 error object. +/// +public class JsonRpcError +{ + [JsonPropertyName("code")] + public int Code { get; set; } + + [JsonPropertyName("message")] + public string Message { get; set; } = ""; + + [JsonPropertyName("data")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? Data { get; set; } +} diff --git a/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcErrorCodes.cs b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcErrorCodes.cs new file mode 100644 index 00000000..994ecc44 --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcErrorCodes.cs @@ -0,0 +1,32 @@ +namespace AuthMcpServer.Models.JsonRpc; + +/// +/// Standard JSON-RPC 2.0 error codes. +/// +public static class JsonRpcErrorCodes +{ + /// + /// Invalid JSON was received by the server. + /// + public const int ParseError = -32700; + + /// + /// The JSON sent is not a valid Request object. + /// + public const int InvalidRequest = -32600; + + /// + /// The method does not exist or is not available. + /// + public const int MethodNotFound = -32601; + + /// + /// Invalid method parameter(s). + /// + public const int InvalidParams = -32602; + + /// + /// Internal JSON-RPC error. + /// + public const int InternalError = -32603; +} diff --git a/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcRequest.cs b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcRequest.cs new file mode 100644 index 00000000..57788153 --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcRequest.cs @@ -0,0 +1,22 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.JsonRpc; + +/// +/// JSON-RPC 2.0 request object. +/// +public class JsonRpcRequest +{ + [JsonPropertyName("jsonrpc")] + public string JsonRpc { get; set; } = "2.0"; + + [JsonPropertyName("id")] + public object? Id { get; set; } + + [JsonPropertyName("method")] + public string Method { get; set; } = ""; + + [JsonPropertyName("params")] + public JsonElement? Params { get; set; } +} diff --git a/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcResponse.cs b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcResponse.cs new file mode 100644 index 00000000..b4d7757b --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/JsonRpc/JsonRpcResponse.cs @@ -0,0 +1,23 @@ +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.JsonRpc; + +/// +/// JSON-RPC 2.0 response object. +/// +public class JsonRpcResponse +{ + [JsonPropertyName("jsonrpc")] + public string JsonRpc { get; set; } = "2.0"; + + [JsonPropertyName("id")] + public object? Id { get; set; } + + [JsonPropertyName("result")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? Result { get; set; } + + [JsonPropertyName("error")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public JsonRpcError? Error { get; set; } +} diff --git a/examples/auth/AuthMcpServer/Models/Mcp/ToolContent.cs b/examples/auth/AuthMcpServer/Models/Mcp/ToolContent.cs new file mode 100644 index 00000000..c3ad9439 --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/Mcp/ToolContent.cs @@ -0,0 +1,24 @@ +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.Mcp; + +/// +/// Content item in a tool result. +/// +public class ToolContent +{ + [JsonPropertyName("type")] + public string Type { get; set; } = "text"; + + [JsonPropertyName("text")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Text { get; set; } + + [JsonPropertyName("data")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Data { get; set; } + + [JsonPropertyName("mimeType")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? MimeType { get; set; } +} diff --git a/examples/auth/AuthMcpServer/Models/Mcp/ToolResult.cs b/examples/auth/AuthMcpServer/Models/Mcp/ToolResult.cs new file mode 100644 index 00000000..0e4df938 --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/Mcp/ToolResult.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.Mcp; + +/// +/// Result of a tool invocation. +/// +public class ToolResult +{ + [JsonPropertyName("content")] + public List Content { get; set; } = new(); + + [JsonPropertyName("isError")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public bool IsError { get; set; } + + /// + /// Create a successful text result. + /// + public static ToolResult Text(string text) => new() + { + Content = new List + { + new() { Type = "text", Text = text } + } + }; + + /// + /// Create an error result. + /// + public static ToolResult Error(string message) => new() + { + Content = new List + { + new() { Type = "text", Text = message } + }, + IsError = true + }; +} diff --git a/examples/auth/AuthMcpServer/Models/Mcp/ToolSpec.cs b/examples/auth/AuthMcpServer/Models/Mcp/ToolSpec.cs new file mode 100644 index 00000000..aae3329a --- /dev/null +++ b/examples/auth/AuthMcpServer/Models/Mcp/ToolSpec.cs @@ -0,0 +1,18 @@ +using System.Text.Json.Serialization; + +namespace AuthMcpServer.Models.Mcp; + +/// +/// MCP tool specification describing a callable tool. +/// +public class ToolSpec +{ + [JsonPropertyName("name")] + public string Name { get; set; } = ""; + + [JsonPropertyName("description")] + public string Description { get; set; } = ""; + + [JsonPropertyName("inputSchema")] + public object? InputSchema { get; set; } +} diff --git a/examples/auth/AuthMcpServer/Program.cs b/examples/auth/AuthMcpServer/Program.cs new file mode 100644 index 00000000..cff52af5 --- /dev/null +++ b/examples/auth/AuthMcpServer/Program.cs @@ -0,0 +1,107 @@ +// C# Auth MCP Server +// OAuth-protected MCP server with JWT validation and scope-based access control + +using AuthMcpServer.Config; +using AuthMcpServer.Middleware; +using AuthMcpServer.Routes; +using AuthMcpServer.Services; +using AuthMcpServer.Tools; + +// Print banner +Console.WriteLine(); +Console.WriteLine("╔══════════════════════════════════════╗"); +Console.WriteLine("║ C# Auth MCP Server ║"); +Console.WriteLine("║ Version 1.0.0 ║"); +Console.WriteLine("╚══════════════════════════════════════╝"); +Console.WriteLine(); + +// Load configuration +var configPath = args.Length > 0 ? args[0] : "server.config"; +Console.WriteLine($"Loading configuration from: {configPath}"); +var config = ConfigLoader.LoadFromFile(configPath); + +var builder = WebApplication.CreateBuilder(args); + +// Register services +builder.Services.AddSingleton(config); +builder.Services.AddSingleton(); + +var app = builder.Build(); + +// CORS middleware (must be first!) +app.UseMiddleware(); + +// Map public endpoints +app.MapHealthEndpoints(); +app.MapOAuthEndpoints(); + +// Auth middleware for protected paths +app.UseWhen( + context => IsProtectedPath(context.Request.Path), + appBuilder => appBuilder.UseMiddleware() +); + +// Map protected endpoints +app.MapMcpEndpoints(); + +// Register tools +var mcpHandler = app.Services.GetRequiredService(); +WeatherTools.Register(mcpHandler); + +// Print endpoints +PrintEndpoints(config); +PrintAuthStatus(config); + +// Run server +Console.WriteLine($"Server listening on http://{config.Host}:{config.Port}"); +Console.WriteLine("Press Ctrl+C to shutdown"); +Console.WriteLine(); + +app.Run($"http://{config.Host}:{config.Port}"); + +static bool IsProtectedPath(PathString path) +{ + var p = path.Value ?? ""; + return p.StartsWith("/mcp") || p.StartsWith("/rpc") || + p.StartsWith("/events") || p.StartsWith("/sse"); +} + +static void PrintEndpoints(AuthServerConfig config) +{ + Console.WriteLine("Available Endpoints:"); + Console.WriteLine("────────────────────────────────────────"); + Console.WriteLine($" Health: GET {config.ServerUrl}/health"); + Console.WriteLine(); + Console.WriteLine(" OAuth Discovery:"); + Console.WriteLine($" GET {config.ServerUrl}/.well-known/oauth-protected-resource"); + Console.WriteLine($" GET {config.ServerUrl}/.well-known/oauth-authorization-server"); + Console.WriteLine($" GET {config.ServerUrl}/.well-known/openid-configuration"); + Console.WriteLine(); + Console.WriteLine(" OAuth Endpoints:"); + Console.WriteLine($" GET {config.ServerUrl}/oauth/authorize"); + Console.WriteLine($" POST {config.ServerUrl}/oauth/register"); + Console.WriteLine(); + Console.WriteLine(" MCP Endpoints:"); + Console.WriteLine($" POST {config.ServerUrl}/mcp"); + Console.WriteLine($" POST {config.ServerUrl}/rpc"); + Console.WriteLine("────────────────────────────────────────"); + Console.WriteLine(); +} + +static void PrintAuthStatus(AuthServerConfig config) +{ + Console.WriteLine("Authentication Status:"); + Console.WriteLine("────────────────────────────────────────"); + if (config.AuthDisabled) + { + Console.WriteLine(" Status: DISABLED (development mode)"); + } + else + { + Console.WriteLine(" Status: ENABLED (OAuth flow, no validation)"); + Console.WriteLine(" JWKS URI: " + (config.JwksUri ?? "not configured")); + Console.WriteLine(" Issuer: " + (config.Issuer ?? config.ServerUrl)); + } + Console.WriteLine("────────────────────────────────────────"); + Console.WriteLine(); +} diff --git a/examples/auth/AuthMcpServer/Properties/launchSettings.json b/examples/auth/AuthMcpServer/Properties/launchSettings.json new file mode 100644 index 00000000..ed5c0952 --- /dev/null +++ b/examples/auth/AuthMcpServer/Properties/launchSettings.json @@ -0,0 +1,38 @@ +{ + "$schema": "http://json.schemastore.org/launchsettings.json", + "iisSettings": { + "windowsAuthentication": false, + "anonymousAuthentication": true, + "iisExpress": { + "applicationUrl": "http://localhost:8450", + "sslPort": 44376 + } + }, + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "http://localhost:5190", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:7100;http://localhost:5190", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "IIS Express": { + "commandName": "IISExpress", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/examples/auth/AuthMcpServer/Routes/HealthEndpoints.cs b/examples/auth/AuthMcpServer/Routes/HealthEndpoints.cs new file mode 100644 index 00000000..21e716c0 --- /dev/null +++ b/examples/auth/AuthMcpServer/Routes/HealthEndpoints.cs @@ -0,0 +1,27 @@ +namespace AuthMcpServer.Routes; + +/// +/// Health check endpoint for server status monitoring. +/// +public static class HealthEndpoints +{ + private static readonly DateTime StartTime = DateTime.UtcNow; + private const string Version = "1.0.0"; + + /// + /// Map health check endpoints to the application. + /// + public static void MapHealthEndpoints(this WebApplication app) + { + app.MapGet("/health", () => + { + var uptime = DateTime.UtcNow - StartTime; + return Results.Json(new + { + status = "healthy", + version = Version, + uptime = $"{uptime.TotalSeconds:F0}s" + }); + }); + } +} diff --git a/examples/auth/AuthMcpServer/Routes/McpEndpoints.cs b/examples/auth/AuthMcpServer/Routes/McpEndpoints.cs new file mode 100644 index 00000000..8eede1af --- /dev/null +++ b/examples/auth/AuthMcpServer/Routes/McpEndpoints.cs @@ -0,0 +1,68 @@ +using System.Text.Json; +using AuthMcpServer.Models.JsonRpc; +using AuthMcpServer.Services; + +namespace AuthMcpServer.Routes; + +/// +/// MCP JSON-RPC endpoints. +/// +public static class McpEndpoints +{ + public static void MapMcpEndpoints(this WebApplication app) + { + var handler = app.Services.GetRequiredService(); + + app.MapPost("/mcp", async (HttpContext context) => + { + return await HandleMcpRequest(context, handler); + }); + + app.MapPost("/rpc", async (HttpContext context) => + { + return await HandleMcpRequest(context, handler); + }); + } + + private static async Task HandleMcpRequest( + HttpContext context, + McpHandler handler) + { + JsonRpcRequest? request; + + try + { + request = await context.Request.ReadFromJsonAsync(); + } + catch (JsonException) + { + return Results.Json(new JsonRpcResponse + { + JsonRpc = "2.0", + Id = null, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.ParseError, + Message = "Parse error: invalid JSON" + } + }); + } + + if (request == null || string.IsNullOrEmpty(request.Method)) + { + return Results.Json(new JsonRpcResponse + { + JsonRpc = "2.0", + Id = null, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.InvalidRequest, + Message = "Invalid request: missing method" + } + }); + } + + var response = await handler.HandleRequest(request, context); + return Results.Json(response); + } +} diff --git a/examples/auth/AuthMcpServer/Routes/OAuthEndpoints.cs b/examples/auth/AuthMcpServer/Routes/OAuthEndpoints.cs new file mode 100644 index 00000000..b041a362 --- /dev/null +++ b/examples/auth/AuthMcpServer/Routes/OAuthEndpoints.cs @@ -0,0 +1,203 @@ +using System.Text.Json; +using System.Web; +using AuthMcpServer.Config; +using GopherOrch.Auth.OAuth; + +namespace AuthMcpServer.Routes; + +/// +/// OAuth discovery and endpoint handlers. +/// +public static class OAuthEndpoints +{ + /// + /// Map OAuth discovery and endpoint routes. + /// + public static void MapOAuthEndpoints(this WebApplication app) + { + // Protected Resource Metadata (RFC 9728) + app.MapGet("/.well-known/oauth-protected-resource", (AuthServerConfig config) => + { + return Results.Json(BuildProtectedResourceMetadata(config)); + }); + + // Protected Resource Metadata for /mcp path specifically + app.MapGet("/.well-known/oauth-protected-resource/mcp", (AuthServerConfig config) => + { + return Results.Json(BuildProtectedResourceMetadata(config)); + }); + + // Authorization Server Metadata (RFC 8414) + app.MapGet("/.well-known/oauth-authorization-server", (AuthServerConfig config) => + { + return Results.Json(BuildAuthorizationServerMetadata(config)); + }); + + // OpenID Connect Discovery + app.MapGet("/.well-known/openid-configuration", (AuthServerConfig config) => + { + return Results.Json(BuildOpenIdConfiguration(config)); + }); + + // OAuth Authorize Redirect + app.MapGet("/oauth/authorize", (HttpContext ctx, AuthServerConfig config) => + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + + try + { + var uriBuilder = new UriBuilder(authEndpoint); + var query = HttpUtility.ParseQueryString(uriBuilder.Query); + + // Forward ALL query parameters from the request + foreach (var param in ctx.Request.Query) + { + query[param.Key] = param.Value.ToString(); + } + + uriBuilder.Query = query.ToString(); + + // Return 302 Found redirect (not 301 permanent) + return Results.Redirect(uriBuilder.ToString(), permanent: false); + } + catch (Exception) + { + return Results.Json(new + { + error = "server_error", + error_description = "Failed to construct authorization URL" + }, statusCode: 500); + } + }); + + // Dynamic Client Registration (RFC 7591) + app.MapPost("/oauth/register", async (HttpContext ctx, AuthServerConfig config) => + { + // Parse request body + Dictionary? body = null; + try + { + body = await ctx.Request.ReadFromJsonAsync>(); + } + catch { } + + // Extract redirect_uris from request + var redirectUris = Array.Empty(); + if (body?.TryGetValue("redirect_uris", out var urisElement) == true && + urisElement.ValueKind == JsonValueKind.Array) + { + redirectUris = urisElement.EnumerateArray() + .Where(e => e.ValueKind == JsonValueKind.String) + .Select(e => e.GetString()!) + .ToArray(); + } + + // Build response with pre-configured credentials (stateless mode) + var response = new ClientRegistrationResponse + { + ClientId = config.ClientId, + ClientSecret = !string.IsNullOrEmpty(config.ClientSecret) + ? config.ClientSecret : null, + ClientIdIssuedAt = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), + ClientSecretExpiresAt = 0, // Never expires + RedirectUris = redirectUris, + GrantTypes = new[] { "authorization_code", "refresh_token" }, + ResponseTypes = new[] { "code" }, + TokenEndpointAuthMethod = !string.IsNullOrEmpty(config.ClientSecret) + ? "client_secret_post" : "none" + }; + + ctx.Response.StatusCode = 201; + return Results.Json(response); + }); + } + + /// + /// Build protected resource metadata per RFC 9728. + /// + private static ProtectedResourceMetadata BuildProtectedResourceMetadata(AuthServerConfig config) + { + return new ProtectedResourceMetadata + { + Resource = $"{config.ServerUrl}/mcp", + AuthorizationServers = new[] { config.ServerUrl }, + ScopesSupported = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries), + BearerMethodsSupported = new[] { "header", "query" }, + ResourceDocumentation = $"{config.ServerUrl}/docs" + }; + } + + /// + /// Build authorization server metadata per RFC 8414. + /// + private static AuthorizationServerMetadata BuildAuthorizationServerMetadata(AuthServerConfig config) + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + + var tokenEndpoint = !string.IsNullOrEmpty(config.OAuthTokenUrl) + ? config.OAuthTokenUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/token"; + + return new AuthorizationServerMetadata + { + Issuer = !string.IsNullOrEmpty(config.Issuer) + ? config.Issuer : config.ServerUrl, + AuthorizationEndpoint = authEndpoint, + TokenEndpoint = tokenEndpoint, + JwksUri = config.JwksUri, + RegistrationEndpoint = $"{config.ServerUrl}/oauth/register", + ScopesSupported = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries), + ResponseTypesSupported = new[] { "code" }, + GrantTypesSupported = new[] { "authorization_code", "refresh_token" }, + TokenEndpointAuthMethodsSupported = new[] { "client_secret_basic", "client_secret_post", "none" }, + CodeChallengeMethodsSupported = new[] { "S256" } + }; + } + + /// + /// Build OpenID Connect configuration extending auth server metadata. + /// + private static OpenIdConfiguration BuildOpenIdConfiguration(AuthServerConfig config) + { + var authEndpoint = !string.IsNullOrEmpty(config.OAuthAuthorizeUrl) + ? config.OAuthAuthorizeUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/auth"; + + var tokenEndpoint = !string.IsNullOrEmpty(config.OAuthTokenUrl) + ? config.OAuthTokenUrl + : $"{config.AuthServerUrl}/protocol/openid-connect/token"; + + // Merge base OIDC scopes with config scopes + var baseScopes = new[] { "openid", "profile", "email" }; + var configScopes = config.AllowedScopes + .Split(' ', StringSplitOptions.RemoveEmptyEntries); + var allScopes = baseScopes.Union(configScopes).Distinct().ToArray(); + + return new OpenIdConfiguration + { + Issuer = !string.IsNullOrEmpty(config.Issuer) + ? config.Issuer : config.ServerUrl, + AuthorizationEndpoint = authEndpoint, + TokenEndpoint = tokenEndpoint, + JwksUri = config.JwksUri, + RegistrationEndpoint = $"{config.ServerUrl}/oauth/register", + ScopesSupported = allScopes, + ResponseTypesSupported = new[] { "code" }, + GrantTypesSupported = new[] { "authorization_code", "refresh_token" }, + TokenEndpointAuthMethodsSupported = new[] { "client_secret_basic", "client_secret_post", "none" }, + CodeChallengeMethodsSupported = new[] { "S256" }, + // OIDC-specific fields + UserinfoEndpoint = !string.IsNullOrEmpty(config.AuthServerUrl) + ? $"{config.AuthServerUrl}/protocol/openid-connect/userinfo" + : null, + SubjectTypesSupported = new[] { "public" }, + IdTokenSigningAlgValuesSupported = new[] { "RS256" } + }; + } +} diff --git a/examples/auth/AuthMcpServer/Services/McpHandler.cs b/examples/auth/AuthMcpServer/Services/McpHandler.cs new file mode 100644 index 00000000..6a718a41 --- /dev/null +++ b/examples/auth/AuthMcpServer/Services/McpHandler.cs @@ -0,0 +1,173 @@ +using System.Text.Json; +using AuthMcpServer.Models.JsonRpc; +using AuthMcpServer.Models.Mcp; + +namespace AuthMcpServer.Services; + +/// +/// MCP JSON-RPC 2.0 handler service. +/// +public class McpHandler +{ + private readonly Dictionary> Handler)> _tools = new(); + + /// + /// Register a tool with its specification and handler. + /// + public void RegisterTool( + string name, + ToolSpec spec, + Func> handler) + { + _tools[name] = (spec, handler); + } + + /// + /// Get all registered tool specifications. + /// + public IEnumerable GetTools() => + _tools.Values.Select(t => t.Spec); + + /// + /// Handle a JSON-RPC request. + /// + public async Task HandleRequest( + JsonRpcRequest request, + HttpContext context) + { + try + { + var result = await DispatchMethod( + request.Method, + request.Params, + context); + + return new JsonRpcResponse + { + JsonRpc = "2.0", + Id = request.Id, + Result = result + }; + } + catch (JsonRpcException ex) + { + return new JsonRpcResponse + { + JsonRpc = "2.0", + Id = request.Id, + Error = ex.Error + }; + } + catch (Exception ex) + { + return new JsonRpcResponse + { + JsonRpc = "2.0", + Id = request.Id, + Error = new JsonRpcError + { + Code = JsonRpcErrorCodes.InternalError, + Message = ex.Message + } + }; + } + } + + private async Task DispatchMethod( + string method, + JsonElement? @params, + HttpContext context) + { + return method switch + { + "initialize" => HandleInitialize(), + "tools/list" => HandleToolsList(), + "tools/call" => await HandleToolsCall(@params, context), + "ping" => new { }, + _ => throw new JsonRpcException( + JsonRpcErrorCodes.MethodNotFound, + $"Method not found: {method}") + }; + } + + private object HandleInitialize() + { + return new + { + protocolVersion = "2024-11-05", + capabilities = new + { + tools = new { listChanged = false } + }, + serverInfo = new + { + name = "auth-mcp-server-csharp", + version = "1.0.0" + } + }; + } + + private object HandleToolsList() + { + return new { tools = GetTools() }; + } + + private async Task HandleToolsCall( + JsonElement? @params, + HttpContext context) + { + if (@params == null) + throw new JsonRpcException( + JsonRpcErrorCodes.InvalidParams, + "Missing params"); + + // Extract tool name + if (!@params.Value.TryGetProperty("name", out var nameElement) || + nameElement.ValueKind != JsonValueKind.String) + { + throw new JsonRpcException( + JsonRpcErrorCodes.InvalidParams, + "Missing or invalid tool name"); + } + + var name = nameElement.GetString()!; + + // Find tool + if (!_tools.TryGetValue(name, out var tool)) + { + throw new JsonRpcException( + JsonRpcErrorCodes.InvalidParams, + $"Unknown tool: {name}"); + } + + // Extract arguments + JsonElement? args = null; + if (@params.Value.TryGetProperty("arguments", out var argsElement)) + { + args = argsElement; + } + + // Execute tool handler + var result = await tool.Handler(args, context); + return result; + } +} + +/// +/// JSON-RPC exception for structured error responses. +/// +public class JsonRpcException : Exception +{ + public JsonRpcError Error { get; } + + public JsonRpcException(int code, string message, object? data = null) + : base(message) + { + Error = new JsonRpcError + { + Code = code, + Message = message, + Data = data + }; + } +} diff --git a/examples/auth/AuthMcpServer/Tools/WeatherTools.cs b/examples/auth/AuthMcpServer/Tools/WeatherTools.cs new file mode 100644 index 00000000..ba2dd815 --- /dev/null +++ b/examples/auth/AuthMcpServer/Tools/WeatherTools.cs @@ -0,0 +1,153 @@ +using System.Text.Json; +using AuthMcpServer.Models.Mcp; +using AuthMcpServer.Services; +using GopherOrch.Auth; + +namespace AuthMcpServer.Tools; + +/// +/// Weather tools with scope-based access control. +/// +public static class WeatherTools +{ + public static void Register(McpHandler handler) + { + // Public tool - no scope required + handler.RegisterTool( + "get-weather", + new ToolSpec + { + Name = "get-weather", + Description = "Get current weather for a city", + InputSchema = new + { + type = "object", + properties = new + { + city = new { type = "string", description = "City name" } + }, + required = new[] { "city" } + } + }, + async (args, ctx) => + { + var city = args?.GetProperty("city").GetString() ?? "Unknown"; + var weather = GetSimulatedWeather(city); + return ToolResult.Text(JsonSerializer.Serialize(weather)); + } + ); + + // Requires mcp:read scope + handler.RegisterTool( + "get-forecast", + new ToolSpec + { + Name = "get-forecast", + Description = "Get 5-day weather forecast (requires mcp:read scope)", + InputSchema = new + { + type = "object", + properties = new + { + city = new { type = "string", description = "City name" } + }, + required = new[] { "city" } + } + }, + async (args, ctx) => + { + var authContext = ctx.Items["AuthContext"] as AuthContext + ?? AuthContext.Empty(); + + if (!authContext.HasScope("mcp:read")) + { + return ToolResult.Error("Access denied: requires mcp:read scope"); + } + + var city = args?.GetProperty("city").GetString() ?? "Unknown"; + var forecast = GetSimulatedForecast(city); + return ToolResult.Text(JsonSerializer.Serialize(forecast)); + } + ); + + // Requires mcp:admin scope + handler.RegisterTool( + "get-weather-alerts", + new ToolSpec + { + Name = "get-weather-alerts", + Description = "Get weather alerts (requires mcp:admin scope)", + InputSchema = new + { + type = "object", + properties = new + { + region = new { type = "string", description = "Region name" } + }, + required = new[] { "region" } + } + }, + async (args, ctx) => + { + var authContext = ctx.Items["AuthContext"] as AuthContext + ?? AuthContext.Empty(); + + if (!authContext.HasScope("mcp:admin")) + { + return ToolResult.Error("Access denied: requires mcp:admin scope"); + } + + var region = args?.GetProperty("region").GetString() ?? "Unknown"; + var alerts = GetSimulatedAlerts(region); + return ToolResult.Text(JsonSerializer.Serialize(alerts)); + } + ); + } + + private static object GetSimulatedWeather(string city) + { + var hash = city.GetHashCode(); + var conditions = new[] { "Sunny", "Cloudy", "Rainy", "Partly Cloudy", "Stormy" }; + + return new + { + city, + temperature = 15 + Math.Abs(hash) % 25, + condition = conditions[Math.Abs(hash) % conditions.Length], + humidity = 40 + Math.Abs(hash) % 40, + windSpeed = 5 + Math.Abs(hash) % 20 + }; + } + + private static object GetSimulatedForecast(string city) + { + var hash = city.GetHashCode(); + var days = new[] { "Monday", "Tuesday", "Wednesday", "Thursday", "Friday" }; + var conditions = new[] { "Sunny", "Cloudy", "Rainy", "Partly Cloudy" }; + + var forecast = days.Select((day, i) => new + { + day, + high = 20 + (hash + i) % 15, + low = 10 + (hash + i) % 10, + condition = conditions[Math.Abs(hash + i) % conditions.Length] + }).ToArray(); + + return new { city, forecast }; + } + + private static object GetSimulatedAlerts(string region) + { + return new + { + region, + alerts = new[] + { + new { type = "Wind Advisory", severity = "Moderate", + message = "Strong winds expected" }, + new { type = "Heat Warning", severity = "High", + message = "High temperatures forecasted" } + } + }; + } +} diff --git a/examples/auth/AuthMcpServer/appsettings.Development.json b/examples/auth/AuthMcpServer/appsettings.Development.json new file mode 100644 index 00000000..ff66ba6b --- /dev/null +++ b/examples/auth/AuthMcpServer/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/examples/auth/AuthMcpServer/appsettings.json b/examples/auth/AuthMcpServer/appsettings.json new file mode 100644 index 00000000..4d566948 --- /dev/null +++ b/examples/auth/AuthMcpServer/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} diff --git a/examples/auth/run_example.ps1 b/examples/auth/run_example.ps1 new file mode 100644 index 00000000..454a8ae5 --- /dev/null +++ b/examples/auth/run_example.ps1 @@ -0,0 +1,5 @@ +$ScriptDir = $PSScriptRoot +Set-Location "$ScriptDir\AuthMcpServer" +dotnet build -c Release -q +$ConfigPath = if ($args.Count -gt 0) { $args[0] } else { "$ScriptDir\server.config" } +dotnet run -c Release -- $ConfigPath diff --git a/examples/auth/run_example.sh b/examples/auth/run_example.sh new file mode 100755 index 00000000..4220ba06 --- /dev/null +++ b/examples/auth/run_example.sh @@ -0,0 +1,5 @@ +#!/bin/bash +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR/AuthMcpServer" +dotnet build -c Release -q +dotnet run -c Release -- "${1:-$SCRIPT_DIR/server.config}" diff --git a/examples/auth/server.config b/examples/auth/server.config new file mode 100644 index 00000000..8eb1d0fa --- /dev/null +++ b/examples/auth/server.config @@ -0,0 +1,33 @@ +# Auth MCP Server Configuration +# Copy this file to server.config and update values as needed + +# Server settings +host=0.0.0.0 +port=3001 +server_url=https://marni-nightcapped-nonmeditatively.ngrok-free.dev + +# OAuth/IDP settings +# Uncomment and configure for Keycloak or other OAuth provider +client_id=oauth_0a650b79c5a64c3b920ae8c2b20599d9 +client_secret=6BiU2beUi2wIBxY3MUBLyYqoWKa4t0U_kJVm9mvSOKw +auth_server_url=https://auth-test.gopher.security/realms/gopher-mcp-auth +oauth_authorize_url=https://api-test.gopher.security/oauth/authorize + +# Direct OAuth endpoint URLs (optional, derived from auth_server_url if not set) +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# Scopes +exchange_idps=oauth-idp-714982830194556929-google +allowed_scopes=openid profile email scope-001 + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development/testing) +# Set to true to disable authentication +auth_disabled=false diff --git a/examples/auth/server.config.example b/examples/auth/server.config.example new file mode 100644 index 00000000..18c24568 --- /dev/null +++ b/examples/auth/server.config.example @@ -0,0 +1,31 @@ +# Auth MCP Server Configuration +# Copy this file to server.config and update values as needed + +# Server settings +host=0.0.0.0 +port=3001 +server_url=http://localhost:3001 + +# OAuth/IDP settings +# Uncomment and configure for Keycloak or other OAuth provider +# client_id=your-client-id +# client_secret=your-client-secret +# auth_server_url=https://keycloak.example.com/realms/mcp + +# Direct OAuth endpoint URLs (optional, derived from auth_server_url if not set) +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# Scopes +allowed_scopes=openid profile email mcp:read mcp:admin + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development/testing) +# Set to true to disable authentication +auth_disabled=true diff --git a/src/GopherOrch/Auth/AuthContext.cs b/src/GopherOrch/Auth/AuthContext.cs new file mode 100644 index 00000000..19ec7569 --- /dev/null +++ b/src/GopherOrch/Auth/AuthContext.cs @@ -0,0 +1,59 @@ +using System; +using System.Linq; + +namespace GopherOrch.Auth +{ + /// + /// Authentication context that holds user/token information. + /// + public class AuthContext + { + public string UserId { get; } + public string Scopes { get; } + public string Audience { get; } + public long TokenExpiry { get; } + public bool IsAuthenticated { get; } + + public AuthContext( + string userId, + string scopes, + string audience, + long tokenExpiry, + bool isAuthenticated) + { + UserId = userId; + Scopes = scopes; + Audience = audience; + TokenExpiry = tokenExpiry; + IsAuthenticated = isAuthenticated; + } + + /// + /// Check if context has a required scope. + /// + public bool HasScope(string requiredScope) + { + if (string.IsNullOrEmpty(requiredScope)) + return true; + + if (string.IsNullOrEmpty(Scopes)) + return false; + + var scopeList = Scopes.Split(new[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); + return scopeList.Contains(requiredScope, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Create empty context for unauthenticated requests. + /// + public static AuthContext Empty() => + new AuthContext("", "", "", 0, false); + + /// + /// Create anonymous context with all allowed scopes (for auth-disabled mode). + /// + public static AuthContext Anonymous(string scopes) => + new AuthContext("anonymous", scopes, "", + DateTimeOffset.UtcNow.AddHours(1).ToUnixTimeSeconds(), true); + } +} diff --git a/src/GopherOrch/Auth/OAuth/AuthorizationServerMetadata.cs b/src/GopherOrch/Auth/OAuth/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..8e3d0af6 --- /dev/null +++ b/src/GopherOrch/Auth/OAuth/AuthorizationServerMetadata.cs @@ -0,0 +1,43 @@ +using System; +using System.Text.Json.Serialization; + +namespace GopherOrch.Auth.OAuth +{ + /// + /// OAuth 2.0 Authorization Server Metadata per RFC 8414. + /// + public class AuthorizationServerMetadata + { + [JsonPropertyName("issuer")] + public string Issuer { get; set; } = ""; + + [JsonPropertyName("authorization_endpoint")] + public string AuthorizationEndpoint { get; set; } = ""; + + [JsonPropertyName("token_endpoint")] + public string TokenEndpoint { get; set; } = ""; + + [JsonPropertyName("jwks_uri")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? JwksUri { get; set; } + + [JsonPropertyName("registration_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? RegistrationEndpoint { get; set; } + + [JsonPropertyName("scopes_supported")] + public string[] ScopesSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("response_types_supported")] + public string[] ResponseTypesSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("grant_types_supported")] + public string[] GrantTypesSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public string[] TokenEndpointAuthMethodsSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("code_challenge_methods_supported")] + public string[] CodeChallengeMethodsSupported { get; set; } = Array.Empty(); + } +} diff --git a/src/GopherOrch/Auth/OAuth/ClientRegistrationResponse.cs b/src/GopherOrch/Auth/OAuth/ClientRegistrationResponse.cs new file mode 100644 index 00000000..9534d02f --- /dev/null +++ b/src/GopherOrch/Auth/OAuth/ClientRegistrationResponse.cs @@ -0,0 +1,36 @@ +using System; +using System.Text.Json.Serialization; + +namespace GopherOrch.Auth.OAuth +{ + /// + /// OAuth 2.0 Dynamic Client Registration response per RFC 7591. + /// + public class ClientRegistrationResponse + { + [JsonPropertyName("client_id")] + public string ClientId { get; set; } = ""; + + [JsonPropertyName("client_secret")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ClientSecret { get; set; } + + [JsonPropertyName("client_id_issued_at")] + public long ClientIdIssuedAt { get; set; } + + [JsonPropertyName("client_secret_expires_at")] + public long ClientSecretExpiresAt { get; set; } + + [JsonPropertyName("redirect_uris")] + public string[] RedirectUris { get; set; } = Array.Empty(); + + [JsonPropertyName("grant_types")] + public string[] GrantTypes { get; set; } = Array.Empty(); + + [JsonPropertyName("response_types")] + public string[] ResponseTypes { get; set; } = Array.Empty(); + + [JsonPropertyName("token_endpoint_auth_method")] + public string TokenEndpointAuthMethod { get; set; } = "none"; + } +} diff --git a/src/GopherOrch/Auth/OAuth/OpenIdConfiguration.cs b/src/GopherOrch/Auth/OAuth/OpenIdConfiguration.cs new file mode 100644 index 00000000..f4617aeb --- /dev/null +++ b/src/GopherOrch/Auth/OAuth/OpenIdConfiguration.cs @@ -0,0 +1,22 @@ +using System; +using System.Text.Json.Serialization; + +namespace GopherOrch.Auth.OAuth +{ + /// + /// OpenID Connect Discovery configuration. + /// Extends AuthorizationServerMetadata with OIDC-specific fields. + /// + public class OpenIdConfiguration : AuthorizationServerMetadata + { + [JsonPropertyName("userinfo_endpoint")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? UserinfoEndpoint { get; set; } + + [JsonPropertyName("subject_types_supported")] + public string[] SubjectTypesSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("id_token_signing_alg_values_supported")] + public string[] IdTokenSigningAlgValuesSupported { get; set; } = Array.Empty(); + } +} diff --git a/src/GopherOrch/Auth/OAuth/ProtectedResourceMetadata.cs b/src/GopherOrch/Auth/OAuth/ProtectedResourceMetadata.cs new file mode 100644 index 00000000..9f076a88 --- /dev/null +++ b/src/GopherOrch/Auth/OAuth/ProtectedResourceMetadata.cs @@ -0,0 +1,27 @@ +using System; +using System.Text.Json.Serialization; + +namespace GopherOrch.Auth.OAuth +{ + /// + /// OAuth 2.0 Protected Resource Metadata per RFC 9728. + /// + public class ProtectedResourceMetadata + { + [JsonPropertyName("resource")] + public string Resource { get; set; } = ""; + + [JsonPropertyName("authorization_servers")] + public string[] AuthorizationServers { get; set; } = Array.Empty(); + + [JsonPropertyName("scopes_supported")] + public string[] ScopesSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("bearer_methods_supported")] + public string[] BearerMethodsSupported { get; set; } = Array.Empty(); + + [JsonPropertyName("resource_documentation")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ResourceDocumentation { get; set; } + } +} diff --git a/third_party/gopher-orch b/third_party/gopher-orch index 6b45ffbb..c8e7c406 160000 --- a/third_party/gopher-orch +++ b/third_party/gopher-orch @@ -1 +1 @@ -Subproject commit 6b45ffbbee74d5ae034008fc2cb2a927f3131992 +Subproject commit c8e7c40606db330142632ecf90aaa8777bc42a3a