diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c8875e5..ee88466 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,6 +45,39 @@ jobs: - name: Run E2E smoke test run: make e2e-smoke-test + - name: Set up Java + uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '11' + + - name: Setup proto files + run: ./scripts/setup-proto-files.sh + + - name: Generate proto descriptors + run: make proto-generate + + - name: Download WireMock + run: make mock-download + + - name: Start WireMock + run: make mock-start + + - name: Run integration tests + run: make test-integration-coverage + + - name: Stop WireMock + if: always() + run: make mock-stop + + - name: Upload WireMock logs on failure + if: failure() + uses: actions/upload-artifact@v4 + with: + name: wiremock-logs + path: wiremock/wiremock.log + if-no-files-found: ignore + - name: Upload test results to Codecov uses: codecov/test-results-action@v1 with: @@ -56,3 +89,14 @@ jobs: files: ./coverage.out token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false + flags: unit + name: unit-tests + + - name: Upload integration test coverage to Codecov + uses: codecov/codecov-action@v5 + with: + file: ./coverage-integration.out + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + flags: integration + name: integration-tests diff --git a/cmd/stackrox-mcp/main.go b/cmd/stackrox-mcp/main.go index c1739df..3810e82 100644 --- a/cmd/stackrox-mcp/main.go +++ b/cmd/stackrox-mcp/main.go @@ -4,28 +4,12 @@ package main import ( "context" "flag" - "log/slog" - "os" - "os/signal" - "syscall" - "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/app" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/logging" - "github.com/stackrox/stackrox-mcp/internal/server" - "github.com/stackrox/stackrox-mcp/internal/toolsets" - toolsetConfig "github.com/stackrox/stackrox-mcp/internal/toolsets/config" - toolsetVulnerability "github.com/stackrox/stackrox-mcp/internal/toolsets/vulnerability" ) -// getToolsets initializes and returns all available toolsets. -func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset { - return []toolsets.Toolset{ - toolsetConfig.NewToolset(cfg, c), - toolsetVulnerability.NewToolset(cfg, c), - } -} - func main() { logging.SetupLogging() @@ -38,38 +22,9 @@ func main() { logging.Fatal("Failed to load configuration", err) } - // Log full configuration with sensitive data redacted. - slog.Info("Configuration loaded successfully", "config", cfg.Redacted()) - - stackroxClient, err := client.NewClient(&cfg.Central) - if err != nil { - logging.Fatal("Failed to create StackRox client", err) - } - - registry := toolsets.NewRegistry(cfg, getToolsets(cfg, stackroxClient)) - srv := server.NewServer(cfg, registry) - - // Set up context with signal handling for graceful shutdown. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - err = stackroxClient.Connect(ctx) - if err != nil { - logging.Fatal("Failed to connect to StackRox server", err) - } - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - go func() { - <-sigChan - slog.Info("Received shutdown signal") - cancel() - }() - - slog.Info("Starting StackRox MCP server") + ctx := context.Background() - if err := srv.Start(ctx); err != nil { + if err := app.Run(ctx, cfg, nil, nil); err != nil { logging.Fatal("Server error", err) } } diff --git a/cmd/stackrox-mcp/main_test.go b/cmd/stackrox-mcp/main_test.go index 0202fcb..34bdd04 100644 --- a/cmd/stackrox-mcp/main_test.go +++ b/cmd/stackrox-mcp/main_test.go @@ -14,10 +14,20 @@ import ( "github.com/stackrox/stackrox-mcp/internal/server" "github.com/stackrox/stackrox-mcp/internal/testutil" "github.com/stackrox/stackrox-mcp/internal/toolsets" + toolsetConfig "github.com/stackrox/stackrox-mcp/internal/toolsets/config" + toolsetVulnerability "github.com/stackrox/stackrox-mcp/internal/toolsets/vulnerability" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// getToolsets initializes and returns all available toolsets. +func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset { + return []toolsets.Toolset{ + toolsetConfig.NewToolset(cfg, c), + toolsetVulnerability.NewToolset(cfg, c), + } +} + func TestGetToolsets(t *testing.T) { allToolsets := getToolsets(&config.Config{}, &client.Client{}) @@ -46,7 +56,7 @@ func TestGracefulShutdown(t *testing.T) { errChan := make(chan error, 1) go func() { - errChan <- srv.Start(ctx) + errChan <- srv.Start(ctx, nil, nil) }() serverURL := "http://" + net.JoinHostPort(cfg.Server.Address, strconv.Itoa(cfg.Server.Port)) diff --git a/go.mod b/go.mod index 486f29d..9673b45 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.11.1 golang.stackrox.io/grpc-http1 v0.5.1 google.golang.org/grpc v1.79.2 + google.golang.org/protobuf v1.36.10 ) require ( @@ -40,7 +41,6 @@ require ( golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/integration/fixtures.go b/integration/fixtures.go new file mode 100644 index 0000000..fcdb5b9 --- /dev/null +++ b/integration/fixtures.go @@ -0,0 +1,29 @@ +//go:build integration + +package integration + +// Log4ShellFixture contains expected data from log4j_cve.json fixture. +var Log4ShellFixture = struct { + CVEName string + DeploymentCount int + DeploymentNames []string +}{ + CVEName: "CVE-2021-44228", + DeploymentCount: 3, + DeploymentNames: []string{"elasticsearch", "kafka-broker", "spring-boot-app"}, +} + +// AllClustersFixture contains expected data from all_clusters.json fixture. +var AllClustersFixture = struct { + TotalCount int + ClusterNames []string +}{ + TotalCount: 5, + ClusterNames: []string{ + "production-cluster", + "staging-cluster", + "staging-central-cluster", + "development-cluster", + "production-cluster-eu", + }, +} diff --git a/integration/integration_test.go b/integration/integration_test.go new file mode 100644 index 0000000..e0a3f9e --- /dev/null +++ b/integration/integration_test.go @@ -0,0 +1,183 @@ +//go:build integration + +package integration + +import ( + "context" + "io" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stackrox/stackrox-mcp/internal/app" + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupInitializedClient creates an initialized MCP client for testing. +func setupInitializedClient(t *testing.T) *testutil.MCPTestClient { + t.Helper() + + client, err := createMCPClient(t) + require.NoError(t, err, "Failed to create MCP client") + t.Cleanup(func() { client.Close() }) + + return client +} + +// callToolAndGetResult calls a tool and verifies it succeeds. +func callToolAndGetResult(t *testing.T, client *testutil.MCPTestClient, toolName string, args map[string]any) *mcp.CallToolResult { + t.Helper() + + ctx := context.Background() + result, err := client.CallTool(ctx, toolName, args) + require.NoError(t, err) + testutil.RequireNoError(t, result) + + return result +} + +// getTextContent extracts text from the first content item. +func getTextContent(t *testing.T, result *mcp.CallToolResult) string { + t.Helper() + require.NotEmpty(t, result.Content, "should have content in response") + + textContent, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok, "expected TextContent, got %T", result.Content[0]) + + return textContent.Text +} + +// TestIntegration_ListTools verifies that all expected tools are registered. +func TestIntegration_ListTools(t *testing.T) { + client := setupInitializedClient(t) + + ctx := context.Background() + result, err := client.ListTools(ctx) + require.NoError(t, err) + + // Verify we have tools registered + assert.NotEmpty(t, result.Tools, "should have tools registered") + + // Check for specific tools we expect + toolNames := make([]string, 0, len(result.Tools)) + for _, tool := range result.Tools { + toolNames = append(toolNames, tool.Name) + } + + assert.Contains(t, toolNames, "get_deployments_for_cve", "should have get_deployments_for_cve tool") + assert.Contains(t, toolNames, "list_clusters", "should have list_clusters tool") +} + +// TestIntegration_ToolCalls tests successful tool calls using table-driven tests. +func TestIntegration_ToolCalls(t *testing.T) { + tests := map[string]struct { + toolName string + args map[string]any + expectedInText []string // strings that must appear in response + }{ + "get_deployments_for_cve with Log4Shell": { + toolName: "get_deployments_for_cve", + args: map[string]any{"cveName": Log4ShellFixture.CVEName}, + expectedInText: Log4ShellFixture.DeploymentNames, + }, + "get_deployments_for_cve with non-existent CVE": { + toolName: "get_deployments_for_cve", + args: map[string]any{"cveName": "CVE-9999-99999"}, + expectedInText: []string{`"deployments":[]`}, + }, + "list_clusters": { + toolName: "list_clusters", + args: map[string]any{}, + expectedInText: AllClustersFixture.ClusterNames, + }, + "get_clusters_with_orchestrator_cve": { + toolName: "get_clusters_with_orchestrator_cve", + args: map[string]any{"cveName": "CVE-2099-00001"}, + expectedInText: []string{`"clusters":`}, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + client := setupInitializedClient(t) + result := callToolAndGetResult(t, client, tt.toolName, tt.args) + + responseText := getTextContent(t, result) + for _, expected := range tt.expectedInText { + assert.Contains(t, responseText, expected) + } + }) + } +} + +// TestIntegration_ToolCallErrors tests error handling using table-driven tests. +func TestIntegration_ToolCallErrors(t *testing.T) { + tests := map[string]struct { + toolName string + args map[string]any + expectedErrorMsg string + }{ + "get_deployments_for_cve missing CVE name": { + toolName: "get_deployments_for_cve", + args: map[string]any{}, + expectedErrorMsg: "cveName", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + client := setupInitializedClient(t) + + ctx := context.Background() + _, err := client.CallTool(ctx, tt.toolName, tt.args) + + // Validation errors are returned as protocol errors, not tool errors + require.Error(t, err, "should receive protocol error for invalid params") + assert.Contains(t, err.Error(), tt.expectedErrorMsg) + }) + } +} + +// createTestConfig creates a test configuration for the MCP server. +func createTestConfig() *config.Config { + return &config.Config{ + Central: config.CentralConfig{ + URL: "localhost:8081", + AuthType: "static", + APIToken: "test-token-admin", + InsecureSkipTLSVerify: true, + RequestTimeout: 30 * time.Second, + MaxRetries: 3, + InitialBackoff: time.Second, + MaxBackoff: 10 * time.Second, + }, + Server: config.ServerConfig{ + Type: config.ServerTypeStdio, + }, + Tools: config.ToolsConfig{ + Vulnerability: config.ToolsetVulnerabilityConfig{ + Enabled: true, + }, + ConfigManager: config.ToolConfigManagerConfig{ + Enabled: true, + }, + }, + } +} + +// createMCPClient is a helper function that creates an MCP client with the test configuration. +func createMCPClient(t *testing.T) (*testutil.MCPTestClient, error) { + t.Helper() + + cfg := createTestConfig() + + // Create a run function that wraps app.Run with the config + runFunc := func(ctx context.Context, stdin io.ReadCloser, stdout io.WriteCloser) error { + return app.Run(ctx, cfg, stdin, stdout) + } + + return testutil.NewMCPTestClient(t, runFunc) +} diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000..567cc0f --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,67 @@ +// Package app contains the main application logic for the stackrox-mcp server. +// This is separated from the main package to allow tests to run the server in-process. +package app + +import ( + "context" + "io" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/pkg/errors" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/server" + "github.com/stackrox/stackrox-mcp/internal/toolsets" + toolsetConfig "github.com/stackrox/stackrox-mcp/internal/toolsets/config" + toolsetVulnerability "github.com/stackrox/stackrox-mcp/internal/toolsets/vulnerability" +) + +// getToolsets initializes and returns all available toolsets. +func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset { + return []toolsets.Toolset{ + toolsetConfig.NewToolset(cfg, c), + toolsetVulnerability.NewToolset(cfg, c), + } +} + +// Run executes the MCP server with the given configuration and I/O streams. +// If stdin/stdout are nil, os.Stdin/os.Stdout will be used. +// This function is extracted from main() to allow tests to run the server in-process. +func Run(ctx context.Context, cfg *config.Config, stdin io.ReadCloser, stdout io.WriteCloser) error { + // Log full configuration with sensitive data redacted. + slog.Info("Configuration loaded successfully", "config", cfg.Redacted()) + + stackroxClient, err := client.NewClient(&cfg.Central) + if err != nil { + return errors.Wrap(err, "failed to create client") + } + + registry := toolsets.NewRegistry(cfg, getToolsets(cfg, stackroxClient)) + srv := server.NewServer(cfg, registry) + + err = stackroxClient.Connect(ctx) + if err != nil { + return errors.Wrap(err, "failed to connect to central") + } + + // Set up signal handling for graceful shutdown. + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + // Create a cancellable context from the input context + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-sigChan + slog.Info("Received shutdown signal") + cancel() + }() + + slog.Info("Starting StackRox MCP server") + + return errors.Wrap(srv.Start(ctx, stdin, stdout), "failed to start server") +} diff --git a/internal/server/server.go b/internal/server/server.go index 5309b86..8740f5f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" + "io" "log/slog" "net" "net/http" @@ -50,13 +51,35 @@ func NewServer(cfg *config.Config, registry *toolsets.Registry) *Server { } // Start starts the HTTP server with Streamable HTTP transport. -func (s *Server) Start(ctx context.Context) error { +// If stdin/stdout are provided (non-nil), they will be used for stdio transport. +// If they are nil, os.Stdin/os.Stdout will be used (production mode). +func (s *Server) Start(ctx context.Context, stdin io.ReadCloser, stdout io.WriteCloser) error { s.registerTools() if s.cfg.Server.Type == config.ServerTypeStdio { - return errors.Wrap(s.mcp.Run(ctx, &mcp.StdioTransport{}), "running mcp over stdio") + return s.startStdio(ctx, stdin, stdout) } + return s.startHTTP(ctx) +} + +func (s *Server) startStdio(ctx context.Context, stdin io.ReadCloser, stdout io.WriteCloser) error { + var transport mcp.Transport + if stdin != nil && stdout != nil { + // Use custom stdin/stdout (for testing) + transport = &mcp.IOTransport{ + Reader: stdin, + Writer: stdout, + } + } else { + // Use os.Stdin/os.Stdout (production) + transport = &mcp.StdioTransport{} + } + + return errors.Wrap(s.mcp.Run(ctx, transport), "running mcp over stdio") +} + +func (s *Server) startHTTP(ctx context.Context) error { // Create a new ServeMux for routing. mux := http.NewServeMux() s.registerRouteHealth(mux) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 156b22a..2968781 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -133,7 +133,7 @@ func TestServer_Start(t *testing.T) { errChan := make(chan error, 1) go func() { - errChan <- srv.Start(ctx) + errChan <- srv.Start(ctx, nil, nil) }() serverURL := "http://" + net.JoinHostPort(cfg.Server.Address, strconv.Itoa(cfg.Server.Port)) @@ -180,7 +180,7 @@ func TestServer_HealthEndpoint(t *testing.T) { errChan := make(chan error, 1) go func() { - errChan <- srv.Start(ctx) + errChan <- srv.Start(ctx, nil, nil) }() serverURL := "http://" + net.JoinHostPort(cfg.Server.Address, strconv.Itoa(cfg.Server.Port)) diff --git a/internal/testutil/command.go b/internal/testutil/command.go new file mode 100644 index 0000000..45dd6b4 --- /dev/null +++ b/internal/testutil/command.go @@ -0,0 +1,21 @@ +package testutil + +import ( + "context" + "os/exec" + "strings" +) + +// RunCommand executes a shell command and returns the output and error. +func RunCommand(command string) (string, error) { + parts := strings.Fields(command) + if len(parts) == 0 { + return "", nil + } + + // #nosec G204 - This is a test utility function with controlled input + cmd := exec.CommandContext(context.Background(), parts[0], parts[1:]...) + output, err := cmd.CombinedOutput() + + return string(output), err +} diff --git a/internal/testutil/mcp_client.go b/internal/testutil/mcp_client.go new file mode 100644 index 0000000..5f7f828 --- /dev/null +++ b/internal/testutil/mcp_client.go @@ -0,0 +1,146 @@ +package testutil + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// MCPTestClient wraps the official MCP SDK client for testing purposes. +type MCPTestClient struct { + session *mcp.ClientSession + cancel context.CancelFunc + errCh chan error + t *testing.T +} + +// ServerRunFunc is a function that runs the MCP server with the given context and I/O streams. +// This allows tests to inject the server run function without creating circular dependencies. +type ServerRunFunc func(ctx context.Context, stdin io.ReadCloser, stdout io.WriteCloser) error + +// NewMCPTestClient creates a new MCP test client that starts the MCP server in-process with stdio transport. +// The runFunc parameter should be a function that starts the MCP server (typically app.Run wrapped with config). +func NewMCPTestClient(t *testing.T, runFunc ServerRunFunc) (*MCPTestClient, error) { + t.Helper() + + // Create pipes for bidirectional communication + // Server reads from serverStdin, client writes to clientStdout (same pipe) + serverStdin, clientStdout := io.Pipe() + // Server writes to serverStdout, client reads from clientStdin (same pipe) + clientStdin, serverStdout := io.Pipe() + + // Start server in goroutine + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + err := runFunc(ctx, serverStdin, serverStdout) + if err != nil && !errors.Is(err, context.Canceled) { + t.Logf("MCP server error: %v", err) + + errCh <- err + } + }() + + // Create MCP client using official SDK + client := mcp.NewClient( + &mcp.Implementation{ + Name: "mcp-test-client", + Version: "1.0.0", + }, + nil, // No custom options needed for basic testing + ) + + // Create IO transport for client + transport := &mcp.IOTransport{ + Reader: clientStdin, // Client reads from this pipe (server writes) + Writer: clientStdout, // Client writes to this pipe (server reads) + } + + // Connect and initialize + session, err := client.Connect(ctx, transport, &mcp.ClientSessionOptions{}) + if err != nil { + cancel() + + _ = clientStdout.Close() + _ = clientStdin.Close() + + return nil, errors.New("failed to connect to MCP server: " + err.Error()) + } + + return &MCPTestClient{ + session: session, + cancel: cancel, + errCh: errCh, + t: t, + }, nil +} + +// Close stops the MCP server and cleans up resources. +func (c *MCPTestClient) Close() error { + if err := c.session.Close(); err != nil { + c.t.Logf("Error closing session: %v", err) + } + + c.cancel() + + // Wait for server to finish (with timeout) + select { + case err := <-c.errCh: + if err != nil && !errors.Is(err, context.Canceled) { + return err + } + default: + // Server is still running or finished cleanly + } + + return nil +} + +// ListTools returns all available tools from the server. +func (c *MCPTestClient) ListTools(ctx context.Context) (*mcp.ListToolsResult, error) { + result, err := c.session.ListTools(ctx, nil) + if err != nil { + return nil, errors.New("failed to list tools: " + err.Error()) + } + + return result, nil +} + +// CallTool invokes a tool with the given name and arguments. +func (c *MCPTestClient) CallTool( + ctx context.Context, + toolName string, + args map[string]any, +) (*mcp.CallToolResult, error) { + result, err := c.session.CallTool(ctx, &mcp.CallToolParams{ + Name: toolName, + Arguments: args, + }) + if err != nil { + return nil, errors.New("failed to call tool: " + err.Error()) + } + + return result, nil +} + +// RequireNoError asserts that the tool call result does not contain an error. +func RequireNoError(t *testing.T, result *mcp.CallToolResult) { + t.Helper() + + if result.IsError { + // Extract error message from content + errMsg := "unknown error" + + if len(result.Content) > 0 { + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { + errMsg = textContent.Text + } + } + + t.Fatalf("expected no error, got: %s", errMsg) + } +}