diff --git a/.gitmodules b/.gitmodules index d5cd4211..18a3014b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "third_party/gopher-orch"] path = third_party/gopher-orch url = https://github.com/GopherSecurity/gopher-orch.git + branch = br_release diff --git a/build.sh b/build.sh index 0b2f3162..924aa378 100755 --- a/build.sh +++ b/build.sh @@ -167,10 +167,34 @@ echo "" echo -e "${YELLOW}Step 5: Running tests...${NC}" mvn test -q 2>/dev/null && echo -e "${GREEN}✓ Tests passed${NC}" || echo -e "${YELLOW}⚠ Some tests may have failed (native library required)${NC}" -# Package JAR -echo -e "${YELLOW}Step 6: Packaging JAR...${NC}" -mvn package -q -DskipTests -echo -e "${GREEN}✓ JAR packaged successfully${NC}" +# Step 6: Package and install JAR to local Maven repo +echo -e "${YELLOW}Step 6: Packaging and installing JAR...${NC}" +mvn install -q -DskipTests -Dmaven.javadoc.skip=true +echo -e "${GREEN}✓ JAR packaged and installed to local Maven repo${NC}" +echo "" + +# Step 7: Build and test auth example +echo -e "${YELLOW}Step 7: Building and testing auth example...${NC}" +AUTH_EXAMPLE_DIR="${SCRIPT_DIR}/examples/auth" + +if [ -d "${AUTH_EXAMPLE_DIR}" ]; then + cd "${AUTH_EXAMPLE_DIR}" + + echo -e "${YELLOW} Compiling auth example...${NC}" + mvn compile -q + echo -e "${GREEN} ✓ Auth example compiled${NC}" + + echo -e "${YELLOW} Running auth example tests...${NC}" + mvn test -q 2>/dev/null && echo -e "${GREEN} ✓ Auth example tests passed${NC}" || echo -e "${YELLOW} ⚠ Some auth example tests may have failed${NC}" + + echo -e "${YELLOW} Packaging auth example...${NC}" + mvn package -q -DskipTests + echo -e "${GREEN} ✓ Auth example packaged${NC}" + + cd "${SCRIPT_DIR}" +else + echo -e "${YELLOW}⚠ Auth example not found at ${AUTH_EXAMPLE_DIR}${NC}" +fi echo "" echo -e "${GREEN}======================================${NC}" @@ -182,3 +206,4 @@ echo -e "Native headers: ${YELLOW}${NATIVE_INCLUDE_DIR}${NC}" echo -e "Run tests: ${YELLOW}mvn test${NC}" echo -e "Run example: ${YELLOW}mvn exec:java${NC}" echo -e "Package JAR: ${YELLOW}mvn package${NC}" +echo -e "Run auth example: ${YELLOW}cd examples/auth && ./run_example.sh --no-auth${NC}" diff --git a/examples/auth/README.md b/examples/auth/README.md new file mode 100644 index 00000000..d88b181d --- /dev/null +++ b/examples/auth/README.md @@ -0,0 +1,180 @@ +# Java Auth MCP Server + +An OAuth-protected MCP (Model Context Protocol) server demonstrating JWT token validation and scope-based access control. + +## Prerequisites + +- Java 17 or later +- Maven 3.8 or later + +## Building + +```bash +mvn package +``` + +## Running + +### Development Mode (No Auth) + +```bash +./run_example.sh --no-auth +``` + +### With Configuration File + +```bash +./run_example.sh --config server.config +``` + +### Direct Java Execution + +```bash +java -jar target/auth-mcp-server-1.0.0.jar server.config +``` + +## Configuration + +Create a `server.config` file with the following options: + +```ini +# Server settings +host=0.0.0.0 +port=3001 + +# OAuth/IDP settings +client_id=my-client +client_secret=my-secret +auth_server_url=https://keycloak.example.com/realms/mcp + +# Scopes +allowed_scopes=openid profile email mcp:read mcp:admin + +# Cache settings +jwks_cache_duration=3600 +jwks_auto_refresh=true +request_timeout=30 + +# Auth bypass mode (for development) +auth_disabled=true +``` + +## Endpoints + +### Health Check + +```bash +curl http://localhost:3001/health +``` + +### OAuth Discovery + +```bash +# Protected Resource Metadata (RFC 9728) +curl http://localhost:3001/.well-known/oauth-protected-resource + +# Authorization Server Metadata (RFC 8414) +curl http://localhost:3001/.well-known/oauth-authorization-server + +# OpenID Configuration +curl http://localhost:3001/.well-known/openid-configuration +``` + +### MCP Endpoints + +#### Initialize + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"initialize"}' +``` + +#### List Tools + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":2,"method":"tools/list"}' +``` + +#### Call Tool (get-weather) + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -d '{ + "jsonrpc":"2.0", + "id":3, + "method":"tools/call", + "params":{ + "name":"get-weather", + "arguments":{"city":"London"} + } + }' +``` + +#### Call Tool (get-forecast) - Requires mcp:read scope + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -d '{ + "jsonrpc":"2.0", + "id":4, + "method":"tools/call", + "params":{ + "name":"get-forecast", + "arguments":{"city":"Tokyo"} + } + }' +``` + +#### Call Tool (get-weather-alerts) - Requires mcp:admin scope + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -d '{ + "jsonrpc":"2.0", + "id":5, + "method":"tools/call", + "params":{ + "name":"get-weather-alerts", + "arguments":{"region":"California"} + } + }' +``` + +## Available Tools + +| Tool | Description | Required Scope | +|------|-------------|----------------| +| get-weather | Get current weather for a city | None | +| get-forecast | Get 5-day weather forecast | mcp:read | +| get-weather-alerts | Get weather alerts for a region | mcp:admin | + +## Authentication + +When authentication is enabled, protected endpoints require a valid JWT bearer token: + +```bash +curl -X POST http://localhost:3001/mcp \ + -H "Authorization: Bearer eyJhbGciOiJSUzI1NiJ9..." \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/list"}' +``` + +The token can also be passed as a query parameter: + +```bash +curl -X POST "http://localhost:3001/mcp?access_token=eyJhbGciOiJSUzI1NiJ9..." \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"tools/list"}' +``` + +## License + +MIT License diff --git a/examples/auth/pom.xml b/examples/auth/pom.xml new file mode 100644 index 00000000..d76c561a --- /dev/null +++ b/examples/auth/pom.xml @@ -0,0 +1,131 @@ + + + 4.0.0 + + com.gophersecurity + auth-mcp-server + 1.0.0 + jar + + Java Auth MCP Server + OAuth-protected MCP server example with JWT validation and scope-based access control + + + 17 + 17 + UTF-8 + 6.1.3 + 2.16.1 + 2.0.12 + 5.10.2 + + + + + + com.gophersecurity + gopher-orch + 0.1.2 + + + + + io.javalin + javalin + ${javalin.version} + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + + + org.slf4j + slf4j-simple + ${slf4j.version} + + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + org.mockito + mockito-core + 5.10.0 + test + + + + org.mockito + mockito-junit-jupiter + 5.10.0 + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.12.1 + + 17 + 17 + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.3.0 + + + + com.gophersecurity.mcp.auth.Application + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.5.1 + + + package + + shade + + + + + com.gophersecurity.mcp.auth.Application + + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.5 + + + + diff --git a/examples/auth/run_example.sh b/examples/auth/run_example.sh new file mode 100755 index 00000000..5d74ac6c --- /dev/null +++ b/examples/auth/run_example.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Java Auth MCP Server - Run Script +# +# Usage: +# ./run_example.sh # Run with default config (server.config) +# ./run_example.sh --no-auth # Run with auth disabled +# ./run_example.sh --config FILE # Run with custom config file + +set -e + +# Change to script directory +cd "$(dirname "$0")" + +# Parse arguments +CONFIG_FILE="server.config" +NO_AUTH=false + +while [[ $# -gt 0 ]]; do + case $1 in + --no-auth) + NO_AUTH=true + shift + ;; + --config) + CONFIG_FILE="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " --no-auth Run with authentication disabled" + echo " --config FILE Use custom configuration file" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Always rebuild to pick up code changes +echo "Building project..." +mvn package -DskipTests -q + +# Create temporary config if --no-auth +if [ "$NO_AUTH" = true ]; then + TEMP_CONFIG=$(mktemp) + cat > "$TEMP_CONFIG" << 'EOF' +# Temporary config with auth disabled +host=0.0.0.0 +port=3001 +auth_disabled=true +allowed_scopes=mcp:read mcp:admin +EOF + CONFIG_FILE="$TEMP_CONFIG" + trap "rm -f $TEMP_CONFIG" EXIT +fi + +# Run the server +echo "Starting Java Auth MCP Server..." +java -jar target/auth-mcp-server-1.0.0.jar "$CONFIG_FILE" diff --git a/examples/auth/server.config b/examples/auth/server.config new file mode 100644 index 00000000..cd0188a5 --- /dev/null +++ b/examples/auth/server.config @@ -0,0 +1,75 @@ +# Java Auth MCP Server Configuration +# ==================================== +# INI-style configuration file for the OAuth-protected MCP server. +# Lines starting with # are comments. Empty lines are ignored. + +# ============================================================================= +# Server Settings +# ============================================================================= + +# Server bind address +# Use 0.0.0.0 to listen on all interfaces, or 127.0.0.1 for localhost only +host=0.0.0.0 + +# Server port number +port=3001 + +# Public server URL (used in OAuth metadata endpoints) +# If not specified, derived from host and port (with localhost substitution) +# server_url=https://marni-nightcapped-nonmeditatively.ngrok-free.dev + +# ============================================================================= +# OAuth/IDP Settings +# ============================================================================= + +# OAuth client credentials +client_id=oauth_0a650b79c5a64c3b920ae8c2b20599d9 +client_secret=6BiU2beUi2wIBxY3MUBLyYqoWKa4t0U_kJVm9mvSOKw +auth_server_url=https://auth-test.gopher.security/realms/gopher-mcp-auth +oauth_authorize_url=https://api-test.gopher.security/oauth/authorize + +# Base URL of the authorization server (e.g., Keycloak realm URL) +# When provided, the following endpoints are automatically derived: +# - jwks_uri: {auth_server_url}/protocol/openid-connect/certs +# - issuer: {auth_server_url} +# - oauth_authorize_url: {auth_server_url}/protocol/openid-connect/auth +# - oauth_token_url: {auth_server_url}/protocol/openid-connect/token + +# Direct OAuth endpoint URLs (optional, override derived values) +# jwks_uri=https://keycloak.example.com/realms/mcp/protocol/openid-connect/certs +# issuer=https://keycloak.example.com/realms/mcp +# oauth_authorize_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/auth +# oauth_token_url=https://keycloak.example.com/realms/mcp/protocol/openid-connect/token + +# ============================================================================= +# Scopes +# ============================================================================= + +# Space-separated list of allowed scopes for token validation +# Tools can require specific scopes for access control +exchange_idps=oauth-idp-714982830194556929-google +allowed_scopes=openid profile email scope-001 + +# ============================================================================= +# Cache Settings +# ============================================================================= + +# JWKS cache duration in seconds (how long to cache the JSON Web Key Set) +jwks_cache_duration=3600 + +# Whether to automatically refresh JWKS before expiration +# Values: true, false, 1, 0 +jwks_auto_refresh=true + +# HTTP request timeout in milliseconds for JWKS fetch and token validation +request_timeout=5000 + +# ============================================================================= +# Development Settings +# ============================================================================= + +# Auth bypass mode - disable authentication for development/testing +# When true, all requests are treated as authenticated with full scopes +# WARNING: Never enable in production! +# Values: true, false, 1, 0 +auth_disabled=false diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/Application.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/Application.java new file mode 100644 index 00000000..4b809325 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/Application.java @@ -0,0 +1,172 @@ +package com.gophersecurity.mcp.auth; + +import com.gophersecurity.orch.auth.GopherAuthClient; +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import com.gophersecurity.mcp.auth.middleware.CorsFilter; +import com.gophersecurity.mcp.auth.middleware.OAuthAuthMiddleware; +import com.gophersecurity.mcp.auth.routes.HealthHandler; +import com.gophersecurity.mcp.auth.routes.McpHandler; +import com.gophersecurity.mcp.auth.routes.OAuthEndpoints; +import com.gophersecurity.mcp.auth.tools.WeatherTools; +import io.javalin.Javalin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Java Auth MCP Server Application. + * + * OAuth-protected MCP server with JWT validation and scope-based access control. + */ +public class Application { + + private static final Logger logger = LoggerFactory.getLogger(Application.class); + private static final String VERSION = "1.0.0"; + + public static void main(String[] args) { + // Print banner + printBanner(); + + // Load configuration + String configPath = args.length > 0 ? args[0] : "server.config"; + AuthServerConfig config = loadConfig(configPath); + + // Initialize auth client (if auth enabled) + final GopherAuthClient authClient = config.isAuthDisabled() ? null : initAuthClient(config); + + // Create components + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(authClient, config); + OAuthEndpoints oauthEndpoints = new OAuthEndpoints(config); + McpHandler mcpHandler = new McpHandler(authMiddleware); + HealthHandler healthHandler = new HealthHandler(VERSION); + + // Register weather tools + WeatherTools.register(mcpHandler, authMiddleware); + + // Create Javalin app + Javalin app = Javalin.create(javalinConfig -> { + javalinConfig.showJavalinBanner = false; + }); + + // Global CORS handler + app.before(ctx -> CorsFilter.setCorsHeaders(ctx)); + + // Health endpoint + app.get("/health", healthHandler::handle); + + // OAuth endpoints + oauthEndpoints.registerRoutes(app); + + // Auth middleware for protected paths + app.before("/mcp", authMiddleware); + app.before("/mcp/*", authMiddleware); + app.before("/rpc", authMiddleware); + app.before("/rpc/*", authMiddleware); + + // MCP/RPC endpoints + app.options("/mcp", CorsFilter::handlePreflight); + app.options("/rpc", CorsFilter::handlePreflight); + app.post("/mcp", mcpHandler); + app.post("/rpc", mcpHandler); + + // Print endpoints + printEndpoints(config); + + // Print auth status + printAuthStatus(config, authClient); + + // Start server + app.start(config.getHost(), config.getPort()); + logger.info("Server started on {}:{}", config.getHost(), config.getPort()); + + // Shutdown hook + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + logger.info("Shutting down..."); + app.stop(); + if (authClient != null) { + // authClient.destroy(); // If destroy method exists + } + System.out.println("Goodbye!"); + })); + } + + private static void printBanner() { + System.out.println(); + System.out.println("╔══════════════════════════════════════╗"); + System.out.println("║ Java Auth MCP Server ║"); + System.out.println("║ Version " + VERSION + " ║"); + System.out.println("╚══════════════════════════════════════╝"); + System.out.println(); + } + + private static AuthServerConfig loadConfig(String path) { + try { + AuthServerConfig config = AuthServerConfig.fromFile(path); + logger.info("Loaded configuration from {}", path); + return config; + } catch (Exception e) { + logger.warn("Failed to load config from {}: {}. Using defaults with auth disabled.", + path, e.getMessage()); + return AuthServerConfig.defaultDisabled(); + } + } + + private static GopherAuthClient initAuthClient(AuthServerConfig config) { + try { + // TODO: Initialize actual GopherAuthClient via FFI + // For now, return null and let middleware handle it + logger.info("Auth client initialization skipped (FFI not available)"); + return null; + } catch (Exception e) { + logger.warn("Failed to initialize auth client: {}. Continuing without auth.", + e.getMessage()); + return null; + } + } + + private static void printEndpoints(AuthServerConfig config) { + String baseUrl = config.getServerUrl(); + + System.out.println("Available Endpoints:"); + System.out.println("────────────────────────────────────────"); + System.out.println(" Health: GET " + baseUrl + "/health"); + System.out.println(); + System.out.println(" OAuth Discovery:"); + System.out.println(" GET " + baseUrl + "/.well-known/oauth-protected-resource"); + System.out.println(" GET " + baseUrl + "/.well-known/oauth-authorization-server"); + System.out.println(" GET " + baseUrl + "/.well-known/openid-configuration"); + System.out.println(); + System.out.println(" OAuth Endpoints:"); + System.out.println(" GET " + baseUrl + "/oauth/authorize"); + System.out.println(" POST " + baseUrl + "/oauth/register"); + System.out.println(); + System.out.println(" MCP Endpoints:"); + System.out.println(" POST " + baseUrl + "/mcp"); + System.out.println(" POST " + baseUrl + "/rpc"); + System.out.println("────────────────────────────────────────"); + System.out.println(); + } + + private static void printAuthStatus(AuthServerConfig config, GopherAuthClient authClient) { + System.out.println("Authentication Status:"); + System.out.println("────────────────────────────────────────"); + + if (config.isAuthDisabled()) { + System.out.println(" Status: DISABLED (development mode)"); + System.out.println(" All protected endpoints are accessible without tokens"); + } else if (authClient == null) { + System.out.println(" Status: ENABLED (OAuth flow, no validation)"); + System.out.println(" Missing token: 401 (triggers OAuth flow)"); + System.out.println(" With token: allowed (validation skipped)"); + System.out.println(" JWKS URI: " + config.getJwksUri()); + System.out.println(" Issuer: " + config.getIssuer()); + } else { + System.out.println(" Status: ENABLED"); + System.out.println(" JWKS URI: " + config.getJwksUri()); + System.out.println(" Issuer: " + config.getIssuer()); + System.out.println(" Scopes: " + config.getAllowedScopes()); + } + + System.out.println("────────────────────────────────────────"); + System.out.println(); + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/config/AuthServerConfig.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/config/AuthServerConfig.java new file mode 100644 index 00000000..40b19937 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/config/AuthServerConfig.java @@ -0,0 +1,353 @@ +package com.gophersecurity.mcp.auth.config; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +/** + * Configuration for the OAuth-protected MCP server. + * + * Supports INI-style configuration files with automatic endpoint derivation + * from the auth_server_url setting. + */ +public class AuthServerConfig { + + // Server settings + private String host = "0.0.0.0"; + private int port = 3001; + private String serverUrl = "http://localhost:3001"; + + // OAuth/IDP settings + private String authServerUrl = ""; + private String jwksUri = ""; + private String issuer = ""; + private String clientId = ""; + private String clientSecret = ""; + private String oauthAuthorizeUrl = ""; + private String oauthTokenUrl = ""; + + // Scopes + private String allowedScopes = "mcp:read mcp:admin"; + + // Cache settings + private int jwksCacheDuration = 3600; + private boolean jwksAutoRefresh = true; + private int requestTimeout = 5000; + + // Auth bypass + private boolean authDisabled = false; + + /** + * Default constructor with default values. + */ + public AuthServerConfig() { + } + + /** + * Create a configuration with authentication disabled. + * + * Useful for development and testing. + * + * @return config with authDisabled=true + */ + public static AuthServerConfig defaultDisabled() { + AuthServerConfig config = new AuthServerConfig(); + config.authDisabled = true; + return config; + } + + /** + * Build configuration from a key-value map. + * + * When auth_server_url is provided, automatically derives: + * - jwks_uri: {auth_server_url}/protocol/openid-connect/certs + * - issuer: {auth_server_url} + * - oauth_authorize_url: {auth_server_url}/protocol/openid-connect/auth + * - oauth_token_url: {auth_server_url}/protocol/openid-connect/token + * + * @param map configuration key-value pairs + * @return configured instance + */ + public static AuthServerConfig buildFromMap(Map map) { + AuthServerConfig config = new AuthServerConfig(); + + // Server settings + config.host = map.getOrDefault("host", config.host); + config.port = parseIntOrDefault(map.get("port"), config.port); + + // Derive server_url with localhost substitution for 0.0.0.0 + String displayHost = config.host.equals("0.0.0.0") ? "localhost" : config.host; + config.serverUrl = map.getOrDefault("server_url", + String.format("http://%s:%d", displayHost, config.port)); + + // OAuth settings with endpoint derivation + config.authServerUrl = map.getOrDefault("auth_server_url", ""); + + if (!config.authServerUrl.isEmpty()) { + config.jwksUri = map.getOrDefault("jwks_uri", + config.authServerUrl + "/protocol/openid-connect/certs"); + config.issuer = map.getOrDefault("issuer", config.authServerUrl); + config.oauthAuthorizeUrl = map.getOrDefault("oauth_authorize_url", + config.authServerUrl + "/protocol/openid-connect/auth"); + config.oauthTokenUrl = map.getOrDefault("oauth_token_url", + config.authServerUrl + "/protocol/openid-connect/token"); + } else { + config.jwksUri = map.getOrDefault("jwks_uri", ""); + config.issuer = map.getOrDefault("issuer", ""); + config.oauthAuthorizeUrl = map.getOrDefault("oauth_authorize_url", ""); + config.oauthTokenUrl = map.getOrDefault("oauth_token_url", ""); + } + + config.clientId = map.getOrDefault("client_id", ""); + config.clientSecret = map.getOrDefault("client_secret", ""); + config.allowedScopes = map.getOrDefault("allowed_scopes", config.allowedScopes); + + // Cache settings + config.jwksCacheDuration = parseIntOrDefault(map.get("jwks_cache_duration"), + config.jwksCacheDuration); + config.jwksAutoRefresh = parseBooleanOrDefault(map.get("jwks_auto_refresh"), + config.jwksAutoRefresh); + config.requestTimeout = parseIntOrDefault(map.get("request_timeout"), + config.requestTimeout); + + // Auth bypass + config.authDisabled = parseBooleanOrDefault(map.get("auth_disabled"), + config.authDisabled); + + return config; + } + + /** + * Validate the configuration. + * + * When authentication is enabled, validates that required fields are present: + * - client_id is not empty + * - client_secret is not empty + * - jwks_uri is not empty + * + * Validation is skipped when authDisabled is true. + * + * @throws IllegalArgumentException if validation fails + */ + public void validate() throws IllegalArgumentException { + if (authDisabled) { + return; + } + + if (clientId == null || clientId.isEmpty()) { + throw new IllegalArgumentException( + "client_id is required when authentication is enabled"); + } + + if (clientSecret == null || clientSecret.isEmpty()) { + throw new IllegalArgumentException( + "client_secret is required when authentication is enabled"); + } + + if (jwksUri == null || jwksUri.isEmpty()) { + throw new IllegalArgumentException( + "jwks_uri is required when authentication is enabled " + + "(provide jwks_uri or auth_server_url)"); + } + } + + /** + * Load configuration from an INI-style file. + * + * Reads the file, parses it, builds the config, and validates it. + * + * @param path path to the configuration file + * @return validated configuration + * @throws IOException if the file cannot be read + * @throws IllegalArgumentException if validation fails + */ + public static AuthServerConfig fromFile(String path) throws IOException { + Map map = parseConfigFile(path); + AuthServerConfig config = buildFromMap(map); + config.validate(); + return config; + } + + private static int parseIntOrDefault(String value, int defaultValue) { + if (value == null || value.isEmpty()) { + return defaultValue; + } + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + return defaultValue; + } + } + + private static boolean parseBooleanOrDefault(String value, boolean defaultValue) { + if (value == null || value.isEmpty()) { + return defaultValue; + } + return "true".equalsIgnoreCase(value) || "1".equals(value); + } + + /** + * Parse an INI-style configuration file. + * + * Handles: + * - Comments (lines starting with #) + * - Empty lines + * - Values containing = characters (splits only on first =) + * - Whitespace trimming for keys and values + * + * @param path path to the configuration file + * @return map of configuration key-value pairs + * @throws IOException if the file cannot be read + */ + public static Map parseConfigFile(String path) throws IOException { + Map map = new HashMap<>(); + + for (String line : Files.readAllLines(Path.of(path))) { + String trimmed = line.trim(); + + // Skip empty lines and comments + if (trimmed.isEmpty() || trimmed.startsWith("#")) { + continue; + } + + // Split on first '=' only to handle values containing '=' + int eqIndex = trimmed.indexOf('='); + if (eqIndex > 0) { + String key = trimmed.substring(0, eqIndex).trim(); + String value = trimmed.substring(eqIndex + 1).trim(); + + if (!key.isEmpty()) { + map.put(key, value); + } + } + } + + return map; + } + + // Getters + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getServerUrl() { + return serverUrl; + } + + public String getAuthServerUrl() { + return authServerUrl; + } + + public String getJwksUri() { + return jwksUri; + } + + public String getIssuer() { + return issuer; + } + + public String getClientId() { + return clientId; + } + + public String getClientSecret() { + return clientSecret; + } + + public String getOauthAuthorizeUrl() { + return oauthAuthorizeUrl; + } + + public String getOauthTokenUrl() { + return oauthTokenUrl; + } + + public String getAllowedScopes() { + return allowedScopes; + } + + public int getJwksCacheDuration() { + return jwksCacheDuration; + } + + public boolean isJwksAutoRefresh() { + return jwksAutoRefresh; + } + + public int getRequestTimeout() { + return requestTimeout; + } + + public boolean isAuthDisabled() { + return authDisabled; + } + + // Package-private setters for buildFromMap + + void setHost(String host) { + this.host = host; + } + + void setPort(int port) { + this.port = port; + } + + void setServerUrl(String serverUrl) { + this.serverUrl = serverUrl; + } + + void setAuthServerUrl(String authServerUrl) { + this.authServerUrl = authServerUrl; + } + + void setJwksUri(String jwksUri) { + this.jwksUri = jwksUri; + } + + void setIssuer(String issuer) { + this.issuer = issuer; + } + + void setClientId(String clientId) { + this.clientId = clientId; + } + + void setClientSecret(String clientSecret) { + this.clientSecret = clientSecret; + } + + void setOauthAuthorizeUrl(String oauthAuthorizeUrl) { + this.oauthAuthorizeUrl = oauthAuthorizeUrl; + } + + void setOauthTokenUrl(String oauthTokenUrl) { + this.oauthTokenUrl = oauthTokenUrl; + } + + void setAllowedScopes(String allowedScopes) { + this.allowedScopes = allowedScopes; + } + + void setJwksCacheDuration(int jwksCacheDuration) { + this.jwksCacheDuration = jwksCacheDuration; + } + + void setJwksAutoRefresh(boolean jwksAutoRefresh) { + this.jwksAutoRefresh = jwksAutoRefresh; + } + + void setRequestTimeout(int requestTimeout) { + this.requestTimeout = requestTimeout; + } + + void setAuthDisabled(boolean authDisabled) { + this.authDisabled = authDisabled; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/CorsFilter.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/CorsFilter.java new file mode 100644 index 00000000..74955047 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/CorsFilter.java @@ -0,0 +1,45 @@ +package com.gophersecurity.mcp.auth.middleware; + +import io.javalin.http.Context; + +/** + * CORS filter for MCP server. + * Sets appropriate headers for cross-origin requests. + */ +public class CorsFilter { + + private static final String ALLOWED_METHODS = + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"; + + private static final String ALLOWED_HEADERS = + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, " + + "X-Requested-With, Origin, Cache-Control, Pragma, Mcp-Session-Id, Mcp-Protocol-Version"; + + private static final String EXPOSED_HEADERS = + "WWW-Authenticate, Content-Length, Content-Type"; + + private static final String MAX_AGE = "86400"; + + /** + * Set CORS headers on the response. + * + * @param ctx Javalin context + */ + public static void setCorsHeaders(Context ctx) { + ctx.header("Access-Control-Allow-Origin", "*"); + ctx.header("Access-Control-Allow-Methods", ALLOWED_METHODS); + ctx.header("Access-Control-Allow-Headers", ALLOWED_HEADERS); + ctx.header("Access-Control-Expose-Headers", EXPOSED_HEADERS); + ctx.header("Access-Control-Max-Age", MAX_AGE); + } + + /** + * Handle CORS preflight request. + * + * @param ctx Javalin context + */ + public static void handlePreflight(Context ctx) { + setCorsHeaders(ctx); + ctx.status(204); + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddleware.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddleware.java new file mode 100644 index 00000000..34905b5b --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddleware.java @@ -0,0 +1,302 @@ +package com.gophersecurity.mcp.auth.middleware; + +import com.gophersecurity.orch.auth.GopherAuthClient; +import com.gophersecurity.orch.auth.TokenPayload; +import com.gophersecurity.orch.auth.ValidationResult; +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import com.gophersecurity.orch.auth.AuthContext; +import io.javalin.http.Context; +import io.javalin.http.Handler; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * OAuth authentication middleware for JWT token validation. + * + * Validates bearer tokens on protected endpoints and maintains + * the current authentication context. + */ +public class OAuthAuthMiddleware implements Handler { + + private static final List PUBLIC_PATHS = Arrays.asList( + "/health", + "/.well-known/", + "/oauth/", + "/favicon.ico" + ); + + private static final List PROTECTED_PREFIXES = Arrays.asList( + "/mcp", + "/rpc", + "/events", + "/sse" + ); + + private final GopherAuthClient authClient; + private final AuthServerConfig config; + private final ThreadLocal currentAuthContext; + + /** + * Create OAuth middleware. + * + * @param authClient GopherAuthClient for token validation (may be null) + * @param config server configuration + */ + public OAuthAuthMiddleware(GopherAuthClient authClient, AuthServerConfig config) { + this.authClient = authClient; + this.config = config; + this.currentAuthContext = ThreadLocal.withInitial(AuthContext::empty); + } + + /** + * Extract bearer token from request. + * + * Checks Authorization header first (case-insensitive "Bearer" prefix), + * then falls back to "access_token" query parameter. + * + * @param ctx Javalin context + * @return token string or null if not found + */ + public String extractToken(Context ctx) { + // Try Authorization header first + String authHeader = ctx.header("Authorization"); + if (authHeader != null && authHeader.length() > 7) { + String prefix = authHeader.substring(0, 7); + if (prefix.equalsIgnoreCase("Bearer ")) { + return authHeader.substring(7); + } + } + + // Try query parameter + String queryToken = ctx.queryParam("access_token"); + if (queryToken != null && !queryToken.isEmpty()) { + return queryToken; + } + + return null; + } + + /** + * Get the current authentication context. + * + * @return current auth context or empty context if not set + */ + public AuthContext getAuthContext() { + AuthContext ctx = currentAuthContext.get(); + return ctx != null ? ctx : AuthContext.empty(); + } + + /** + * Set the current authentication context. + * + * @param context auth context to set + */ + protected void setAuthContext(AuthContext context) { + currentAuthContext.set(context); + } + + /** + * Check if authentication is disabled. + * + * @return true if auth is disabled in config + */ + public boolean isAuthDisabled() { + return config.isAuthDisabled(); + } + + /** + * Get the auth client. + * + * @return auth client or null + */ + protected GopherAuthClient getAuthClient() { + return authClient; + } + + /** + * Get the configuration. + * + * @return server configuration + */ + protected AuthServerConfig getConfig() { + return config; + } + + /** + * Check if a path is public (no auth required). + * + * @param path request path + * @return true if path is public + */ + public boolean isPublicPath(String path) { + if (path == null) { + return false; + } + for (String publicPath : PUBLIC_PATHS) { + if (path.equals(publicPath) || path.startsWith(publicPath)) { + return true; + } + } + return false; + } + + /** + * Check if a path requires authentication. + * + * Returns false if: + * - Auth is disabled in config + * - Path is public + * + * Returns true if path matches a protected prefix. + * Returns false for unknown paths (default open). + * + * @param path request path + * @return true if authentication is required + */ + public boolean requiresAuth(String path) { + // Auth globally disabled + if (config.isAuthDisabled()) { + return false; + } + + // Public paths don't require auth + if (isPublicPath(path)) { + return false; + } + + // Check protected prefixes + if (path != null) { + for (String prefix : PROTECTED_PREFIXES) { + if (path.equals(prefix) || path.startsWith(prefix + "/")) { + return true; + } + } + } + + // Unknown paths don't require auth by default + return false; + } + + /** + * Send 401 Unauthorized response with WWW-Authenticate header. + * + * @param ctx Javalin context + * @param error OAuth error code + * @param description human-readable error description + */ + public void sendUnauthorized(Context ctx, String error, String description) { + // Build WWW-Authenticate header + String wwwAuthenticate = String.format( + "Bearer realm=\"%s\", resource_metadata=\"%s/.well-known/oauth-protected-resource\", " + + "scope=\"%s\", error=\"%s\", error_description=\"%s\"", + config.getServerUrl(), + config.getServerUrl(), + config.getAllowedScopes(), + escapeHeaderValue(error), + escapeHeaderValue(description) + ); + + // Set CORS headers + CorsFilter.setCorsHeaders(ctx); + + // Build response body + Map body = new LinkedHashMap<>(); + body.put("error", error); + body.put("error_description", description); + + // Send response + ctx.status(401); + ctx.header("WWW-Authenticate", wwwAuthenticate); + ctx.contentType("application/json"); + ctx.json(body); + } + + /** + * Escape special characters in header values. + * + * @param value value to escape + * @return escaped value + */ + private String escapeHeaderValue(String value) { + if (value == null) { + return ""; + } + // Escape backslashes and quotes + return value.replace("\\", "\\\\").replace("\"", "\\\""); + } + + /** + * Handle incoming request as Javalin middleware. + * + * Validates authentication for protected paths and sets the auth context. + * + * @param ctx Javalin context + * @throws Exception if request handling fails + */ + @Override + public void handle(Context ctx) throws Exception { + String path = ctx.path(); + + // Public paths - set empty context and proceed + if (isPublicPath(path)) { + setAuthContext(AuthContext.empty()); + return; + } + + // Auth disabled - set anonymous context and proceed + if (config.isAuthDisabled()) { + setAuthContext(AuthContext.anonymous(config.getAllowedScopes())); + return; + } + + // Extract token + String token = extractToken(ctx); + if (token == null) { + sendUnauthorized(ctx, "invalid_request", "Missing bearer token"); + return; + } + + // If auth client is not available, trust the token and allow request + // (OAuth flow completed, but native validation is not available) + if (authClient == null) { + setAuthContext(AuthContext.anonymous(config.getAllowedScopes())); + return; + } + + // Validate token + ValidationResult result = authClient.validateToken(token, 30); + if (!result.isValid()) { + sendUnauthorized(ctx, "invalid_token", result.getErrorMessage()); + return; + } + + // Extract payload and create auth context + try { + TokenPayload payload = authClient.extractPayload(token); + AuthContext authContext = new AuthContext( + payload.getSubject(), + payload.getScopes(), + payload.getAudience(), + payload.getExpiration(), + true + ); + setAuthContext(authContext); + } catch (Exception e) { + // Payload extraction failed but token is valid + setAuthContext(new AuthContext("", "", "", 0, true)); + } + } + + /** + * Check if the current auth context has a required scope. + * + * @param requiredScope scope to check + * @return true if scope is present + */ + public boolean hasScope(String requiredScope) { + return getAuthContext().hasScope(requiredScope); + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcError.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcError.java new file mode 100644 index 00000000..fd4cc391 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcError.java @@ -0,0 +1,131 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * JSON-RPC 2.0 error object. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class JsonRpcError { + + // Standard JSON-RPC 2.0 error codes + public static final int PARSE_ERROR = -32700; + public static final int INVALID_REQUEST = -32600; + public static final int METHOD_NOT_FOUND = -32601; + public static final int INVALID_PARAMS = -32602; + public static final int INTERNAL_ERROR = -32603; + + @JsonProperty("code") + private int code; + + @JsonProperty("message") + private String message; + + @JsonProperty("data") + private Object data; + + /** + * Default constructor for Jackson. + */ + public JsonRpcError() { + } + + /** + * Create an error with code and message. + * + * @param code error code + * @param message error message + */ + public JsonRpcError(int code, String message) { + this.code = code; + this.message = message; + } + + /** + * Create an error with code, message, and data. + * + * @param code error code + * @param message error message + * @param data additional error data + */ + public JsonRpcError(int code, String message, Object data) { + this.code = code; + this.message = message; + this.data = data; + } + + /** + * Create a parse error. + * + * @param data additional error details + * @return parse error + */ + public static JsonRpcError parseError(String data) { + return new JsonRpcError(PARSE_ERROR, "Parse error", data); + } + + /** + * Create an invalid request error. + * + * @param data additional error details + * @return invalid request error + */ + public static JsonRpcError invalidRequest(String data) { + return new JsonRpcError(INVALID_REQUEST, "Invalid Request", data); + } + + /** + * Create a method not found error. + * + * @param data additional error details + * @return method not found error + */ + public static JsonRpcError methodNotFound(String data) { + return new JsonRpcError(METHOD_NOT_FOUND, "Method not found", data); + } + + /** + * Create an invalid params error. + * + * @param data additional error details + * @return invalid params error + */ + public static JsonRpcError invalidParams(String data) { + return new JsonRpcError(INVALID_PARAMS, "Invalid params", data); + } + + /** + * Create an internal error. + * + * @param data additional error details + * @return internal error + */ + public static JsonRpcError internalError(String data) { + return new JsonRpcError(INTERNAL_ERROR, "Internal error", data); + } + + public int getCode() { + return code; + } + + public void setCode(int code) { + this.code = code; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Object getData() { + return data; + } + + public void setData(Object data) { + this.data = data; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcRequest.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcRequest.java new file mode 100644 index 00000000..156d97b1 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcRequest.java @@ -0,0 +1,78 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +/** + * JSON-RPC 2.0 request object. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public class JsonRpcRequest { + + @JsonProperty("jsonrpc") + private String jsonrpc; + + @JsonProperty("id") + private Object id; + + @JsonProperty("method") + private String method; + + @JsonProperty("params") + private Map params; + + /** + * Default constructor for Jackson. + */ + public JsonRpcRequest() { + } + + /** + * Create a request with all fields. + * + * @param jsonrpc protocol version (should be "2.0") + * @param id request identifier + * @param method method name + * @param params method parameters + */ + public JsonRpcRequest(String jsonrpc, Object id, String method, Map params) { + this.jsonrpc = jsonrpc; + this.id = id; + this.method = method; + this.params = params; + } + + public String getJsonrpc() { + return jsonrpc; + } + + public void setJsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + } + + public Object getId() { + return id; + } + + public void setId(Object id) { + this.id = id; + } + + public String getMethod() { + return method; + } + + public void setMethod(String method) { + this.method = method; + } + + public Map getParams() { + return params; + } + + public void setParams(Map params) { + this.params = params; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcResponse.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcResponse.java new file mode 100644 index 00000000..2634b24a --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/JsonRpcResponse.java @@ -0,0 +1,89 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * JSON-RPC 2.0 response object. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class JsonRpcResponse { + + @JsonProperty("jsonrpc") + private String jsonrpc = "2.0"; + + @JsonProperty("id") + private Object id; + + @JsonProperty("result") + private Object result; + + @JsonProperty("error") + private JsonRpcError error; + + /** + * Default constructor for Jackson. + */ + public JsonRpcResponse() { + } + + /** + * Create a success response. + * + * @param id request identifier + * @param result method result + * @return success response + */ + public static JsonRpcResponse success(Object id, Object result) { + JsonRpcResponse response = new JsonRpcResponse(); + response.id = id; + response.result = result; + return response; + } + + /** + * Create an error response. + * + * @param id request identifier + * @param error error object + * @return error response + */ + public static JsonRpcResponse error(Object id, JsonRpcError error) { + JsonRpcResponse response = new JsonRpcResponse(); + response.id = id; + response.error = error; + return response; + } + + public String getJsonrpc() { + return jsonrpc; + } + + public void setJsonrpc(String jsonrpc) { + this.jsonrpc = jsonrpc; + } + + public Object getId() { + return id; + } + + public void setId(Object id) { + this.id = id; + } + + public Object getResult() { + return result; + } + + public void setResult(Object result) { + this.result = result; + } + + public JsonRpcError getError() { + return error; + } + + public void setError(JsonRpcError error) { + this.error = error; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolContent.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolContent.java new file mode 100644 index 00000000..4a6ba7ba --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolContent.java @@ -0,0 +1,89 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Content item for tool results. + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ToolContent { + + @JsonProperty("type") + private String type; + + @JsonProperty("text") + private String text; + + @JsonProperty("data") + private String data; + + @JsonProperty("mimeType") + private String mimeType; + + /** + * Default constructor for Jackson. + */ + public ToolContent() { + } + + /** + * Create a text content item. + * + * @param text text content + * @return text content item + */ + public static ToolContent text(String text) { + ToolContent content = new ToolContent(); + content.type = "text"; + content.text = text; + return content; + } + + /** + * Create an image content item. + * + * @param base64Data base64-encoded image data + * @param mimeType image MIME type (e.g., "image/png") + * @return image content item + */ + public static ToolContent image(String base64Data, String mimeType) { + ToolContent content = new ToolContent(); + content.type = "image"; + content.data = base64Data; + content.mimeType = mimeType; + return content; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public String getData() { + return data; + } + + public void setData(String data) { + this.data = data; + } + + public String getMimeType() { + return mimeType; + } + + public void setMimeType(String mimeType) { + this.mimeType = mimeType; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolResult.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolResult.java new file mode 100644 index 00000000..edc2e8a5 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolResult.java @@ -0,0 +1,70 @@ +package com.gophersecurity.mcp.auth.model; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Result from a tool execution. + */ +public class ToolResult { + + private final List content; + private final boolean isError; + + /** + * Create a tool result. + * + * @param content list of content items + * @param isError whether this is an error result + */ + public ToolResult(List content, boolean isError) { + this.content = content; + this.isError = isError; + } + + /** + * Create a successful text result. + * + * @param text text content + * @return text result + */ + public static ToolResult text(String text) { + List content = new ArrayList<>(); + content.add(ToolContent.text(text)); + return new ToolResult(content, false); + } + + /** + * Create an error result. + * + * @param message error message + * @return error result + */ + public static ToolResult error(String message) { + List content = new ArrayList<>(); + content.add(ToolContent.text(message)); + return new ToolResult(content, true); + } + + /** + * Convert to a map for JSON serialization. + * + * @return map representation + */ + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put("content", content); + map.put("isError", isError); + return map; + } + + public List getContent() { + return content; + } + + public boolean isError() { + return isError; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolSpec.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolSpec.java new file mode 100644 index 00000000..0e015c9a --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/model/ToolSpec.java @@ -0,0 +1,52 @@ +package com.gophersecurity.mcp.auth.model; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * MCP tool specification. + */ +public class ToolSpec { + + private final String name; + private final String description; + private final Map inputSchema; + + /** + * Create a tool specification. + * + * @param name tool name + * @param description tool description + * @param inputSchema JSON schema for tool input + */ + public ToolSpec(String name, String description, Map inputSchema) { + this.name = name; + this.description = description; + this.inputSchema = inputSchema; + } + + /** + * Convert to a map for JSON serialization. + * + * @return map representation + */ + public Map toMap() { + Map map = new LinkedHashMap<>(); + map.put("name", name); + map.put("description", description); + map.put("inputSchema", inputSchema); + return map; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public Map getInputSchema() { + return inputSchema; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/HealthHandler.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/HealthHandler.java new file mode 100644 index 00000000..e4b796dc --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/HealthHandler.java @@ -0,0 +1,49 @@ +package com.gophersecurity.mcp.auth.routes; + +import io.javalin.http.Context; + +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Health check endpoint handler. + */ +public class HealthHandler { + + private final String version; + + /** + * Create a health handler without version. + */ + public HealthHandler() { + this.version = null; + } + + /** + * Create a health handler with version. + * + * @param version application version + */ + public HealthHandler(String version) { + this.version = version; + } + + /** + * Handle health check request. + * + * @param ctx Javalin context + */ + public void handle(Context ctx) { + Map response = new LinkedHashMap<>(); + response.put("status", "ok"); + response.put("timestamp", Instant.now().toString()); + + if (version != null && !version.isEmpty()) { + response.put("version", version); + } + + ctx.contentType("application/json"); + ctx.json(response); + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/McpHandler.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/McpHandler.java new file mode 100644 index 00000000..ab00fa41 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/McpHandler.java @@ -0,0 +1,226 @@ +package com.gophersecurity.mcp.auth.routes; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.gophersecurity.mcp.auth.middleware.CorsFilter; +import com.gophersecurity.mcp.auth.middleware.OAuthAuthMiddleware; +import com.gophersecurity.mcp.auth.model.JsonRpcError; +import com.gophersecurity.mcp.auth.model.JsonRpcRequest; +import com.gophersecurity.mcp.auth.model.JsonRpcResponse; +import com.gophersecurity.mcp.auth.model.ToolResult; +import com.gophersecurity.mcp.auth.model.ToolSpec; +import io.javalin.http.Context; +import io.javalin.http.Handler; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * MCP (Model Context Protocol) JSON-RPC handler. + */ +public class McpHandler implements Handler { + + // JSON-RPC error codes + public static final int PARSE_ERROR = -32700; + public static final int INVALID_REQUEST = -32600; + public static final int METHOD_NOT_FOUND = -32601; + public static final int INVALID_PARAMS = -32602; + public static final int INTERNAL_ERROR = -32603; + + private final ObjectMapper mapper = new ObjectMapper(); + private final Map tools = new HashMap<>(); + private final Map, Context, ToolResult>> toolHandlers = new HashMap<>(); + private final OAuthAuthMiddleware authMiddleware; + + /** + * Create MCP handler. + * + * @param authMiddleware authentication middleware + */ + public McpHandler(OAuthAuthMiddleware authMiddleware) { + this.authMiddleware = authMiddleware; + } + + /** + * Handle incoming JSON-RPC request. + * + * @param ctx Javalin context + * @throws Exception if request handling fails + */ + @Override + public void handle(Context ctx) throws Exception { + CorsFilter.setCorsHeaders(ctx); + + // Parse request + JsonRpcRequest request; + try { + request = mapper.readValue(ctx.body(), JsonRpcRequest.class); + } catch (Exception e) { + sendError(ctx, null, PARSE_ERROR, "Parse error", e.getMessage()); + return; + } + + // Validate JSON-RPC version + if (!"2.0".equals(request.getJsonrpc())) { + sendError(ctx, request.getId(), INVALID_REQUEST, "Invalid Request", + "jsonrpc must be \"2.0\""); + return; + } + + // Route by method + String method = request.getMethod(); + Object result; + + try { + switch (method) { + case "initialize": + result = handleInitialize(); + break; + case "tools/list": + result = handleToolsList(); + break; + case "tools/call": + result = handleToolsCall(request.getParams(), ctx); + break; + case "ping": + result = Map.of(); + break; + default: + sendError(ctx, request.getId(), METHOD_NOT_FOUND, "Method not found", + "Unknown method: " + method); + return; + } + + sendSuccess(ctx, request.getId(), result); + } catch (Exception e) { + sendError(ctx, request.getId(), INTERNAL_ERROR, "Internal error", e.getMessage()); + } + } + + /** + * Handle initialize request. + * + * @return initialization response + */ + public Map handleInitialize() { + Map response = new LinkedHashMap<>(); + response.put("protocolVersion", "2024-11-05"); + + Map capabilities = new LinkedHashMap<>(); + capabilities.put("tools", Map.of()); + response.put("capabilities", capabilities); + + Map serverInfo = new LinkedHashMap<>(); + serverInfo.put("name", "java-auth-mcp-server"); + serverInfo.put("version", "1.0.0"); + response.put("serverInfo", serverInfo); + + return response; + } + + /** + * Handle tools/list request. + * + * @return list of available tools + */ + public Map handleToolsList() { + Map response = new LinkedHashMap<>(); + response.put("tools", tools.values().stream() + .map(ToolSpec::toMap) + .toList()); + return response; + } + + /** + * Handle tools/call request. + * + * @param params request parameters + * @param ctx Javalin context + * @return tool execution result + */ + @SuppressWarnings("unchecked") + public Map handleToolsCall(Map params, Context ctx) { + if (params == null) { + return ToolResult.error("Missing params").toMap(); + } + + String name = (String) params.get("name"); + if (name == null || name.isEmpty()) { + return ToolResult.error("Missing tool name").toMap(); + } + + BiFunction, Context, ToolResult> handler = toolHandlers.get(name); + if (handler == null) { + return ToolResult.error("Tool not found: " + name).toMap(); + } + + Map arguments = (Map) params.get("arguments"); + if (arguments == null) { + arguments = Map.of(); + } + + ToolResult result = handler.apply(arguments, ctx); + return result.toMap(); + } + + /** + * Register a tool with its handler. + * + * @param name tool name + * @param spec tool specification + * @param handler tool execution handler + */ + public void registerTool(String name, ToolSpec spec, + BiFunction, Context, ToolResult> handler) { + tools.put(name, spec); + toolHandlers.put(name, handler); + } + + /** + * Send success response. + * + * @param ctx Javalin context + * @param id request id + * @param result method result + */ + public void sendSuccess(Context ctx, Object id, Object result) { + JsonRpcResponse response = JsonRpcResponse.success(id, result); + ctx.contentType("application/json"); + ctx.json(response); + } + + /** + * Send error response. + * + * @param ctx Javalin context + * @param id request id + * @param code error code + * @param message error message + * @param data additional error data + */ + public void sendError(Context ctx, Object id, int code, String message, String data) { + JsonRpcError error = new JsonRpcError(code, message, data); + JsonRpcResponse response = JsonRpcResponse.error(id, error); + ctx.contentType("application/json"); + ctx.json(response); + } + + /** + * Get registered tools map. + * + * @return tools map + */ + public Map getTools() { + return tools; + } + + /** + * Get registered tool handlers map. + * + * @return tool handlers map + */ + public Map, Context, ToolResult>> getToolHandlers() { + return toolHandlers; + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/OAuthEndpoints.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/OAuthEndpoints.java new file mode 100644 index 00000000..6ed12a0f --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/routes/OAuthEndpoints.java @@ -0,0 +1,279 @@ +package com.gophersecurity.mcp.auth.routes; + +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import com.gophersecurity.mcp.auth.middleware.CorsFilter; +import io.javalin.Javalin; +import io.javalin.http.Context; +import io.javalin.http.HttpStatus; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * OAuth 2.0 and OpenID Connect endpoint handlers. + */ +public class OAuthEndpoints { + + private static final SecureRandom RANDOM = new SecureRandom(); + private static final String ALPHANUMERIC = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + + private final AuthServerConfig config; + + /** + * Create OAuth endpoints with configuration. + * + * @param config server configuration + */ + public OAuthEndpoints(AuthServerConfig config) { + this.config = config; + } + + /** + * Register all OAuth routes with the Javalin app. + * + * @param app Javalin application + */ + public void registerRoutes(Javalin app) { + // Protected resource metadata (RFC 9728) + app.get("/.well-known/oauth-protected-resource", this::protectedResourceMetadata); + app.get("/.well-known/oauth-protected-resource/mcp", this::protectedResourceMetadata); + app.options("/.well-known/oauth-protected-resource", CorsFilter::handlePreflight); + app.options("/.well-known/oauth-protected-resource/mcp", CorsFilter::handlePreflight); + + // Authorization server metadata (RFC 8414) + app.get("/.well-known/oauth-authorization-server", this::authorizationServerMetadata); + app.options("/.well-known/oauth-authorization-server", CorsFilter::handlePreflight); + + // OpenID Connect discovery + app.get("/.well-known/openid-configuration", this::openidConfiguration); + app.options("/.well-known/openid-configuration", CorsFilter::handlePreflight); + + // OAuth endpoints + app.get("/oauth/authorize", this::authorize); + app.options("/oauth/authorize", CorsFilter::handlePreflight); + + app.post("/oauth/register", this::register); + app.options("/oauth/register", CorsFilter::handlePreflight); + } + + /** + * Handle protected resource metadata request (RFC 9728). + * + * @param ctx Javalin context + */ + public void protectedResourceMetadata(Context ctx) { + Map response = new LinkedHashMap<>(); + response.put("resource", config.getServerUrl() + "/mcp"); + response.put("authorization_servers", List.of(config.getServerUrl())); + response.put("scopes_supported", splitScopes(config.getAllowedScopes())); + response.put("bearer_methods_supported", List.of("header", "query")); + response.put("resource_documentation", config.getServerUrl() + "/docs"); + + CorsFilter.setCorsHeaders(ctx); + ctx.contentType("application/json"); + ctx.json(response); + } + + /** + * Handle authorization server metadata request (RFC 8414). + * + * @param ctx Javalin context + */ + public void authorizationServerMetadata(Context ctx) { + Map response = new LinkedHashMap<>(); + response.put("issuer", config.getIssuer().isEmpty() ? config.getServerUrl() : config.getIssuer()); + response.put("authorization_endpoint", config.getOauthAuthorizeUrl()); + response.put("token_endpoint", config.getOauthTokenUrl()); + response.put("jwks_uri", config.getJwksUri()); + response.put("registration_endpoint", config.getServerUrl() + "/oauth/register"); + response.put("scopes_supported", splitScopes(config.getAllowedScopes())); + response.put("response_types_supported", List.of("code")); + response.put("grant_types_supported", + List.of("authorization_code", "refresh_token")); + response.put("token_endpoint_auth_methods_supported", + List.of("client_secret_basic", "client_secret_post", "none")); + response.put("code_challenge_methods_supported", List.of("S256")); + + CorsFilter.setCorsHeaders(ctx); + ctx.contentType("application/json"); + ctx.json(response); + } + + /** + * Handle OpenID Connect discovery request. + * + * Extends authorization server metadata with OIDC-specific fields. + * + * @param ctx Javalin context + */ + public void openidConfiguration(Context ctx) { + // Merge base OIDC scopes with configured scopes + List baseScopes = Arrays.asList("openid", "profile", "email"); + List configuredScopes = splitScopes(config.getAllowedScopes()); + List allScopes = new java.util.ArrayList<>(baseScopes); + for (String scope : configuredScopes) { + if (!allScopes.contains(scope)) { + allScopes.add(scope); + } + } + + Map response = new LinkedHashMap<>(); + response.put("issuer", config.getIssuer().isEmpty() ? config.getServerUrl() : config.getIssuer()); + response.put("authorization_endpoint", config.getOauthAuthorizeUrl()); + response.put("token_endpoint", config.getOauthTokenUrl()); + response.put("jwks_uri", config.getJwksUri()); + if (!config.getAuthServerUrl().isEmpty()) { + response.put("userinfo_endpoint", config.getAuthServerUrl() + "/protocol/openid-connect/userinfo"); + } + response.put("scopes_supported", allScopes); + response.put("response_types_supported", List.of("code")); + response.put("grant_types_supported", + List.of("authorization_code", "refresh_token")); + response.put("token_endpoint_auth_methods_supported", + List.of("client_secret_basic", "client_secret_post", "none")); + response.put("subject_types_supported", List.of("public")); + response.put("id_token_signing_alg_values_supported", List.of("RS256")); + + CorsFilter.setCorsHeaders(ctx); + ctx.contentType("application/json"); + ctx.json(response); + } + + /** + * Handle OAuth authorization redirect. + * + * Redirects to the OAuth provider's authorization endpoint with + * all necessary query parameters. + * + * @param ctx Javalin context + */ + public void authorize(Context ctx) { + String authEndpoint = config.getOauthAuthorizeUrl(); + + try { + // Build URL and forward all query parameters + StringBuilder url = new StringBuilder(authEndpoint); + Map> queryParams = ctx.queryParamMap(); + + if (!queryParams.isEmpty()) { + url.append("?"); + boolean first = true; + for (Map.Entry> entry : queryParams.entrySet()) { + String key = entry.getKey(); + List values = entry.getValue(); + if (values != null && !values.isEmpty()) { + if (!first) { + url.append("&"); + } + url.append(urlEncode(key)).append("=").append(urlEncode(values.get(0))); + first = false; + } + } + } + + CorsFilter.setCorsHeaders(ctx); + ctx.redirect(url.toString(), HttpStatus.FOUND); + } catch (Exception e) { + CorsFilter.setCorsHeaders(ctx); + ctx.status(500); + ctx.json(Map.of( + "error", "server_error", + "error_description", "Failed to construct authorization URL" + )); + } + } + + /** + * Handle dynamic client registration. + * + * Generates client credentials and returns a registration response. + * + * @param ctx Javalin context + */ + @SuppressWarnings("unchecked") + public void register(Context ctx) { + Map body; + try { + body = ctx.bodyAsClass(Map.class); + } catch (Exception e) { + body = Map.of(); + } + + Object redirectUris = body.get("redirect_uris"); + + // Return pre-configured credentials (stateless mode for MCP) + // This allows MCP clients to "register" and receive the server's OAuth credentials + Map response = new LinkedHashMap<>(); + response.put("client_id", config.getClientId()); + + String clientSecret = config.getClientSecret(); + if (clientSecret != null && !clientSecret.isEmpty()) { + response.put("client_secret", clientSecret); + } + + response.put("client_id_issued_at", System.currentTimeMillis() / 1000); + response.put("client_secret_expires_at", 0); + + if (redirectUris != null) { + response.put("redirect_uris", redirectUris); + } + + // Use client_secret_post if secret is configured, otherwise none + String authMethod = (clientSecret != null && !clientSecret.isEmpty()) + ? "client_secret_post" : "none"; + response.put("token_endpoint_auth_method", authMethod); + response.put("grant_types", List.of("authorization_code", "refresh_token")); + response.put("response_types", List.of("code")); + + CorsFilter.setCorsHeaders(ctx); + ctx.contentType("application/json"); + ctx.status(201); + ctx.json(response); + } + + /** + * Split space-separated scopes into a list. + * + * @param scopes space-separated scope string + * @return list of scopes + */ + List splitScopes(String scopes) { + if (scopes == null || scopes.isEmpty()) { + return List.of(); + } + return Arrays.asList(scopes.trim().split("\\s+")); + } + + /** + * URL-encode a string. + * + * @param value value to encode + * @return URL-encoded value + */ + String urlEncode(String value) { + if (value == null) { + return ""; + } + return URLEncoder.encode(value, StandardCharsets.UTF_8); + } + + /** + * Generate a random alphanumeric string. + * + * @param length length of the string + * @return random string + */ + String generateRandomString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append(ALPHANUMERIC.charAt(RANDOM.nextInt(ALPHANUMERIC.length()))); + } + return sb.toString(); + } +} diff --git a/examples/auth/src/main/java/com/gophersecurity/mcp/auth/tools/WeatherTools.java b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/tools/WeatherTools.java new file mode 100644 index 00000000..bf3eacc9 --- /dev/null +++ b/examples/auth/src/main/java/com/gophersecurity/mcp/auth/tools/WeatherTools.java @@ -0,0 +1,223 @@ +package com.gophersecurity.mcp.auth.tools; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.gophersecurity.mcp.auth.middleware.OAuthAuthMiddleware; +import com.gophersecurity.mcp.auth.model.ToolResult; +import com.gophersecurity.mcp.auth.model.ToolSpec; +import com.gophersecurity.mcp.auth.routes.McpHandler; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Weather-related MCP tools. + */ +public class WeatherTools { + + private static final ObjectMapper mapper = new ObjectMapper(); + private static final String[] CONDITIONS = {"Sunny", "Cloudy", "Rainy", "Windy", "Snowy"}; + + /** + * Register all weather tools with the MCP handler. + * + * @param mcp MCP handler + * @param authMiddleware authentication middleware + */ + public static void register(McpHandler mcp, OAuthAuthMiddleware authMiddleware) { + // get-weather: no scope required + mcp.registerTool("get-weather", + new ToolSpec("get-weather", "Get current weather for a city", + Map.of( + "type", "object", + "properties", Map.of( + "city", Map.of("type", "string", "description", "City name") + ), + "required", List.of("city") + )), + (args, ctx) -> getWeather(args)); + + // get-forecast: requires mcp:read + mcp.registerTool("get-forecast", + new ToolSpec("get-forecast", "Get 5-day weather forecast for a city", + Map.of( + "type", "object", + "properties", Map.of( + "city", Map.of("type", "string", "description", "City name") + ), + "required", List.of("city") + )), + (args, ctx) -> { + if (requiresScope(authMiddleware, "mcp:read")) { + if (!authMiddleware.hasScope("mcp:read")) { + return accessDenied("mcp:read"); + } + } + return getForecast(args); + }); + + // get-weather-alerts: requires mcp:admin + mcp.registerTool("get-weather-alerts", + new ToolSpec("get-weather-alerts", "Get weather alerts for a region", + Map.of( + "type", "object", + "properties", Map.of( + "region", Map.of("type", "string", "description", "Region name") + ), + "required", List.of("region") + )), + (args, ctx) -> { + if (requiresScope(authMiddleware, "mcp:admin")) { + if (!authMiddleware.hasScope("mcp:admin")) { + return accessDenied("mcp:admin"); + } + } + return getWeatherAlerts(args); + }); + } + + /** + * Check if scope checking is required. + * + * @param authMiddleware auth middleware + * @param scope scope to check + * @return true if scope should be checked + */ + static boolean requiresScope(OAuthAuthMiddleware authMiddleware, String scope) { + return authMiddleware != null && !authMiddleware.isAuthDisabled(); + } + + /** + * Get current weather for a city. + * + * @param args tool arguments + * @return weather result + */ + static ToolResult getWeather(Map args) { + String city = (String) args.get("city"); + if (city == null || city.isEmpty()) { + return ToolResult.error("Missing city parameter"); + } + + int hash = Math.abs(city.hashCode()); + int temperature = 15 + (hash % 20); + String condition = getCondition(hash); + int humidity = 40 + (hash % 40); + + Map weather = new LinkedHashMap<>(); + weather.put("city", city); + weather.put("temperature", temperature); + weather.put("condition", condition); + weather.put("humidity", humidity); + weather.put("unit", "celsius"); + + try { + return ToolResult.text(mapper.writeValueAsString(weather)); + } catch (JsonProcessingException e) { + return ToolResult.error("Failed to serialize weather data"); + } + } + + /** + * Get 5-day forecast for a city. + * + * @param args tool arguments + * @return forecast result + */ + static ToolResult getForecast(Map args) { + String city = (String) args.get("city"); + if (city == null || city.isEmpty()) { + return ToolResult.error("Missing city parameter"); + } + + int hash = Math.abs(city.hashCode()); + List> forecast = new ArrayList<>(); + + for (int i = 0; i < 5; i++) { + int dayHash = hash + i * 7; + Map day = new LinkedHashMap<>(); + day.put("day", i + 1); + day.put("temperature", 15 + (dayHash % 20)); + day.put("condition", getCondition(dayHash)); + forecast.add(day); + } + + Map result = new LinkedHashMap<>(); + result.put("city", city); + result.put("forecast", forecast); + + try { + return ToolResult.text(mapper.writeValueAsString(result)); + } catch (JsonProcessingException e) { + return ToolResult.error("Failed to serialize forecast data"); + } + } + + /** + * Get weather alerts for a region. + * + * @param args tool arguments + * @return alerts result + */ + static ToolResult getWeatherAlerts(Map args) { + String region = (String) args.get("region"); + if (region == null || region.isEmpty()) { + return ToolResult.error("Missing region parameter"); + } + + int hash = Math.abs(region.hashCode()); + List> alerts = new ArrayList<>(); + + // Generate 0-2 alerts based on hash + int alertCount = hash % 3; + String[] alertTypes = {"Storm Warning", "Heat Advisory", "Flood Watch", "Wind Advisory"}; + + for (int i = 0; i < alertCount; i++) { + Map alert = new LinkedHashMap<>(); + alert.put("type", alertTypes[(hash + i) % alertTypes.length]); + alert.put("severity", (hash + i) % 2 == 0 ? "moderate" : "severe"); + alert.put("message", "Weather alert for " + region); + alerts.add(alert); + } + + Map result = new LinkedHashMap<>(); + result.put("region", region); + result.put("alerts", alerts); + + try { + return ToolResult.text(mapper.writeValueAsString(result)); + } catch (JsonProcessingException e) { + return ToolResult.error("Failed to serialize alerts data"); + } + } + + /** + * Create access denied error result. + * + * @param requiredScope the required scope + * @return error result + */ + static ToolResult accessDenied(String requiredScope) { + Map error = new LinkedHashMap<>(); + error.put("error", "access_denied"); + error.put("message", "Access denied. Required scope: " + requiredScope); + + try { + return ToolResult.error(mapper.writeValueAsString(error)); + } catch (JsonProcessingException e) { + return ToolResult.error("Access denied. Required scope: " + requiredScope); + } + } + + /** + * Get weather condition based on hash. + * + * @param hash hash value + * @return weather condition + */ + static String getCondition(int hash) { + return CONDITIONS[Math.abs(hash) % CONDITIONS.length]; + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/config/AuthServerConfigTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/config/AuthServerConfigTest.java new file mode 100644 index 00000000..3084689d --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/config/AuthServerConfigTest.java @@ -0,0 +1,325 @@ +package com.gophersecurity.mcp.auth.config; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for AuthServerConfig. + */ +class AuthServerConfigTest { + + @TempDir + Path tempDir; + + @Test + void testParseBasicKeyValue() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, "host=localhost\nport=3001"); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals("localhost", map.get("host")); + assertEquals("3001", map.get("port")); + } + + @Test + void testParseCommentsSkipped() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + "# This is a comment\n" + + "host=localhost\n" + + "# Another comment\n" + + "port=3001"); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals(2, map.size()); + assertEquals("localhost", map.get("host")); + assertEquals("3001", map.get("port")); + } + + @Test + void testParseEmptyLinesSkipped() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + "host=localhost\n" + + "\n" + + "\n" + + "port=3001\n" + + "\n"); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals(2, map.size()); + assertEquals("localhost", map.get("host")); + assertEquals("3001", map.get("port")); + } + + @Test + void testParseValuesWithEquals() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + "auth_url=https://auth.example.com?param=value&other=123"); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals("https://auth.example.com?param=value&other=123", map.get("auth_url")); + } + + @Test + void testParseWhitespaceTrimmed() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + " host = localhost \n" + + " port= 3001"); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals("localhost", map.get("host")); + assertEquals("3001", map.get("port")); + } + + @Test + void testParseEmptyValue() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, "empty_key="); + + Map map = AuthServerConfig.parseConfigFile(configFile.toString()); + + assertEquals("", map.get("empty_key")); + } + + @Test + void testDefaultValues() { + AuthServerConfig config = new AuthServerConfig(); + + assertEquals("0.0.0.0", config.getHost()); + assertEquals(3001, config.getPort()); + assertEquals("http://localhost:3001", config.getServerUrl()); + assertEquals("mcp:read mcp:admin", config.getAllowedScopes()); + assertEquals(3600, config.getJwksCacheDuration()); + assertTrue(config.isJwksAutoRefresh()); + assertEquals(5000, config.getRequestTimeout()); + assertFalse(config.isAuthDisabled()); + } + + @Test + void testDefaultDisabled() { + AuthServerConfig config = AuthServerConfig.defaultDisabled(); + + assertTrue(config.isAuthDisabled()); + assertEquals("0.0.0.0", config.getHost()); + assertEquals(3001, config.getPort()); + } + + @Test + void testBuildFromMapDefaults() { + Map map = new HashMap<>(); + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + assertEquals("0.0.0.0", config.getHost()); + assertEquals(3001, config.getPort()); + // server_url uses localhost when host is 0.0.0.0 + assertEquals("http://localhost:3001", config.getServerUrl()); + assertFalse(config.isAuthDisabled()); + } + + @Test + void testBuildFromMapCustomValues() { + Map map = new HashMap<>(); + map.put("host", "127.0.0.1"); + map.put("port", "8080"); + map.put("client_id", "my-client"); + map.put("auth_disabled", "true"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + assertEquals("127.0.0.1", config.getHost()); + assertEquals(8080, config.getPort()); + assertEquals("http://127.0.0.1:8080", config.getServerUrl()); + assertEquals("my-client", config.getClientId()); + assertTrue(config.isAuthDisabled()); + } + + @Test + void testBuildFromMapEndpointDerivation() { + Map map = new HashMap<>(); + map.put("auth_server_url", "https://auth.example.com/realms/test"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + assertEquals("https://auth.example.com/realms/test/protocol/openid-connect/certs", + config.getJwksUri()); + assertEquals("https://auth.example.com/realms/test", + config.getIssuer()); + assertEquals("https://auth.example.com/realms/test/protocol/openid-connect/auth", + config.getOauthAuthorizeUrl()); + assertEquals("https://auth.example.com/realms/test/protocol/openid-connect/token", + config.getOauthTokenUrl()); + } + + @Test + void testBuildFromMapExplicitEndpointsOverride() { + Map map = new HashMap<>(); + map.put("auth_server_url", "https://auth.example.com/realms/test"); + map.put("jwks_uri", "https://custom.example.com/jwks"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + // Explicit value should override derived + assertEquals("https://custom.example.com/jwks", config.getJwksUri()); + // Other endpoints still derived + assertEquals("https://auth.example.com/realms/test/protocol/openid-connect/auth", + config.getOauthAuthorizeUrl()); + } + + @Test + void testBuildFromMapLocalhostSubstitution() { + Map map = new HashMap<>(); + map.put("host", "0.0.0.0"); + map.put("port", "3001"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + // server_url should use localhost when host is 0.0.0.0 + assertEquals("http://localhost:3001", config.getServerUrl()); + } + + @Test + void testBuildFromMapBooleanParsing() { + // Test "1" as true + Map map1 = new HashMap<>(); + map1.put("auth_disabled", "1"); + AuthServerConfig config1 = AuthServerConfig.buildFromMap(map1); + assertTrue(config1.isAuthDisabled()); + + // Test "true" as true + Map map2 = new HashMap<>(); + map2.put("jwks_auto_refresh", "true"); + AuthServerConfig config2 = AuthServerConfig.buildFromMap(map2); + assertTrue(config2.isJwksAutoRefresh()); + + // Test "false" as false + Map map3 = new HashMap<>(); + map3.put("jwks_auto_refresh", "false"); + AuthServerConfig config3 = AuthServerConfig.buildFromMap(map3); + assertFalse(config3.isJwksAutoRefresh()); + } + + @Test + void testBuildFromMapCacheSettings() { + Map map = new HashMap<>(); + map.put("jwks_cache_duration", "7200"); + map.put("jwks_auto_refresh", "false"); + map.put("request_timeout", "10000"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + assertEquals(7200, config.getJwksCacheDuration()); + assertFalse(config.isJwksAutoRefresh()); + assertEquals(10000, config.getRequestTimeout()); + } + + @Test + void testValidatePassesWithValidConfig() { + Map map = new HashMap<>(); + map.put("client_id", "my-client"); + map.put("client_secret", "secret"); + map.put("auth_server_url", "https://auth.example.com"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + assertDoesNotThrow(() -> config.validate()); + } + + @Test + void testValidateFailsMissingClientId() { + Map map = new HashMap<>(); + map.put("client_secret", "secret"); + map.put("jwks_uri", "https://example.com/jwks"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> config.validate()); + assertTrue(ex.getMessage().contains("client_id")); + } + + @Test + void testValidateFailsMissingClientSecret() { + Map map = new HashMap<>(); + map.put("client_id", "my-client"); + map.put("jwks_uri", "https://example.com/jwks"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> config.validate()); + assertTrue(ex.getMessage().contains("client_secret")); + } + + @Test + void testValidateFailsMissingJwksUri() { + Map map = new HashMap<>(); + map.put("client_id", "my-client"); + map.put("client_secret", "secret"); + + AuthServerConfig config = AuthServerConfig.buildFromMap(map); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> config.validate()); + assertTrue(ex.getMessage().contains("jwks_uri")); + } + + @Test + void testValidateSkippedWhenAuthDisabled() { + // Empty config should fail validation normally + AuthServerConfig config1 = new AuthServerConfig(); + assertThrows(IllegalArgumentException.class, () -> config1.validate()); + + // But should pass when auth is disabled + AuthServerConfig config2 = AuthServerConfig.defaultDisabled(); + assertDoesNotThrow(() -> config2.validate()); + } + + @Test + void testFromFileSuccess() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + "# Test configuration\n" + + "host=127.0.0.1\n" + + "port=8080\n" + + "client_id=test-client\n" + + "client_secret=test-secret\n" + + "auth_server_url=https://auth.example.com/realms/test\n"); + + AuthServerConfig config = AuthServerConfig.fromFile(configFile.toString()); + + assertEquals("127.0.0.1", config.getHost()); + assertEquals(8080, config.getPort()); + assertEquals("test-client", config.getClientId()); + assertEquals("test-secret", config.getClientSecret()); + assertEquals("https://auth.example.com/realms/test/protocol/openid-connect/certs", + config.getJwksUri()); + } + + @Test + void testFromFileValidationError() throws IOException { + Path configFile = tempDir.resolve("test.config"); + Files.writeString(configFile, + "host=127.0.0.1\n" + + "port=8080\n" + + "# Missing client_id, client_secret, auth_server_url\n"); + + assertThrows(IllegalArgumentException.class, + () -> AuthServerConfig.fromFile(configFile.toString())); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/CorsFilterTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/CorsFilterTest.java new file mode 100644 index 00000000..d8bdd12e --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/CorsFilterTest.java @@ -0,0 +1,73 @@ +package com.gophersecurity.mcp.auth.middleware; + +import io.javalin.http.Context; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for CorsFilter. + */ +@ExtendWith(MockitoExtension.class) +class CorsFilterTest { + + @Mock + private Context ctx; + + @Test + void testSetCorsHeadersSetsAllRequiredHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + + CorsFilter.setCorsHeaders(ctx); + + verify(ctx).header("Access-Control-Allow-Origin", "*"); + verify(ctx).header("Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"); + verify(ctx).header("Access-Control-Allow-Headers", + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, " + + "X-Requested-With, Origin, Cache-Control, Pragma, Mcp-Session-Id, Mcp-Protocol-Version"); + verify(ctx).header("Access-Control-Expose-Headers", + "WWW-Authenticate, Content-Length, Content-Type"); + verify(ctx).header("Access-Control-Max-Age", "86400"); + } + + @Test + void testSetCorsHeadersIncludesMcpHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + ArgumentCaptor valueCaptor = ArgumentCaptor.forClass(String.class); + + CorsFilter.setCorsHeaders(ctx); + + verify(ctx).header(eq("Access-Control-Allow-Headers"), valueCaptor.capture()); + String allowedHeaders = valueCaptor.getValue(); + assertTrue(allowedHeaders.contains("Mcp-Session-Id")); + assertTrue(allowedHeaders.contains("Mcp-Protocol-Version")); + } + + @Test + void testHandlePreflightSetsCorsHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + + CorsFilter.handlePreflight(ctx); + + verify(ctx).header("Access-Control-Allow-Origin", "*"); + verify(ctx).header("Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"); + } + + @Test + void testHandlePreflightSets204Status() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + + CorsFilter.handlePreflight(ctx); + + verify(ctx).status(204); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddlewareTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddlewareTest.java new file mode 100644 index 00000000..dab05347 --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/middleware/OAuthAuthMiddlewareTest.java @@ -0,0 +1,447 @@ +package com.gophersecurity.mcp.auth.middleware; + +import com.gophersecurity.orch.auth.GopherAuthClient; +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import io.javalin.http.Context; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; +import static org.mockito.ArgumentMatchers.contains; + +/** + * Unit tests for OAuthAuthMiddleware. + */ +@ExtendWith(MockitoExtension.class) +class OAuthAuthMiddlewareTest { + + @Mock + private Context ctx; + + @Mock + private GopherAuthClient authClient; + + private AuthServerConfig config; + private OAuthAuthMiddleware middleware; + + @BeforeEach + void setUp() { + Map configMap = new HashMap<>(); + configMap.put("server_url", "http://localhost:3001"); + configMap.put("auth_server_url", "https://auth.example.com"); + configMap.put("allowed_scopes", "mcp:read mcp:admin"); + config = AuthServerConfig.buildFromMap(configMap); + middleware = new OAuthAuthMiddleware(authClient, config); + } + + // Token extraction tests + + @Test + void testExtractTokenFromAuthorizationHeader() { + when(ctx.header("Authorization")).thenReturn("Bearer eyJhbGciOiJSUzI1NiJ9.test"); + + String token = middleware.extractToken(ctx); + + assertEquals("eyJhbGciOiJSUzI1NiJ9.test", token); + } + + @Test + void testExtractTokenHandlesLowercaseBearer() { + when(ctx.header("Authorization")).thenReturn("bearer eyJhbGciOiJSUzI1NiJ9.test"); + + String token = middleware.extractToken(ctx); + + assertEquals("eyJhbGciOiJSUzI1NiJ9.test", token); + } + + @Test + void testExtractTokenHandlesUppercaseBearer() { + when(ctx.header("Authorization")).thenReturn("BEARER eyJhbGciOiJSUzI1NiJ9.test"); + + String token = middleware.extractToken(ctx); + + assertEquals("eyJhbGciOiJSUzI1NiJ9.test", token); + } + + @Test + void testExtractTokenFromQueryParameter() { + when(ctx.header("Authorization")).thenReturn(null); + when(ctx.queryParam("access_token")).thenReturn("query-token-123"); + + String token = middleware.extractToken(ctx); + + assertEquals("query-token-123", token); + } + + @Test + void testExtractTokenReturnsNullWhenNoToken() { + when(ctx.header("Authorization")).thenReturn(null); + when(ctx.queryParam("access_token")).thenReturn(null); + + String token = middleware.extractToken(ctx); + + assertNull(token); + } + + @Test + void testExtractTokenReturnsNullForEmptyQueryParam() { + when(ctx.header("Authorization")).thenReturn(null); + when(ctx.queryParam("access_token")).thenReturn(""); + + String token = middleware.extractToken(ctx); + + assertNull(token); + } + + @Test + void testHeaderTakesPriorityOverQueryParam() { + when(ctx.header("Authorization")).thenReturn("Bearer header-token"); + // Note: queryParam should not be called when header is present + + String token = middleware.extractToken(ctx); + + assertEquals("header-token", token); + } + + @Test + void testExtractTokenIgnoresNonBearerHeader() { + when(ctx.header("Authorization")).thenReturn("Basic dXNlcjpwYXNz"); + when(ctx.queryParam("access_token")).thenReturn("fallback-token"); + + String token = middleware.extractToken(ctx); + + assertEquals("fallback-token", token); + } + + // Auth context tests + + @Test + void testGetAuthContextReturnsEmptyByDefault() { + assertFalse(middleware.getAuthContext().isAuthenticated()); + assertEquals("", middleware.getAuthContext().getUserId()); + } + + @Test + void testIsAuthDisabledDelegatesToConfig() { + assertFalse(middleware.isAuthDisabled()); + + Map disabledConfigMap = new HashMap<>(); + disabledConfigMap.put("auth_disabled", "true"); + AuthServerConfig disabledConfig = AuthServerConfig.buildFromMap(disabledConfigMap); + OAuthAuthMiddleware disabledMiddleware = new OAuthAuthMiddleware(authClient, disabledConfig); + + assertTrue(disabledMiddleware.isAuthDisabled()); + } + + // Public path tests + + @Test + void testHealthIsPublic() { + assertTrue(middleware.isPublicPath("/health")); + } + + @Test + void testWellKnownOAuthProtectedResourceIsPublic() { + assertTrue(middleware.isPublicPath("/.well-known/oauth-protected-resource")); + } + + @Test + void testWellKnownOpenidConfigIsPublic() { + assertTrue(middleware.isPublicPath("/.well-known/openid-configuration")); + } + + @Test + void testOauthAuthorizeIsPublic() { + assertTrue(middleware.isPublicPath("/oauth/authorize")); + } + + @Test + void testOauthRegisterIsPublic() { + assertTrue(middleware.isPublicPath("/oauth/register")); + } + + @Test + void testFaviconIsPublic() { + assertTrue(middleware.isPublicPath("/favicon.ico")); + } + + @Test + void testMcpIsNotPublic() { + assertFalse(middleware.isPublicPath("/mcp")); + } + + // Requires auth tests + + @Test + void testMcpRequiresAuth() { + assertTrue(middleware.requiresAuth("/mcp")); + assertTrue(middleware.requiresAuth("/mcp/tools")); + } + + @Test + void testRpcRequiresAuth() { + assertTrue(middleware.requiresAuth("/rpc")); + assertTrue(middleware.requiresAuth("/rpc/call")); + } + + @Test + void testEventsRequiresAuth() { + assertTrue(middleware.requiresAuth("/events")); + assertTrue(middleware.requiresAuth("/events/subscribe")); + } + + @Test + void testSseRequiresAuth() { + assertTrue(middleware.requiresAuth("/sse")); + assertTrue(middleware.requiresAuth("/sse/stream")); + } + + @Test + void testUnknownPathDoesNotRequireAuth() { + assertFalse(middleware.requiresAuth("/foo")); + assertFalse(middleware.requiresAuth("/bar/baz")); + assertFalse(middleware.requiresAuth("/api/unknown")); + } + + @Test + void testRequiresAuthReturnsFalseWhenAuthDisabled() { + Map disabledConfigMap = new HashMap<>(); + disabledConfigMap.put("auth_disabled", "true"); + AuthServerConfig disabledConfig = AuthServerConfig.buildFromMap(disabledConfigMap); + OAuthAuthMiddleware disabledMiddleware = new OAuthAuthMiddleware(authClient, disabledConfig); + + assertFalse(disabledMiddleware.requiresAuth("/mcp")); + assertFalse(disabledMiddleware.requiresAuth("/rpc")); + } + + @Test + void testRequiresAuthReturnsTrueWhenAuthClientIsNull() { + // Auth is still required even when client is null - this triggers OAuth flow + OAuthAuthMiddleware nullClientMiddleware = new OAuthAuthMiddleware(null, config); + + assertTrue(nullClientMiddleware.requiresAuth("/mcp")); + assertTrue(nullClientMiddleware.requiresAuth("/rpc")); + } + + @Test + void testPublicPathsNeverRequireAuth() { + assertTrue(middleware.isPublicPath("/health")); + assertFalse(middleware.requiresAuth("/health")); + + assertTrue(middleware.isPublicPath("/.well-known/oauth-protected-resource")); + assertFalse(middleware.requiresAuth("/.well-known/oauth-protected-resource")); + } + + // Unauthorized response tests + + @Test + void testSendUnauthorizedSets401Status() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.sendUnauthorized(ctx, "invalid_token", "Token expired"); + + verify(ctx).status(401); + } + + @Test + void testSendUnauthorizedSetsWwwAuthenticateHeader() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.sendUnauthorized(ctx, "invalid_token", "Token expired"); + + verify(ctx).header(eq("WWW-Authenticate"), contains("Bearer realm=")); + verify(ctx).header(eq("WWW-Authenticate"), contains("resource_metadata=")); + verify(ctx).header(eq("WWW-Authenticate"), contains("error=\"invalid_token\"")); + verify(ctx).header(eq("WWW-Authenticate"), contains("error_description=\"Token expired\"")); + } + + @Test + void testSendUnauthorizedSetsCorsHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.sendUnauthorized(ctx, "invalid_token", "Token expired"); + + verify(ctx).header("Access-Control-Allow-Origin", "*"); + } + + @Test + void testSendUnauthorizedReturnsJsonBody() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + org.mockito.ArgumentCaptor> captor = + org.mockito.ArgumentCaptor.forClass(Map.class); + + middleware.sendUnauthorized(ctx, "invalid_request", "Missing token"); + + verify(ctx).json(captor.capture()); + Map body = captor.getValue(); + assertEquals("invalid_request", body.get("error")); + assertEquals("Missing token", body.get("error_description")); + } + + @Test + void testSendUnauthorizedSetsContentType() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.sendUnauthorized(ctx, "invalid_token", "Token expired"); + + verify(ctx).contentType("application/json"); + } + + @Test + void testSendUnauthorizedEscapesSpecialCharsInDescription() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.sendUnauthorized(ctx, "error", "Description with \"quotes\" and \\backslash"); + + // Verify header escaping + verify(ctx).header(eq("WWW-Authenticate"), + contains("error_description=\"Description with \\\"quotes\\\" and \\\\backslash\"")); + } + + // Handle method tests + + @Test + void testHandlePublicPathSetsEmptyContext() throws Exception { + when(ctx.path()).thenReturn("/health"); + + middleware.handle(ctx); + + assertFalse(middleware.getAuthContext().isAuthenticated()); + assertEquals("", middleware.getAuthContext().getUserId()); + } + + @Test + void testHandleAuthDisabledSetsAnonymousContext() throws Exception { + Map disabledConfigMap = new HashMap<>(); + disabledConfigMap.put("auth_disabled", "true"); + disabledConfigMap.put("allowed_scopes", "mcp:read mcp:admin"); + AuthServerConfig disabledConfig = AuthServerConfig.buildFromMap(disabledConfigMap); + OAuthAuthMiddleware disabledMiddleware = new OAuthAuthMiddleware(authClient, disabledConfig); + + when(ctx.path()).thenReturn("/mcp/tools"); + + disabledMiddleware.handle(ctx); + + assertTrue(disabledMiddleware.getAuthContext().isAuthenticated()); + assertEquals("anonymous", disabledMiddleware.getAuthContext().getUserId()); + assertTrue(disabledMiddleware.getAuthContext().hasScope("mcp:read")); + } + + @Test + void testHandleNullAuthClientReturns401() throws Exception { + // When auth is enabled but client is null, should still require auth (triggers OAuth flow) + OAuthAuthMiddleware nullClientMiddleware = new OAuthAuthMiddleware(null, config); + + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn(null); + when(ctx.queryParam("access_token")).thenReturn(null); + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + nullClientMiddleware.handle(ctx); + + verify(ctx).status(401); + } + + @Test + void testHandleNullAuthClientWithTokenAllowsRequest() throws Exception { + // When auth client is null but token IS provided, allow request (trust OAuth token) + OAuthAuthMiddleware nullClientMiddleware = new OAuthAuthMiddleware(null, config); + + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn("Bearer some-oauth-token"); + + nullClientMiddleware.handle(ctx); + + // Should set anonymous context and NOT return 401 + verify(ctx, never()).status(401); + assertTrue(nullClientMiddleware.getAuthContext().isAuthenticated()); + } + + @Test + void testHandleMissingTokenReturns401() throws Exception { + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn(null); + when(ctx.queryParam("access_token")).thenReturn(null); + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.handle(ctx); + + verify(ctx).status(401); + } + + @Test + void testHandleInvalidTokenReturns401() throws Exception { + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn("Bearer invalid-token"); + when(authClient.validateToken("invalid-token", 30)) + .thenReturn(com.gophersecurity.orch.auth.ValidationResult.failure("Token expired")); + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + middleware.handle(ctx); + + verify(ctx).status(401); + } + + @Test + void testHandleValidTokenSetsProperContext() throws Exception { + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn("Bearer valid-token"); + when(authClient.validateToken("valid-token", 30)) + .thenReturn(com.gophersecurity.orch.auth.ValidationResult.success()); + when(authClient.extractPayload("valid-token")) + .thenReturn(new com.gophersecurity.orch.auth.TokenPayload( + "user123", "mcp:read mcp:admin", "api", 9999999999L)); + + middleware.handle(ctx); + + assertTrue(middleware.getAuthContext().isAuthenticated()); + assertEquals("user123", middleware.getAuthContext().getUserId()); + assertEquals("mcp:read mcp:admin", middleware.getAuthContext().getScopes()); + assertTrue(middleware.hasScope("mcp:read")); + assertTrue(middleware.hasScope("mcp:admin")); + } + + // hasScope helper tests + + @Test + void testHasScopeDelegatesToAuthContext() throws Exception { + when(ctx.path()).thenReturn("/mcp/tools"); + when(ctx.header("Authorization")).thenReturn("Bearer valid-token"); + when(authClient.validateToken("valid-token", 30)) + .thenReturn(com.gophersecurity.orch.auth.ValidationResult.success()); + when(authClient.extractPayload("valid-token")) + .thenReturn(new com.gophersecurity.orch.auth.TokenPayload( + "user123", "read write", "api", 9999999999L)); + + middleware.handle(ctx); + + assertTrue(middleware.hasScope("read")); + assertTrue(middleware.hasScope("write")); + assertFalse(middleware.hasScope("admin")); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/JsonRpcTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/JsonRpcTest.java new file mode 100644 index 00000000..523be6db --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/JsonRpcTest.java @@ -0,0 +1,132 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for JSON-RPC model classes. + */ +class JsonRpcTest { + + private final ObjectMapper mapper = new ObjectMapper(); + + @Test + void testJsonRpcRequestDeserialization() throws Exception { + String json = """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {"name": "test"} + } + """; + + JsonRpcRequest request = mapper.readValue(json, JsonRpcRequest.class); + + assertEquals("2.0", request.getJsonrpc()); + assertEquals(1, request.getId()); + assertEquals("tools/list", request.getMethod()); + assertEquals("test", request.getParams().get("name")); + } + + @Test + void testJsonRpcRequestMinimal() throws Exception { + String json = """ + { + "jsonrpc": "2.0", + "id": "abc", + "method": "ping" + } + """; + + JsonRpcRequest request = mapper.readValue(json, JsonRpcRequest.class); + + assertEquals("2.0", request.getJsonrpc()); + assertEquals("abc", request.getId()); + assertEquals("ping", request.getMethod()); + assertNull(request.getParams()); + } + + @Test + void testJsonRpcResponseSuccess() throws Exception { + JsonRpcResponse response = JsonRpcResponse.success(1, Map.of("status", "ok")); + String json = mapper.writeValueAsString(response); + + assertTrue(json.contains("\"jsonrpc\":\"2.0\"")); + assertTrue(json.contains("\"id\":1")); + assertTrue(json.contains("\"result\"")); + assertTrue(json.contains("\"status\":\"ok\"")); + assertFalse(json.contains("\"error\"")); + } + + @Test + void testJsonRpcResponseError() throws Exception { + JsonRpcError error = JsonRpcError.methodNotFound("Unknown method: foo"); + JsonRpcResponse response = JsonRpcResponse.error(1, error); + String json = mapper.writeValueAsString(response); + + assertTrue(json.contains("\"jsonrpc\":\"2.0\"")); + assertTrue(json.contains("\"id\":1")); + assertTrue(json.contains("\"error\"")); + assertTrue(json.contains("\"code\":-32601")); + assertTrue(json.contains("\"message\":\"Method not found\"")); + assertFalse(json.contains("\"result\"")); + } + + @Test + void testJsonRpcErrorCodes() { + assertEquals(-32700, JsonRpcError.PARSE_ERROR); + assertEquals(-32600, JsonRpcError.INVALID_REQUEST); + assertEquals(-32601, JsonRpcError.METHOD_NOT_FOUND); + assertEquals(-32602, JsonRpcError.INVALID_PARAMS); + assertEquals(-32603, JsonRpcError.INTERNAL_ERROR); + } + + @Test + void testJsonRpcErrorFactories() { + JsonRpcError parseError = JsonRpcError.parseError("unexpected token"); + assertEquals(-32700, parseError.getCode()); + assertEquals("Parse error", parseError.getMessage()); + assertEquals("unexpected token", parseError.getData()); + + JsonRpcError invalidRequest = JsonRpcError.invalidRequest("missing jsonrpc"); + assertEquals(-32600, invalidRequest.getCode()); + assertEquals("Invalid Request", invalidRequest.getMessage()); + + JsonRpcError methodNotFound = JsonRpcError.methodNotFound("foo/bar"); + assertEquals(-32601, methodNotFound.getCode()); + assertEquals("Method not found", methodNotFound.getMessage()); + + JsonRpcError invalidParams = JsonRpcError.invalidParams("missing name"); + assertEquals(-32602, invalidParams.getCode()); + assertEquals("Invalid params", invalidParams.getMessage()); + + JsonRpcError internalError = JsonRpcError.internalError("database error"); + assertEquals(-32603, internalError.getCode()); + assertEquals("Internal error", internalError.getMessage()); + } + + @Test + void testJsonRpcErrorOmitsNullData() throws Exception { + JsonRpcError error = new JsonRpcError(-32600, "Invalid Request"); + String json = mapper.writeValueAsString(error); + + assertTrue(json.contains("\"code\":-32600")); + assertTrue(json.contains("\"message\":\"Invalid Request\"")); + assertFalse(json.contains("\"data\"")); + } + + @Test + void testJsonRpcResponseNullFieldsOmitted() throws Exception { + JsonRpcResponse response = JsonRpcResponse.success(1, "result"); + String json = mapper.writeValueAsString(response); + + // result should be present, error should be omitted + assertTrue(json.contains("\"result\"")); + assertFalse(json.contains("\"error\"")); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/ToolModelTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/ToolModelTest.java new file mode 100644 index 00000000..a598a42d --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/model/ToolModelTest.java @@ -0,0 +1,110 @@ +package com.gophersecurity.mcp.auth.model; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for Tool model classes. + */ +class ToolModelTest { + + private final ObjectMapper mapper = new ObjectMapper(); + + @Test + void testToolSpecToMap() { + Map schema = new LinkedHashMap<>(); + schema.put("type", "object"); + schema.put("properties", Map.of("name", Map.of("type", "string"))); + + ToolSpec spec = new ToolSpec("greet", "Greet a user", schema); + Map map = spec.toMap(); + + assertEquals("greet", map.get("name")); + assertEquals("Greet a user", map.get("description")); + assertNotNull(map.get("inputSchema")); + } + + @Test + void testToolSpecGetters() { + Map schema = Map.of("type", "object"); + ToolSpec spec = new ToolSpec("test", "Test tool", schema); + + assertEquals("test", spec.getName()); + assertEquals("Test tool", spec.getDescription()); + assertEquals(schema, spec.getInputSchema()); + } + + @Test + void testToolContentText() { + ToolContent content = ToolContent.text("Hello, World!"); + + assertEquals("text", content.getType()); + assertEquals("Hello, World!", content.getText()); + assertNull(content.getData()); + assertNull(content.getMimeType()); + } + + @Test + void testToolContentImage() { + ToolContent content = ToolContent.image("base64data==", "image/png"); + + assertEquals("image", content.getType()); + assertNull(content.getText()); + assertEquals("base64data==", content.getData()); + assertEquals("image/png", content.getMimeType()); + } + + @Test + void testToolContentOmitsNullFields() throws Exception { + ToolContent content = ToolContent.text("Hello"); + String json = mapper.writeValueAsString(content); + + assertTrue(json.contains("\"type\":\"text\"")); + assertTrue(json.contains("\"text\":\"Hello\"")); + assertFalse(json.contains("\"data\"")); + assertFalse(json.contains("\"mimeType\"")); + } + + @Test + void testToolResultText() { + ToolResult result = ToolResult.text("Success!"); + + assertFalse(result.isError()); + assertEquals(1, result.getContent().size()); + assertEquals("text", result.getContent().get(0).getType()); + assertEquals("Success!", result.getContent().get(0).getText()); + } + + @Test + void testToolResultError() { + ToolResult result = ToolResult.error("Something went wrong"); + + assertTrue(result.isError()); + assertEquals(1, result.getContent().size()); + assertEquals("text", result.getContent().get(0).getType()); + assertEquals("Something went wrong", result.getContent().get(0).getText()); + } + + @Test + void testToolResultToMap() { + ToolResult result = ToolResult.text("OK"); + Map map = result.toMap(); + + assertNotNull(map.get("content")); + assertEquals(false, map.get("isError")); + } + + @Test + void testToolResultErrorToMap() { + ToolResult result = ToolResult.error("Error"); + Map map = result.toMap(); + + assertNotNull(map.get("content")); + assertEquals(true, map.get("isError")); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/HealthHandlerTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/HealthHandlerTest.java new file mode 100644 index 00000000..64e7950a --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/HealthHandlerTest.java @@ -0,0 +1,103 @@ +package com.gophersecurity.mcp.auth.routes; + +import io.javalin.http.Context; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for HealthHandler. + */ +@ExtendWith(MockitoExtension.class) +class HealthHandlerTest { + + @Mock + private Context ctx; + + @Test + void testResponseContainsStatusOk() { + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + HealthHandler handler = new HealthHandler(); + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + assertEquals("ok", response.get("status")); + } + + @Test + void testResponseContainsTimestampInISO8601() { + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + HealthHandler handler = new HealthHandler(); + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + String timestamp = (String) response.get("timestamp"); + assertNotNull(timestamp); + // ISO8601 format contains 'T' separator + assertTrue(timestamp.contains("T")); + // Should contain time zone indicator (Z for UTC) + assertTrue(timestamp.endsWith("Z")); + } + + @Test + void testResponseIncludesVersionWhenProvided() { + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + HealthHandler handler = new HealthHandler("1.0.0"); + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + assertEquals("1.0.0", response.get("version")); + } + + @Test + void testResponseOmitsVersionWhenNotProvided() { + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + HealthHandler handler = new HealthHandler(); + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + assertFalse(response.containsKey("version")); + } + + @Test + void testResponseOmitsVersionWhenEmpty() { + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + HealthHandler handler = new HealthHandler(""); + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + assertFalse(response.containsKey("version")); + } + + @Test + void testSetsContentTypeToJson() { + when(ctx.contentType(anyString())).thenReturn(ctx); + + HealthHandler handler = new HealthHandler(); + handler.handle(ctx); + + verify(ctx).contentType("application/json"); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/McpHandlerTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/McpHandlerTest.java new file mode 100644 index 00000000..d9c1a59c --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/McpHandlerTest.java @@ -0,0 +1,244 @@ +package com.gophersecurity.mcp.auth.routes; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import com.gophersecurity.mcp.auth.middleware.OAuthAuthMiddleware; +import com.gophersecurity.mcp.auth.model.JsonRpcResponse; +import com.gophersecurity.mcp.auth.model.ToolResult; +import com.gophersecurity.mcp.auth.model.ToolSpec; +import io.javalin.http.Context; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +/** + * Unit tests for McpHandler. + */ +@ExtendWith(MockitoExtension.class) +class McpHandlerTest { + + @Mock + private Context ctx; + + private McpHandler handler; + private ObjectMapper mapper = new ObjectMapper(); + + @BeforeEach + void setUp() { + AuthServerConfig config = AuthServerConfig.defaultDisabled(); + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(null, config); + handler = new McpHandler(authMiddleware); + } + + // Parse error tests + + @Test + void testParseErrorResponse() throws Exception { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.body()).thenReturn("invalid json {{{"); + ArgumentCaptor captor = ArgumentCaptor.forClass(JsonRpcResponse.class); + + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + JsonRpcResponse response = captor.getValue(); + assertNotNull(response.getError()); + assertEquals(-32700, response.getError().getCode()); + } + + // Invalid request tests + + @Test + void testInvalidRequestWrongVersion() throws Exception { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.body()).thenReturn("{\"jsonrpc\":\"1.0\",\"id\":1,\"method\":\"ping\"}"); + ArgumentCaptor captor = ArgumentCaptor.forClass(JsonRpcResponse.class); + + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + JsonRpcResponse response = captor.getValue(); + assertNotNull(response.getError()); + assertEquals(-32600, response.getError().getCode()); + } + + // Method not found tests + + @Test + void testMethodNotFound() throws Exception { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"unknown/method\"}"); + ArgumentCaptor captor = ArgumentCaptor.forClass(JsonRpcResponse.class); + + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + JsonRpcResponse response = captor.getValue(); + assertNotNull(response.getError()); + assertEquals(-32601, response.getError().getCode()); + } + + // Initialize tests + + @Test + void testInitializeReturnsCorrectStructure() throws Exception { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\"}"); + ArgumentCaptor captor = ArgumentCaptor.forClass(JsonRpcResponse.class); + + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + JsonRpcResponse response = captor.getValue(); + assertNull(response.getError()); + + @SuppressWarnings("unchecked") + Map result = (Map) response.getResult(); + assertEquals("2024-11-05", result.get("protocolVersion")); + assertNotNull(result.get("capabilities")); + assertNotNull(result.get("serverInfo")); + + @SuppressWarnings("unchecked") + Map serverInfo = (Map) result.get("serverInfo"); + assertEquals("java-auth-mcp-server", serverInfo.get("name")); + assertEquals("1.0.0", serverInfo.get("version")); + } + + // Ping tests + + @Test + void testPingReturnsEmptyResult() throws Exception { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"ping\"}"); + ArgumentCaptor captor = ArgumentCaptor.forClass(JsonRpcResponse.class); + + handler.handle(ctx); + + verify(ctx).json(captor.capture()); + JsonRpcResponse response = captor.getValue(); + assertNull(response.getError()); + + @SuppressWarnings("unchecked") + Map result = (Map) response.getResult(); + assertTrue(result.isEmpty()); + } + + // Error code constants tests + + @Test + void testErrorCodeConstants() { + assertEquals(-32700, McpHandler.PARSE_ERROR); + assertEquals(-32600, McpHandler.INVALID_REQUEST); + assertEquals(-32601, McpHandler.METHOD_NOT_FOUND); + assertEquals(-32602, McpHandler.INVALID_PARAMS); + assertEquals(-32603, McpHandler.INTERNAL_ERROR); + } + + // Tool registration tests + + @Test + void testRegisterToolAddsToBothMaps() { + ToolSpec spec = new ToolSpec("test-tool", "A test tool", Map.of("type", "object")); + + handler.registerTool("test-tool", spec, (args, c) -> ToolResult.text("OK")); + + assertTrue(handler.getTools().containsKey("test-tool")); + assertTrue(handler.getToolHandlers().containsKey("test-tool")); + assertEquals(spec, handler.getTools().get("test-tool")); + } + + // Tools list tests + + @Test + void testHandleToolsListReturnsRegisteredTools() { + ToolSpec spec = new ToolSpec("my-tool", "My tool", Map.of("type", "object")); + handler.registerTool("my-tool", spec, (args, c) -> ToolResult.text("OK")); + + Map result = handler.handleToolsList(); + + @SuppressWarnings("unchecked") + List> tools = (List>) result.get("tools"); + assertEquals(1, tools.size()); + assertEquals("my-tool", tools.get(0).get("name")); + } + + @Test + void testHandleToolsListReturnsEmptyListWhenNoTools() { + Map result = handler.handleToolsList(); + + @SuppressWarnings("unchecked") + List> tools = (List>) result.get("tools"); + assertTrue(tools.isEmpty()); + } + + // Tools call tests + + @Test + void testHandleToolsCallExecutesCorrectHandler() { + handler.registerTool("greet", new ToolSpec("greet", "Greet", Map.of()), + (args, c) -> ToolResult.text("Hello, " + args.get("name"))); + + Map params = new HashMap<>(); + params.put("name", "greet"); + params.put("arguments", Map.of("name", "World")); + + Map result = handler.handleToolsCall(params, ctx); + + assertFalse((Boolean) result.get("isError")); + } + + @Test + void testHandleToolsCallPassesArgumentsToHandler() { + handler.registerTool("echo", new ToolSpec("echo", "Echo", Map.of()), + (args, c) -> ToolResult.text("Received: " + args.get("message"))); + + Map params = new HashMap<>(); + params.put("name", "echo"); + params.put("arguments", Map.of("message", "test123")); + + Map result = handler.handleToolsCall(params, ctx); + + assertFalse((Boolean) result.get("isError")); + } + + @Test + void testHandleToolsCallReturnsErrorForUnknownTool() { + Map params = new HashMap<>(); + params.put("name", "unknown-tool"); + + Map result = handler.handleToolsCall(params, ctx); + + assertTrue((Boolean) result.get("isError")); + } + + @Test + void testHandleToolsCallReturnsErrorForMissingName() { + Map params = new HashMap<>(); + + Map result = handler.handleToolsCall(params, ctx); + + assertTrue((Boolean) result.get("isError")); + } + + @Test + void testHandleToolsCallReturnsErrorForNullParams() { + Map result = handler.handleToolsCall(null, ctx); + + assertTrue((Boolean) result.get("isError")); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/OAuthEndpointsTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/OAuthEndpointsTest.java new file mode 100644 index 00000000..b633e748 --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/routes/OAuthEndpointsTest.java @@ -0,0 +1,419 @@ +package com.gophersecurity.mcp.auth.routes; + +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import io.javalin.http.Context; +import io.javalin.http.HttpStatus; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.lenient; + +/** + * Unit tests for OAuthEndpoints. + */ +@ExtendWith(MockitoExtension.class) +class OAuthEndpointsTest { + + @Mock + private Context ctx; + + private AuthServerConfig config; + private OAuthEndpoints endpoints; + + @BeforeEach + void setUp() { + Map configMap = new HashMap<>(); + configMap.put("server_url", "http://localhost:3001"); + configMap.put("auth_server_url", "https://auth.example.com"); + configMap.put("allowed_scopes", "mcp:read mcp:admin openid"); + configMap.put("client_id", "test-client-id"); + configMap.put("client_secret", "test-client-secret"); + config = AuthServerConfig.buildFromMap(configMap); + endpoints = new OAuthEndpoints(config); + } + + @Test + void testProtectedResourceMetadataStructure() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.protectedResourceMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + assertEquals("http://localhost:3001/mcp", response.get("resource")); + assertEquals(List.of("http://localhost:3001"), response.get("authorization_servers")); + assertEquals(List.of("header", "query"), response.get("bearer_methods_supported")); + assertEquals("http://localhost:3001/docs", response.get("resource_documentation")); + } + + @Test + void testProtectedResourceMetadataScopesSplit() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.protectedResourceMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + @SuppressWarnings("unchecked") + List scopes = (List) response.get("scopes_supported"); + assertEquals(3, scopes.size()); + assertTrue(scopes.contains("mcp:read")); + assertTrue(scopes.contains("mcp:admin")); + assertTrue(scopes.contains("openid")); + } + + @Test + void testProtectedResourceMetadataSetsCorsHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + endpoints.protectedResourceMetadata(ctx); + + verify(ctx).header("Access-Control-Allow-Origin", "*"); + } + + @Test + void testProtectedResourceMetadataSetsContentType() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + endpoints.protectedResourceMetadata(ctx); + + verify(ctx).contentType("application/json"); + } + + @Test + void testSplitScopesHandlesEmptyString() { + List result = endpoints.splitScopes(""); + assertTrue(result.isEmpty()); + } + + @Test + void testSplitScopesHandlesNull() { + List result = endpoints.splitScopes(null); + assertTrue(result.isEmpty()); + } + + @Test + void testSplitScopesHandlesMultipleSpaces() { + List result = endpoints.splitScopes("read write admin"); + assertEquals(3, result.size()); + assertEquals("read", result.get(0)); + assertEquals("write", result.get(1)); + assertEquals("admin", result.get(2)); + } + + // Authorization Server Metadata tests (RFC 8414) + + @Test + void testAuthorizationServerMetadataStructure() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.authorizationServerMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + assertEquals("https://auth.example.com", response.get("issuer")); + assertEquals("https://auth.example.com/protocol/openid-connect/auth", + response.get("authorization_endpoint")); + assertEquals("https://auth.example.com/protocol/openid-connect/token", + response.get("token_endpoint")); + assertEquals("https://auth.example.com/protocol/openid-connect/certs", + response.get("jwks_uri")); + } + + @Test + void testAuthorizationServerMetadataRequiredFields() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.authorizationServerMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + // Check all required fields are present + assertTrue(response.containsKey("issuer")); + assertTrue(response.containsKey("authorization_endpoint")); + assertTrue(response.containsKey("token_endpoint")); + assertTrue(response.containsKey("jwks_uri")); + assertTrue(response.containsKey("scopes_supported")); + assertTrue(response.containsKey("response_types_supported")); + assertTrue(response.containsKey("grant_types_supported")); + assertTrue(response.containsKey("token_endpoint_auth_methods_supported")); + } + + @Test + void testAuthorizationServerMetadataResponseTypes() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.authorizationServerMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + @SuppressWarnings("unchecked") + List responseTypes = (List) response.get("response_types_supported"); + assertTrue(responseTypes.contains("code")); + } + + @Test + void testAuthorizationServerMetadataGrantTypes() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.authorizationServerMetadata(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + @SuppressWarnings("unchecked") + List grantTypes = (List) response.get("grant_types_supported"); + assertTrue(grantTypes.contains("authorization_code")); + assertTrue(grantTypes.contains("refresh_token")); + } + + @Test + void testAuthorizationServerMetadataSetsCorsHeaders() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + + endpoints.authorizationServerMetadata(ctx); + + verify(ctx).header("Access-Control-Allow-Origin", "*"); + } + + // OpenID Configuration tests + + @Test + void testOpenidConfigurationIncludesOidcFields() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.openidConfiguration(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + // OIDC-specific fields + assertTrue(response.containsKey("userinfo_endpoint")); + assertTrue(response.containsKey("id_token_signing_alg_values_supported")); + + assertEquals("https://auth.example.com/protocol/openid-connect/userinfo", + response.get("userinfo_endpoint")); + + @SuppressWarnings("unchecked") + List algs = (List) response.get("id_token_signing_alg_values_supported"); + assertTrue(algs.contains("RS256")); + } + + @Test + void testOpenidConfigurationExtendsAuthServerMetadata() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.openidConfiguration(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + // Should include all auth server metadata fields + assertTrue(response.containsKey("issuer")); + assertTrue(response.containsKey("authorization_endpoint")); + assertTrue(response.containsKey("token_endpoint")); + assertTrue(response.containsKey("jwks_uri")); + assertTrue(response.containsKey("scopes_supported")); + } + + // Authorize endpoint tests + + @Test + void testAuthorizeBuildsCorrectRedirectUrl() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + Map> queryParams = new HashMap<>(); + queryParams.put("response_type", List.of("code")); + queryParams.put("client_id", List.of("test-client")); + queryParams.put("redirect_uri", List.of("http://localhost/callback")); + queryParams.put("state", List.of("abc123")); + when(ctx.queryParamMap()).thenReturn(queryParams); + ArgumentCaptor urlCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(HttpStatus.class); + + endpoints.authorize(ctx); + + verify(ctx).redirect(urlCaptor.capture(), statusCaptor.capture()); + String redirectUrl = urlCaptor.getValue(); + + assertTrue(redirectUrl.startsWith("https://auth.example.com/protocol/openid-connect/auth?")); + assertTrue(redirectUrl.contains("response_type=code")); + assertTrue(redirectUrl.contains("client_id=test-client")); + assertTrue(redirectUrl.contains("state=abc123")); + assertEquals(HttpStatus.FOUND, statusCaptor.getValue()); + } + + @Test + void testAuthorizeUrlEncodesParameters() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + Map> queryParams = new HashMap<>(); + queryParams.put("redirect_uri", List.of("http://localhost/callback?foo=bar")); + queryParams.put("scope", List.of("openid profile")); + when(ctx.queryParamMap()).thenReturn(queryParams); + ArgumentCaptor urlCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(HttpStatus.class); + + endpoints.authorize(ctx); + + verify(ctx).redirect(urlCaptor.capture(), statusCaptor.capture()); + String redirectUrl = urlCaptor.getValue(); + + // URL-encoded characters + assertTrue(redirectUrl.contains("redirect_uri=http%3A%2F%2Flocalhost%2Fcallback%3Ffoo%3Dbar")); + assertTrue(redirectUrl.contains("scope=openid+profile")); + } + + @Test + void testAuthorizeIncludesCodeChallenge() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + Map> queryParams = new HashMap<>(); + queryParams.put("code_challenge", List.of("challenge123")); + queryParams.put("code_challenge_method", List.of("S256")); + when(ctx.queryParamMap()).thenReturn(queryParams); + ArgumentCaptor urlCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(HttpStatus.class); + + endpoints.authorize(ctx); + + verify(ctx).redirect(urlCaptor.capture(), statusCaptor.capture()); + String redirectUrl = urlCaptor.getValue(); + + assertTrue(redirectUrl.contains("code_challenge=challenge123")); + assertTrue(redirectUrl.contains("code_challenge_method=S256")); + } + + // Register endpoint tests + + @Test + void testRegisterReturnsConfiguredCredentials() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.bodyAsClass(Map.class)).thenReturn(Map.of("redirect_uris", List.of("http://localhost/callback"))); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.register(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + String clientId = (String) response.get("client_id"); + String clientSecret = (String) response.get("client_secret"); + + // Returns configured credentials, not random ones + assertEquals("test-client-id", clientId); + assertEquals("test-client-secret", clientSecret); + } + + @Test + void testRegisterReturns201Status() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.bodyAsClass(Map.class)).thenReturn(Map.of()); + + endpoints.register(ctx); + + verify(ctx).status(201); + } + + @Test + void testRegisterIncludesRedirectUris() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + List uris = List.of("http://localhost/callback", "http://example.com/callback"); + when(ctx.bodyAsClass(Map.class)).thenReturn(Map.of("redirect_uris", uris)); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.register(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + assertEquals(uris, response.get("redirect_uris")); + } + + @Test + void testRegisterIncludesRequiredFields() { + when(ctx.header(anyString(), anyString())).thenReturn(ctx); + when(ctx.contentType(anyString())).thenReturn(ctx); + when(ctx.status(anyInt())).thenReturn(ctx); + when(ctx.bodyAsClass(Map.class)).thenReturn(Map.of()); + ArgumentCaptor> captor = ArgumentCaptor.forClass(Map.class); + + endpoints.register(ctx); + + verify(ctx).json(captor.capture()); + Map response = captor.getValue(); + + assertTrue(response.containsKey("client_id")); + assertTrue(response.containsKey("client_secret")); + assertTrue(response.containsKey("client_id_issued_at")); + assertTrue(response.containsKey("client_secret_expires_at")); + assertTrue(response.containsKey("token_endpoint_auth_method")); + assertTrue(response.containsKey("grant_types")); + assertTrue(response.containsKey("response_types")); + } + + // URL encoding tests + + @Test + void testUrlEncodeHandlesNull() { + String result = endpoints.urlEncode(null); + assertEquals("", result); + } + + @Test + void testUrlEncodeEncodesSpecialChars() { + String result = endpoints.urlEncode("hello world&foo=bar"); + assertEquals("hello+world%26foo%3Dbar", result); + } + + // Random string generation tests + + @Test + void testGenerateRandomStringLength() { + String result = endpoints.generateRandomString(16); + assertEquals(16, result.length()); + } + + @Test + void testGenerateRandomStringAlphanumeric() { + String result = endpoints.generateRandomString(100); + assertTrue(result.matches("[A-Za-z0-9]+")); + } +} diff --git a/examples/auth/src/test/java/com/gophersecurity/mcp/auth/tools/WeatherToolsTest.java b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/tools/WeatherToolsTest.java new file mode 100644 index 00000000..d8bcd831 --- /dev/null +++ b/examples/auth/src/test/java/com/gophersecurity/mcp/auth/tools/WeatherToolsTest.java @@ -0,0 +1,192 @@ +package com.gophersecurity.mcp.auth.tools; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.gophersecurity.mcp.auth.config.AuthServerConfig; +import com.gophersecurity.mcp.auth.middleware.OAuthAuthMiddleware; +import com.gophersecurity.mcp.auth.model.ToolResult; +import com.gophersecurity.mcp.auth.routes.McpHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for WeatherTools. + */ +class WeatherToolsTest { + + private ObjectMapper mapper = new ObjectMapper(); + + @BeforeEach + void setUp() { + } + + // get-weather tests + + @Test + void testGetWeatherReturnsDeterministicData() throws Exception { + Map args = Map.of("city", "London"); + + ToolResult result1 = WeatherTools.getWeather(args); + ToolResult result2 = WeatherTools.getWeather(args); + + assertFalse(result1.isError()); + assertEquals(result1.getContent().get(0).getText(), result2.getContent().get(0).getText()); + } + + @Test + void testGetWeatherReturnsValidJson() throws Exception { + Map args = Map.of("city", "Paris"); + + ToolResult result = WeatherTools.getWeather(args); + + assertFalse(result.isError()); + String json = result.getContent().get(0).getText(); + + @SuppressWarnings("unchecked") + Map weather = mapper.readValue(json, Map.class); + assertEquals("Paris", weather.get("city")); + assertNotNull(weather.get("temperature")); + assertNotNull(weather.get("condition")); + assertNotNull(weather.get("humidity")); + assertEquals("celsius", weather.get("unit")); + } + + @Test + void testGetWeatherReturnsErrorForMissingCity() { + Map args = new HashMap<>(); + + ToolResult result = WeatherTools.getWeather(args); + + assertTrue(result.isError()); + } + + // get-forecast tests + + @Test + void testGetForecastReturns5Days() throws Exception { + Map args = Map.of("city", "Tokyo"); + + ToolResult result = WeatherTools.getForecast(args); + + assertFalse(result.isError()); + String json = result.getContent().get(0).getText(); + + @SuppressWarnings("unchecked") + Map forecast = mapper.readValue(json, Map.class); + assertEquals("Tokyo", forecast.get("city")); + + @SuppressWarnings("unchecked") + java.util.List> days = (java.util.List>) forecast.get("forecast"); + assertEquals(5, days.size()); + } + + @Test + void testGetForecastChecksScope() { + // Create middleware with auth enabled + Map configMap = new HashMap<>(); + configMap.put("auth_server_url", "https://auth.example.com"); + configMap.put("client_id", "test"); + configMap.put("client_secret", "secret"); + AuthServerConfig config = AuthServerConfig.buildFromMap(configMap); + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(null, config); + + // Should require scope when auth is enabled (but authClient is null, so scope check is bypassed) + assertTrue(WeatherTools.requiresScope(authMiddleware, "mcp:read")); + } + + @Test + void testGetForecastDoesNotCheckScopeWhenAuthDisabled() { + AuthServerConfig config = AuthServerConfig.defaultDisabled(); + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(null, config); + + assertFalse(WeatherTools.requiresScope(authMiddleware, "mcp:read")); + } + + // get-weather-alerts tests + + @Test + void testGetWeatherAlertsReturnsValidJson() throws Exception { + Map args = Map.of("region", "California"); + + ToolResult result = WeatherTools.getWeatherAlerts(args); + + assertFalse(result.isError()); + String json = result.getContent().get(0).getText(); + + @SuppressWarnings("unchecked") + Map alerts = mapper.readValue(json, Map.class); + assertEquals("California", alerts.get("region")); + assertNotNull(alerts.get("alerts")); + } + + @Test + void testGetWeatherAlertsChecksScope() { + Map configMap = new HashMap<>(); + configMap.put("auth_server_url", "https://auth.example.com"); + configMap.put("client_id", "test"); + configMap.put("client_secret", "secret"); + AuthServerConfig config = AuthServerConfig.buildFromMap(configMap); + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(null, config); + + assertTrue(WeatherTools.requiresScope(authMiddleware, "mcp:admin")); + } + + // getCondition tests + + @Test + void testGetConditionReturnsValidConditions() { + String[] validConditions = {"Sunny", "Cloudy", "Rainy", "Windy", "Snowy"}; + + for (int i = 0; i < 100; i++) { + String condition = WeatherTools.getCondition(i); + boolean found = false; + for (String valid : validConditions) { + if (valid.equals(condition)) { + found = true; + break; + } + } + assertTrue(found, "Invalid condition: " + condition); + } + } + + @Test + void testGetConditionIsDeterministic() { + assertEquals(WeatherTools.getCondition(42), WeatherTools.getCondition(42)); + assertEquals(WeatherTools.getCondition(100), WeatherTools.getCondition(100)); + } + + // accessDenied tests + + @Test + void testAccessDeniedFormat() throws Exception { + ToolResult result = WeatherTools.accessDenied("mcp:admin"); + + assertTrue(result.isError()); + String json = result.getContent().get(0).getText(); + + @SuppressWarnings("unchecked") + Map error = mapper.readValue(json, Map.class); + assertEquals("access_denied", error.get("error")); + assertTrue(error.get("message").contains("mcp:admin")); + } + + // Registration tests + + @Test + void testRegisterAddsAllTools() { + AuthServerConfig config = AuthServerConfig.defaultDisabled(); + OAuthAuthMiddleware authMiddleware = new OAuthAuthMiddleware(null, config); + McpHandler mcp = new McpHandler(authMiddleware); + + WeatherTools.register(mcp, authMiddleware); + + assertTrue(mcp.getTools().containsKey("get-weather")); + assertTrue(mcp.getTools().containsKey("get-forecast")); + assertTrue(mcp.getTools().containsKey("get-weather-alerts")); + } +} diff --git a/pom.xml b/pom.xml index fcc91ac5..a853e0f1 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.gophersecurity gopher-orch - 0.1.0 + 0.1.2 jar gopher-orch @@ -74,7 +74,7 @@ - + org.codehaus.mojo build-helper-maven-plugin @@ -95,7 +95,7 @@ - + org.apache.maven.plugins maven-compiler-plugin @@ -103,6 +103,12 @@ 1.8 1.8 + + auth/** + + + auth/** + diff --git a/src/main/java/com/gophersecurity/orch/auth/AuthContext.java b/src/main/java/com/gophersecurity/orch/auth/AuthContext.java new file mode 100644 index 00000000..a2a4039c --- /dev/null +++ b/src/main/java/com/gophersecurity/orch/auth/AuthContext.java @@ -0,0 +1,131 @@ +package com.gophersecurity.orch.auth; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Authentication context from JWT token validation. + * + *

Contains user information extracted from a validated token, including user ID, scopes, + * audience, and expiration time. + */ +public class AuthContext { + + private final String userId; + private final String scopes; + private final String audience; + private final long tokenExpiry; + private final boolean authenticated; + private final Set scopeSet; + + /** + * Create an authentication context. + * + * @param userId user identifier from token subject + * @param scopes space-separated list of scopes + * @param audience token audience + * @param tokenExpiry token expiration timestamp (Unix epoch seconds) + * @param authenticated whether the user is authenticated + */ + public AuthContext( + String userId, + String scopes, + String audience, + long tokenExpiry, + boolean authenticated) { + this.userId = userId != null ? userId : ""; + this.scopes = scopes != null ? scopes : ""; + this.audience = audience != null ? audience : ""; + this.tokenExpiry = tokenExpiry; + this.authenticated = authenticated; + + // Pre-compute scope set for efficient lookups + if (this.scopes.isEmpty()) { + this.scopeSet = Collections.emptySet(); + } else { + this.scopeSet = new HashSet<>(Arrays.asList(this.scopes.split("\\s+"))); + } + } + + /** + * Create an empty, unauthenticated context. + * + * @return empty auth context + */ + public static AuthContext empty() { + return new AuthContext("", "", "", 0, false); + } + + /** + * Create an anonymous authenticated context with specified scopes. + * + *

Useful for development mode when authentication is disabled but scope checking is still + * active. + * + * @param scopes space-separated list of scopes + * @return anonymous auth context with given scopes + */ + public static AuthContext anonymous(String scopes) { + return new AuthContext("anonymous", scopes, "", 0, true); + } + + /** + * Check if the context has a specific scope. + * + * @param requiredScope the scope to check for + * @return true if the scope is present + */ + public boolean hasScope(String requiredScope) { + if (requiredScope == null || requiredScope.isEmpty()) { + return true; + } + return scopeSet.contains(requiredScope); + } + + /** + * Get the user identifier. + * + * @return user ID from token subject + */ + public String getUserId() { + return userId; + } + + /** + * Get the space-separated scopes string. + * + * @return scopes string + */ + public String getScopes() { + return scopes; + } + + /** + * Get the token audience. + * + * @return audience string + */ + public String getAudience() { + return audience; + } + + /** + * Get the token expiration timestamp. + * + * @return Unix epoch seconds + */ + public long getTokenExpiry() { + return tokenExpiry; + } + + /** + * Check if the context represents an authenticated user. + * + * @return true if authenticated + */ + public boolean isAuthenticated() { + return authenticated; + } +} diff --git a/src/main/java/com/gophersecurity/orch/auth/GopherAuthClient.java b/src/main/java/com/gophersecurity/orch/auth/GopherAuthClient.java new file mode 100644 index 00000000..df214685 --- /dev/null +++ b/src/main/java/com/gophersecurity/orch/auth/GopherAuthClient.java @@ -0,0 +1,34 @@ +package com.gophersecurity.orch.auth; + +/** + * Interface for JWT token validation using gopher-auth. + * + *

This interface abstracts the FFI calls to the gopher-auth native library. + */ +public interface GopherAuthClient { + + /** + * Validate a JWT token. + * + * @param token JWT token string + * @param clockSkewSeconds allowed clock skew in seconds + * @return validation result + */ + ValidationResult validateToken(String token, int clockSkewSeconds); + + /** + * Extract payload from a JWT token. + * + * @param token JWT token string + * @return extracted payload + * @throws RuntimeException if payload extraction fails + */ + TokenPayload extractPayload(String token); + + /** + * Check if the client is initialized and ready. + * + * @return true if client is ready + */ + boolean isReady(); +} diff --git a/src/main/java/com/gophersecurity/orch/auth/TokenPayload.java b/src/main/java/com/gophersecurity/orch/auth/TokenPayload.java new file mode 100644 index 00000000..a3c2f442 --- /dev/null +++ b/src/main/java/com/gophersecurity/orch/auth/TokenPayload.java @@ -0,0 +1,41 @@ +package com.gophersecurity.orch.auth; + +/** Extracted JWT token payload. */ +public class TokenPayload { + + private final String subject; + private final String scopes; + private final String audience; + private final long expiration; + + /** + * Create a token payload. + * + * @param subject user identifier (sub claim) + * @param scopes space-separated scopes (scope claim) + * @param audience token audience (aud claim) + * @param expiration expiration timestamp (exp claim) + */ + public TokenPayload(String subject, String scopes, String audience, long expiration) { + this.subject = subject; + this.scopes = scopes; + this.audience = audience; + this.expiration = expiration; + } + + public String getSubject() { + return subject; + } + + public String getScopes() { + return scopes; + } + + public String getAudience() { + return audience; + } + + public long getExpiration() { + return expiration; + } +} diff --git a/src/main/java/com/gophersecurity/orch/auth/ValidationResult.java b/src/main/java/com/gophersecurity/orch/auth/ValidationResult.java new file mode 100644 index 00000000..19a22cb8 --- /dev/null +++ b/src/main/java/com/gophersecurity/orch/auth/ValidationResult.java @@ -0,0 +1,46 @@ +package com.gophersecurity.orch.auth; + +/** Result of JWT token validation. */ +public class ValidationResult { + + private final boolean valid; + private final String errorMessage; + + /** + * Create a validation result. + * + * @param valid whether the token is valid + * @param errorMessage error message if invalid, null otherwise + */ + public ValidationResult(boolean valid, String errorMessage) { + this.valid = valid; + this.errorMessage = errorMessage; + } + + /** + * Create a successful validation result. + * + * @return successful result + */ + public static ValidationResult success() { + return new ValidationResult(true, null); + } + + /** + * Create a failed validation result. + * + * @param errorMessage error description + * @return failed result + */ + public static ValidationResult failure(String errorMessage) { + return new ValidationResult(false, errorMessage); + } + + public boolean isValid() { + return valid; + } + + public String getErrorMessage() { + return errorMessage; + } +} diff --git a/src/test/java/com/gophersecurity/orch/auth/AuthContextTest.java b/src/test/java/com/gophersecurity/orch/auth/AuthContextTest.java new file mode 100644 index 00000000..a87a01f4 --- /dev/null +++ b/src/test/java/com/gophersecurity/orch/auth/AuthContextTest.java @@ -0,0 +1,94 @@ +package com.gophersecurity.orch.auth; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** Unit tests for AuthContext. */ +class AuthContextTest { + + @Test + void testHasScopePresent() { + AuthContext ctx = new AuthContext("user1", "read write admin", "api", 12345, true); + + assertTrue(ctx.hasScope("read")); + assertTrue(ctx.hasScope("write")); + assertTrue(ctx.hasScope("admin")); + } + + @Test + void testHasScopeAbsent() { + AuthContext ctx = new AuthContext("user1", "read write", "api", 12345, true); + + assertFalse(ctx.hasScope("admin")); + assertFalse(ctx.hasScope("delete")); + } + + @Test + void testHasScopeEmptyScopes() { + AuthContext ctx = new AuthContext("user1", "", "api", 12345, true); + + assertFalse(ctx.hasScope("read")); + assertFalse(ctx.hasScope("admin")); + } + + @Test + void testHasScopeNullScopes() { + AuthContext ctx = new AuthContext("user1", null, "api", 12345, true); + + assertFalse(ctx.hasScope("read")); + assertFalse(ctx.hasScope("admin")); + } + + @Test + void testHasScopeEmptyRequired() { + AuthContext ctx = new AuthContext("user1", "read write", "api", 12345, true); + + // Empty or null required scope should return true + assertTrue(ctx.hasScope("")); + assertTrue(ctx.hasScope(null)); + } + + @Test + void testEmpty() { + AuthContext ctx = AuthContext.empty(); + + assertEquals("", ctx.getUserId()); + assertEquals("", ctx.getScopes()); + assertEquals("", ctx.getAudience()); + assertEquals(0, ctx.getTokenExpiry()); + assertFalse(ctx.isAuthenticated()); + } + + @Test + void testAnonymous() { + AuthContext ctx = AuthContext.anonymous("mcp:read mcp:admin"); + + assertEquals("anonymous", ctx.getUserId()); + assertEquals("mcp:read mcp:admin", ctx.getScopes()); + assertTrue(ctx.isAuthenticated()); + assertTrue(ctx.hasScope("mcp:read")); + assertTrue(ctx.hasScope("mcp:admin")); + } + + @Test + void testGetters() { + AuthContext ctx = new AuthContext("user123", "scope1 scope2", "my-api", 9999999999L, true); + + assertEquals("user123", ctx.getUserId()); + assertEquals("scope1 scope2", ctx.getScopes()); + assertEquals("my-api", ctx.getAudience()); + assertEquals(9999999999L, ctx.getTokenExpiry()); + assertTrue(ctx.isAuthenticated()); + } + + @Test + void testMultipleSpacesBetweenScopes() { + // Test that multiple spaces are handled correctly + AuthContext ctx = new AuthContext("user1", "read write admin", "api", 12345, true); + + assertTrue(ctx.hasScope("read")); + assertTrue(ctx.hasScope("write")); + assertTrue(ctx.hasScope("admin")); + } +} diff --git a/third_party/gopher-orch b/third_party/gopher-orch index 6b45ffbb..c8e7c406 160000 --- a/third_party/gopher-orch +++ b/third_party/gopher-orch @@ -1 +1 @@ -Subproject commit 6b45ffbbee74d5ae034008fc2cb2a927f3131992 +Subproject commit c8e7c40606db330142632ecf90aaa8777bc42a3a