From e53103a2d55a25abafb95eeadf689edc60d193c2 Mon Sep 17 00:00:00 2001 From: Ravi Suhag Date: Sat, 28 Mar 2026 14:35:01 -0500 Subject: [PATCH] feat: add MCP server for AI agent tool access Add Model Context Protocol (MCP) server integration to Compass, enabling AI coding tools to discover and query assets via the /mcp endpoint. Also adds .mcp.json to .gitignore and updates docker-compose postgres port. --- .gitignore | 1 + docker-compose.yaml | 2 +- go.mod | 5 +- go.sum | 10 +- internal/mcp/format.go | 201 +++++++++++++++++++++++ internal/mcp/handlers.go | 127 +++++++++++++++ internal/mcp/handlers_test.go | 298 ++++++++++++++++++++++++++++++++++ internal/mcp/server.go | 57 +++++++ internal/mcp/server_test.go | 229 ++++++++++++++++++++++++++ internal/mcp/tools.go | 77 +++++++++ internal/server/bootstrap.go | 5 + internal/server/server.go | 8 + 12 files changed, 1016 insertions(+), 4 deletions(-) create mode 100644 internal/mcp/format.go create mode 100644 internal/mcp/handlers.go create mode 100644 internal/mcp/handlers_test.go create mode 100644 internal/mcp/server.go create mode 100644 internal/mcp/server_test.go create mode 100644 internal/mcp/tools.go diff --git a/.gitignore b/.gitignore index cc1ee884..e79a4f01 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ vendor/ /config.yaml compass.yaml temp/ +.mcp.json diff --git a/docker-compose.yaml b/docker-compose.yaml index 30ae9053..8ee67e65 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -15,7 +15,7 @@ services: postgres: image: postgres:13 ports: - - 5432:5432 + - 5433:5432 environment: POSTGRES_USER: compass POSTGRES_PASSWORD: compass_password diff --git a/go.mod b/go.mod index e604f03a..8ff15abe 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/jmoiron/sqlx v1.4.0 github.com/lestrrat-go/jwx/v2 v2.1.6 github.com/lib/pq v1.12.0 + github.com/mark3labs/mcp-go v0.46.0 github.com/ory/dockertest/v3 v3.12.0 github.com/peterbourgon/mergemap v0.0.1 github.com/r3labs/diff/v3 v3.0.2 @@ -135,6 +136,7 @@ require ( github.com/go-playground/validator v9.31.0+incompatible // indirect github.com/go-viper/mapstructure/v2 v2.1.0 // indirect github.com/google/cel-go v0.26.1 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -145,12 +147,13 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/spf13/viper v1.19.0 // indirect github.com/stoewer/go-strcase v1.3.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect diff --git a/go.sum b/go.sum index b4053782..aa4a3369 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -201,6 +203,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo= +github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -290,8 +294,8 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= -github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= -github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= @@ -329,6 +333,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.3/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/internal/mcp/format.go b/internal/mcp/format.go new file mode 100644 index 00000000..1759f2f9 --- /dev/null +++ b/internal/mcp/format.go @@ -0,0 +1,201 @@ +package mcp + +import ( + "fmt" + "strings" + + "github.com/raystack/compass/core/asset" +) + +// formatAsset formats an asset as LLM-friendly markdown text. +func formatAsset(a asset.Asset) string { + var b strings.Builder + + fmt.Fprintf(&b, "## %s (%s)\n", a.Name, a.Type) + fmt.Fprintf(&b, "Service: %s | URN: %s\n", a.Service, a.URN) + + if a.Description != "" { + fmt.Fprintf(&b, "Description: %s\n", a.Description) + } + + if len(a.Owners) > 0 { + names := make([]string, 0, len(a.Owners)) + for _, o := range a.Owners { + if o.Email != "" { + names = append(names, o.Email) + } else { + names = append(names, o.UUID) + } + } + fmt.Fprintf(&b, "Owners: %s\n", strings.Join(names, ", ")) + } + + if a.URL != "" { + fmt.Fprintf(&b, "URL: %s\n", a.URL) + } + + if len(a.Labels) > 0 { + pairs := make([]string, 0, len(a.Labels)) + for k, v := range a.Labels { + pairs = append(pairs, fmt.Sprintf("%s=%s", k, v)) + } + fmt.Fprintf(&b, "Labels: %s\n", strings.Join(pairs, ", ")) + } + + formatAssetData(&b, a.Data) + + return b.String() +} + +// formatAssetData formats the Data map, extracting schema columns if present. +func formatAssetData(b *strings.Builder, data map[string]interface{}) { + if data == nil { + return + } + + // Extract schema/columns if present (common in table/topic assets) + if columns, ok := extractColumns(data); ok && len(columns) > 0 { + fmt.Fprintf(b, "\nColumns (%d):\n", len(columns)) + for _, col := range columns { + name, _ := col["name"].(string) + dataType, _ := col["data_type"].(string) + desc, _ := col["description"].(string) + + if desc != "" { + fmt.Fprintf(b, " - %s (%s): %s\n", name, dataType, desc) + } else { + fmt.Fprintf(b, " - %s (%s)\n", name, dataType) + } + } + } +} + +// extractColumns tries to find column definitions in asset data. +func extractColumns(data map[string]interface{}) ([]map[string]interface{}, bool) { + // Try common paths: data.columns, data.schema.columns + if cols, ok := data["columns"]; ok { + return toMapSlice(cols) + } + if schema, ok := data["schema"].(map[string]interface{}); ok { + if cols, ok := schema["columns"]; ok { + return toMapSlice(cols) + } + } + return nil, false +} + +func toMapSlice(v interface{}) ([]map[string]interface{}, bool) { + slice, ok := v.([]interface{}) + if !ok { + return nil, false + } + result := make([]map[string]interface{}, 0, len(slice)) + for _, item := range slice { + if m, ok := item.(map[string]interface{}); ok { + result = append(result, m) + } + } + return result, len(result) > 0 +} + +// formatSearchResult formats a search result as a compact line. +func formatSearchResult(sr asset.SearchResult) string { + var b strings.Builder + fmt.Fprintf(&b, "- **%s** (%s) — service: %s, urn: %s", sr.Title, sr.Type, sr.Service, sr.URN) + if sr.Description != "" { + desc := sr.Description + if len(desc) > 120 { + desc = desc[:120] + "..." + } + fmt.Fprintf(&b, "\n %s", desc) + } + return b.String() +} + +// formatSearchResults formats a list of search results. +func formatSearchResults(results []asset.SearchResult) string { + if len(results) == 0 { + return "No assets found." + } + + var b strings.Builder + fmt.Fprintf(&b, "Found %d assets:\n\n", len(results)) + for _, sr := range results { + b.WriteString(formatSearchResult(sr)) + b.WriteString("\n") + } + return b.String() +} + +// formatLineage formats lineage data as readable text. +func formatLineage(urn string, lineage asset.Lineage) string { + if len(lineage.Edges) == 0 { + return fmt.Sprintf("No lineage found for %s.", urn) + } + + var b strings.Builder + fmt.Fprintf(&b, "Lineage for %s (%d edges):\n\n", urn, len(lineage.Edges)) + + upstreams := make([]string, 0) + downstreams := make([]string, 0) + + for _, edge := range lineage.Edges { + if edge.Target == urn { + upstreams = append(upstreams, edge.Source) + } else if edge.Source == urn { + downstreams = append(downstreams, edge.Target) + } else { + // Transitive edges + fmt.Fprintf(&b, " %s → %s\n", edge.Source, edge.Target) + } + } + + if len(upstreams) > 0 { + b.WriteString("Upstream (sources):\n") + for _, u := range upstreams { + fmt.Fprintf(&b, " ← %s\n", u) + } + } + + if len(downstreams) > 0 { + b.WriteString("Downstream (consumers):\n") + for _, d := range downstreams { + fmt.Fprintf(&b, " → %s\n", d) + } + } + + return b.String() +} + +// formatTypes formats asset type counts. +func formatTypes(types map[asset.Type]int) string { + if len(types) == 0 { + return "No asset types found." + } + + var b strings.Builder + b.WriteString("Asset types:\n\n") + for t, count := range types { + fmt.Fprintf(&b, "- %s: %d assets\n", t, count) + } + return b.String() +} + +// formatAssets formats a list of assets as a summary list. +func formatAssets(assets []asset.Asset, total uint32) string { + if len(assets) == 0 { + return "No assets found." + } + + var b strings.Builder + if total > 0 { + fmt.Fprintf(&b, "Showing %d of %d assets:\n\n", len(assets), total) + } else { + fmt.Fprintf(&b, "Found %d assets:\n\n", len(assets)) + } + + for _, a := range assets { + fmt.Fprintf(&b, "- **%s** (%s) — service: %s, urn: %s\n", a.Name, a.Type, a.Service, a.URN) + } + return b.String() +} diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go new file mode 100644 index 00000000..66aac1f2 --- /dev/null +++ b/internal/mcp/handlers.go @@ -0,0 +1,127 @@ +package mcp + +import ( + "context" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/raystack/compass/core/asset" +) + +func (s *Server) handleSearchAssets(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + text := mcp.ParseString(req, "text", "") + if text == "" { + return mcp.NewToolResultError("'text' parameter is required"), nil + } + + size := mcp.ParseInt(req, "size", 10) + + cfg := asset.SearchConfig{ + Text: strings.TrimSpace(text), + MaxResults: size, + Namespace: s.namespace, + } + + if types := mcp.ParseString(req, "types", ""); types != "" { + cfg.Filters = map[string][]string{ + "type": strings.Split(types, ","), + } + } + if services := mcp.ParseString(req, "services", ""); services != "" { + if cfg.Filters == nil { + cfg.Filters = make(map[string][]string) + } + cfg.Filters["service"] = strings.Split(services, ",") + } + + results, err := s.assetService.SearchAssets(ctx, cfg) + if err != nil { + return mcp.NewToolResultError("failed to search assets: " + err.Error()), nil + } + + return mcp.NewToolResultText(formatSearchResults(results)), nil +} + +func (s *Server) handleGetAsset(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + id := mcp.ParseString(req, "id", "") + if id == "" { + return mcp.NewToolResultError("'id' parameter is required"), nil + } + + a, err := s.assetService.GetAssetByID(ctx, id) + if err != nil { + return mcp.NewToolResultError("failed to get asset: " + err.Error()), nil + } + + return mcp.NewToolResultText(formatAsset(a)), nil +} + +func (s *Server) handleGetLineage(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + urn := mcp.ParseString(req, "urn", "") + if urn == "" { + return mcp.NewToolResultError("'urn' parameter is required"), nil + } + + direction := asset.LineageDirection(mcp.ParseString(req, "direction", "")) + if direction != "" && direction != asset.LineageDirectionUpstream && direction != asset.LineageDirectionDownstream { + return mcp.NewToolResultError("'direction' must be 'upstream', 'downstream', or empty for both"), nil + } + + level := mcp.ParseInt(req, "level", 1) + + lineage, err := s.assetService.GetLineage(ctx, urn, asset.LineageQuery{ + Level: level, + Direction: direction, + WithAttributes: true, + }) + if err != nil { + return mcp.NewToolResultError("failed to get lineage: " + err.Error()), nil + } + + return mcp.NewToolResultText(formatLineage(urn, lineage)), nil +} + +func (s *Server) handleListTypes(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + flt, err := asset.NewFilterBuilder().Build() + if err != nil { + return mcp.NewToolResultError("failed to build filter: " + err.Error()), nil + } + + types, err := s.assetService.GetTypes(ctx, flt) + if err != nil { + return mcp.NewToolResultError("failed to list types: " + err.Error()), nil + } + + return mcp.NewToolResultText(formatTypes(types)), nil +} + +func (s *Server) handleGetAllAssets(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + size := mcp.ParseInt(req, "size", 20) + offset := mcp.ParseInt(req, "offset", 0) + + fb := asset.NewFilterBuilder(). + Size(size). + Offset(offset) + + if types := mcp.ParseString(req, "types", ""); types != "" { + fb = fb.Types(types) + } + if services := mcp.ParseString(req, "services", ""); services != "" { + fb = fb.Services(services) + } + if q := mcp.ParseString(req, "q", ""); q != "" { + fb = fb.Q(q) + } + + flt, err := fb.Build() + if err != nil { + return mcp.NewToolResultError("invalid filter: " + err.Error()), nil + } + + assets, total, err := s.assetService.GetAllAssets(ctx, flt, true) + if err != nil { + return mcp.NewToolResultError("failed to get assets: " + err.Error()), nil + } + + return mcp.NewToolResultText(formatAssets(assets, total)), nil +} diff --git a/internal/mcp/handlers_test.go b/internal/mcp/handlers_test.go new file mode 100644 index 00000000..b916dbbe --- /dev/null +++ b/internal/mcp/handlers_test.go @@ -0,0 +1,298 @@ +package mcp + +import ( + "context" + "fmt" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/raystack/compass/core/asset" + "github.com/raystack/compass/core/namespace" + "github.com/raystack/compass/core/user" +) + +// mockAssetService is a test double for the AssetService interface. +type mockAssetService struct { + searchAssetsFunc func(ctx context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) + getAssetByIDFunc func(ctx context.Context, id string) (asset.Asset, error) + getLineageFunc func(ctx context.Context, urn string, query asset.LineageQuery) (asset.Lineage, error) + getTypesFunc func(ctx context.Context, flt asset.Filter) (map[asset.Type]int, error) + getAllAssetsFunc func(ctx context.Context, flt asset.Filter, withTotal bool) ([]asset.Asset, uint32, error) +} + +func (m *mockAssetService) SearchAssets(ctx context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) { + if m.searchAssetsFunc != nil { + return m.searchAssetsFunc(ctx, cfg) + } + return nil, nil +} + +func (m *mockAssetService) GetAssetByID(ctx context.Context, id string) (asset.Asset, error) { + if m.getAssetByIDFunc != nil { + return m.getAssetByIDFunc(ctx, id) + } + return asset.Asset{}, nil +} + +func (m *mockAssetService) GetLineage(ctx context.Context, urn string, query asset.LineageQuery) (asset.Lineage, error) { + if m.getLineageFunc != nil { + return m.getLineageFunc(ctx, urn, query) + } + return asset.Lineage{}, nil +} + +func (m *mockAssetService) GetTypes(ctx context.Context, flt asset.Filter) (map[asset.Type]int, error) { + if m.getTypesFunc != nil { + return m.getTypesFunc(ctx, flt) + } + return nil, nil +} + +func (m *mockAssetService) GetAllAssets(ctx context.Context, flt asset.Filter, withTotal bool) ([]asset.Asset, uint32, error) { + if m.getAllAssetsFunc != nil { + return m.getAllAssetsFunc(ctx, flt, withTotal) + } + return nil, 0, nil +} + +func newTestServer(svc *mockAssetService) *Server { + return New(svc, namespace.DefaultNamespace) +} + +func callToolRequest(args map[string]any) mcp.CallToolRequest { + return mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: args, + }, + } +} + +func TestHandleSearchAssets(t *testing.T) { + ctx := context.Background() + + t.Run("returns error when text is empty", func(t *testing.T) { + s := newTestServer(&mockAssetService{}) + result, err := s.handleSearchAssets(ctx, callToolRequest(map[string]any{})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error result") + } + }) + + t.Run("returns search results", func(t *testing.T) { + svc := &mockAssetService{ + searchAssetsFunc: func(_ context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) { + if cfg.Text != "orders" { + t.Errorf("expected text 'orders', got '%s'", cfg.Text) + } + return []asset.SearchResult{ + {ID: "1", URN: "urn:bq:orders", Title: "orders", Type: "table", Service: "bigquery"}, + }, nil + }, + } + s := newTestServer(svc) + result, err := s.handleSearchAssets(ctx, callToolRequest(map[string]any{"text": "orders"})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.IsError { + t.Error("unexpected error result") + } + text := getTextContent(result) + if text == "" { + t.Error("expected non-empty text content") + } + }) + + t.Run("passes filters correctly", func(t *testing.T) { + svc := &mockAssetService{ + searchAssetsFunc: func(_ context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) { + if cfg.Filters["type"][0] != "table" { + t.Errorf("expected type filter 'table', got %v", cfg.Filters["type"]) + } + if cfg.Filters["service"][0] != "bigquery" { + t.Errorf("expected service filter 'bigquery', got %v", cfg.Filters["service"]) + } + return []asset.SearchResult{}, nil + }, + } + s := newTestServer(svc) + _, err := s.handleSearchAssets(ctx, callToolRequest(map[string]any{ + "text": "test", + "types": "table", + "services": "bigquery", + })) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("returns error on service failure", func(t *testing.T) { + svc := &mockAssetService{ + searchAssetsFunc: func(_ context.Context, _ asset.SearchConfig) ([]asset.SearchResult, error) { + return nil, fmt.Errorf("connection refused") + }, + } + s := newTestServer(svc) + result, err := s.handleSearchAssets(ctx, callToolRequest(map[string]any{"text": "test"})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error result") + } + }) +} + +func TestHandleGetAsset(t *testing.T) { + ctx := context.Background() + + t.Run("returns error when id is empty", func(t *testing.T) { + s := newTestServer(&mockAssetService{}) + result, err := s.handleGetAsset(ctx, callToolRequest(map[string]any{})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error result") + } + }) + + t.Run("returns asset details", func(t *testing.T) { + svc := &mockAssetService{ + getAssetByIDFunc: func(_ context.Context, id string) (asset.Asset, error) { + return asset.Asset{ + ID: "123", + URN: "urn:bq:orders", + Name: "orders", + Type: asset.Type("table"), + Service: "bigquery", + Description: "Main orders table", + Owners: []user.User{{Email: "alice@co.com"}}, + }, nil + }, + } + s := newTestServer(svc) + result, err := s.handleGetAsset(ctx, callToolRequest(map[string]any{"id": "123"})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := getTextContent(result) + if text == "" { + t.Error("expected non-empty text content") + } + }) +} + +func TestHandleGetLineage(t *testing.T) { + ctx := context.Background() + + t.Run("returns error when urn is empty", func(t *testing.T) { + s := newTestServer(&mockAssetService{}) + result, err := s.handleGetLineage(ctx, callToolRequest(map[string]any{})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error result") + } + }) + + t.Run("returns lineage", func(t *testing.T) { + svc := &mockAssetService{ + getLineageFunc: func(_ context.Context, urn string, q asset.LineageQuery) (asset.Lineage, error) { + return asset.Lineage{ + Edges: []asset.LineageEdge{ + {Source: "urn:bq:raw_orders", Target: urn}, + {Source: urn, Target: "urn:bq:order_summary"}, + }, + }, nil + }, + } + s := newTestServer(svc) + result, err := s.handleGetLineage(ctx, callToolRequest(map[string]any{"urn": "urn:bq:orders"})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := getTextContent(result) + if text == "" { + t.Error("expected non-empty text content") + } + }) + + t.Run("validates direction", func(t *testing.T) { + s := newTestServer(&mockAssetService{}) + result, err := s.handleGetLineage(ctx, callToolRequest(map[string]any{ + "urn": "urn:bq:orders", + "direction": "invalid", + })) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !result.IsError { + t.Error("expected error result for invalid direction") + } + }) +} + +func TestHandleListTypes(t *testing.T) { + ctx := context.Background() + + t.Run("returns types", func(t *testing.T) { + svc := &mockAssetService{ + getTypesFunc: func(_ context.Context, _ asset.Filter) (map[asset.Type]int, error) { + return map[asset.Type]int{ + "table": 42, + "topic": 10, + }, nil + }, + } + s := newTestServer(svc) + result, err := s.handleListTypes(ctx, callToolRequest(nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := getTextContent(result) + if text == "" { + t.Error("expected non-empty text content") + } + }) +} + +func TestHandleGetAllAssets(t *testing.T) { + ctx := context.Background() + + t.Run("returns assets with pagination", func(t *testing.T) { + svc := &mockAssetService{ + getAllAssetsFunc: func(_ context.Context, flt asset.Filter, withTotal bool) ([]asset.Asset, uint32, error) { + if !withTotal { + t.Error("expected withTotal to be true") + } + return []asset.Asset{ + {ID: "1", URN: "urn:bq:orders", Name: "orders", Type: "table", Service: "bigquery"}, + }, 100, nil + }, + } + s := newTestServer(svc) + result, err := s.handleGetAllAssets(ctx, callToolRequest(map[string]any{"size": 10})) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + text := getTextContent(result) + if text == "" { + t.Error("expected non-empty text content") + } + }) +} + +// getTextContent extracts text from the first TextContent in a CallToolResult. +func getTextContent(result *mcp.CallToolResult) string { + for _, c := range result.Content { + if tc, ok := c.(mcp.TextContent); ok { + return tc.Text + } + } + return "" +} diff --git a/internal/mcp/server.go b/internal/mcp/server.go new file mode 100644 index 00000000..0cd55b76 --- /dev/null +++ b/internal/mcp/server.go @@ -0,0 +1,57 @@ +package mcp + +import ( + "context" + "net/http" + + mcpserver "github.com/mark3labs/mcp-go/server" + "github.com/raystack/compass/core/asset" + "github.com/raystack/compass/core/namespace" +) + +// AssetService defines the asset operations needed by the MCP server. +type AssetService interface { + SearchAssets(ctx context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) + GetAssetByID(ctx context.Context, id string) (asset.Asset, error) + GetLineage(ctx context.Context, urn string, query asset.LineageQuery) (asset.Lineage, error) + GetTypes(ctx context.Context, flt asset.Filter) (map[asset.Type]int, error) + GetAllAssets(ctx context.Context, flt asset.Filter, withTotal bool) ([]asset.Asset, uint32, error) +} + +// Server is the MCP server that exposes Compass catalog as AI-agent tools. +type Server struct { + assetService AssetService + namespace *namespace.Namespace + mcpServer *mcpserver.MCPServer + httpServer *mcpserver.StreamableHTTPServer +} + +// New creates a new MCP server with the given dependencies. +func New(assetSvc AssetService, ns *namespace.Namespace) *Server { + s := &Server{ + assetService: assetSvc, + namespace: ns, + } + + mcpSrv := mcpserver.NewMCPServer( + "compass", + "0.1.0", + mcpserver.WithToolCapabilities(false), + ) + + mcpSrv.AddTool(searchAssetsTool(), s.handleSearchAssets) + mcpSrv.AddTool(getAssetTool(), s.handleGetAsset) + mcpSrv.AddTool(getLineageTool(), s.handleGetLineage) + mcpSrv.AddTool(listTypesTool(), s.handleListTypes) + mcpSrv.AddTool(getAllAssetsTool(), s.handleGetAllAssets) + + s.mcpServer = mcpSrv + s.httpServer = mcpserver.NewStreamableHTTPServer(mcpSrv) + + return s +} + +// Handler returns an http.Handler for mounting the MCP server on an existing mux. +func (s *Server) Handler() http.Handler { + return s.httpServer +} diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go new file mode 100644 index 00000000..112ca107 --- /dev/null +++ b/internal/mcp/server_test.go @@ -0,0 +1,229 @@ +package mcp + +import ( + "context" + "strings" + "testing" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/raystack/compass/core/asset" + "github.com/raystack/compass/core/namespace" + "github.com/raystack/compass/core/user" +) + +// TestMCPEndToEnd tests the full MCP flow: initialize → list tools → call tools +// using the in-process transport (no network, but exercises the full mcp-go stack). +func TestMCPEndToEnd(t *testing.T) { + svc := &mockAssetService{ + searchAssetsFunc: func(_ context.Context, cfg asset.SearchConfig) ([]asset.SearchResult, error) { + return []asset.SearchResult{ + {ID: "1", URN: "urn:bq:dataset.orders", Title: "orders", Type: "table", Service: "bigquery", Description: "Main orders table"}, + {ID: "2", URN: "urn:bq:dataset.customers", Title: "customers", Type: "table", Service: "bigquery", Description: "Customer records"}, + }, nil + }, + getAssetByIDFunc: func(_ context.Context, id string) (asset.Asset, error) { + return asset.Asset{ + ID: "1", + URN: "urn:bq:dataset.orders", + Name: "orders", + Type: "table", + Service: "bigquery", + Description: "Main orders table with all customer transactions", + Owners: []user.User{{Email: "alice@company.com"}, {Email: "bob@company.com"}}, + Data: map[string]interface{}{ + "columns": []interface{}{ + map[string]interface{}{"name": "order_id", "data_type": "INTEGER", "description": "Primary key"}, + map[string]interface{}{"name": "customer_id", "data_type": "INTEGER", "description": "FK to customers"}, + map[string]interface{}{"name": "amount", "data_type": "FLOAT", "description": "Order total in USD"}, + }, + }, + }, nil + }, + getLineageFunc: func(_ context.Context, urn string, q asset.LineageQuery) (asset.Lineage, error) { + return asset.Lineage{ + Edges: []asset.LineageEdge{ + {Source: "urn:bq:raw_orders", Target: urn}, + {Source: urn, Target: "urn:bq:order_summary"}, + }, + }, nil + }, + getTypesFunc: func(_ context.Context, _ asset.Filter) (map[asset.Type]int, error) { + return map[asset.Type]int{ + "table": 42, + "topic": 15, + "dashboard": 8, + }, nil + }, + getAllAssetsFunc: func(_ context.Context, flt asset.Filter, _ bool) ([]asset.Asset, uint32, error) { + return []asset.Asset{ + {ID: "1", URN: "urn:bq:dataset.orders", Name: "orders", Type: "table", Service: "bigquery"}, + {ID: "2", URN: "urn:bq:dataset.customers", Name: "customers", Type: "table", Service: "bigquery"}, + }, 100, nil + }, + } + + srv := New(svc, namespace.DefaultNamespace) + ctx := context.Background() + + // Create in-process MCP client (exercises the full mcp-go protocol stack) + client, err := mcpclient.NewInProcessClient(srv.mcpServer) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + if err := client.Start(ctx); err != nil { + t.Fatalf("failed to start client: %v", err) + } + + // Step 1: Initialize + initResult, err := client.Initialize(ctx, mcp.InitializeRequest{}) + if err != nil { + t.Fatalf("initialize failed: %v", err) + } + t.Logf("Server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + + if initResult.ServerInfo.Name != "compass" { + t.Errorf("expected server name 'compass', got '%s'", initResult.ServerInfo.Name) + } + + // Step 2: List tools + tools, err := client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + t.Fatalf("list tools failed: %v", err) + } + t.Logf("Available tools (%d):", len(tools.Tools)) + for _, tool := range tools.Tools { + t.Logf(" - %s: %s", tool.Name, tool.Description) + } + + expectedTools := []string{"search_assets", "get_asset", "get_lineage", "list_types", "get_all_assets"} + if len(tools.Tools) != len(expectedTools) { + t.Fatalf("expected %d tools, got %d", len(expectedTools), len(tools.Tools)) + } + toolNames := make(map[string]bool) + for _, tool := range tools.Tools { + toolNames[tool.Name] = true + } + for _, name := range expectedTools { + if !toolNames[name] { + t.Errorf("missing tool: %s", name) + } + } + + // Step 3: Call search_assets + t.Run("search_assets", func(t *testing.T) { + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "search_assets", + Arguments: map[string]any{"text": "orders"}, + }, + }) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + text := extractText(result) + t.Logf("Result:\n%s", text) + + if !strings.Contains(text, "orders") { + t.Error("expected result to contain 'orders'") + } + if !strings.Contains(text, "Found 2 assets") { + t.Error("expected result to mention 2 assets found") + } + }) + + // Step 4: Call get_asset + t.Run("get_asset", func(t *testing.T) { + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "get_asset", + Arguments: map[string]any{"id": "urn:bq:dataset.orders"}, + }, + }) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + text := extractText(result) + t.Logf("Result:\n%s", text) + + if !strings.Contains(text, "orders") { + t.Error("expected result to contain 'orders'") + } + if !strings.Contains(text, "alice@company.com") { + t.Error("expected result to contain owner email") + } + if !strings.Contains(text, "order_id") { + t.Error("expected result to contain column info") + } + }) + + // Step 5: Call get_lineage + t.Run("get_lineage", func(t *testing.T) { + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "get_lineage", + Arguments: map[string]any{"urn": "urn:bq:dataset.orders"}, + }, + }) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + text := extractText(result) + t.Logf("Result:\n%s", text) + + if !strings.Contains(text, "raw_orders") { + t.Error("expected result to contain upstream 'raw_orders'") + } + if !strings.Contains(text, "order_summary") { + t.Error("expected result to contain downstream 'order_summary'") + } + }) + + // Step 6: Call list_types + t.Run("list_types", func(t *testing.T) { + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "list_types", + }, + }) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + text := extractText(result) + t.Logf("Result:\n%s", text) + + if !strings.Contains(text, "table") { + t.Error("expected result to contain 'table'") + } + }) + + // Step 7: Call get_all_assets + t.Run("get_all_assets", func(t *testing.T) { + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "get_all_assets", + Arguments: map[string]any{"size": 10}, + }, + }) + if err != nil { + t.Fatalf("call tool failed: %v", err) + } + text := extractText(result) + t.Logf("Result:\n%s", text) + + if !strings.Contains(text, "Showing 2 of 100 assets") { + t.Error("expected result to show pagination info") + } + }) +} + +func extractText(result *mcp.CallToolResult) string { + for _, c := range result.Content { + if tc, ok := c.(mcp.TextContent); ok { + return tc.Text + } + } + return "" +} diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go new file mode 100644 index 00000000..543d7413 --- /dev/null +++ b/internal/mcp/tools.go @@ -0,0 +1,77 @@ +package mcp + +import ( + "github.com/mark3labs/mcp-go/mcp" +) + +func searchAssetsTool() mcp.Tool { + return mcp.NewTool("search_assets", + mcp.WithDescription("Search for data assets in the Compass catalog. Returns matching tables, topics, dashboards, and other assets."), + mcp.WithString("text", + mcp.Required(), + mcp.Description("Search query text"), + ), + mcp.WithString("types", + mcp.Description("Comma-separated asset types to filter by (e.g. table,topic,dashboard)"), + ), + mcp.WithString("services", + mcp.Description("Comma-separated services to filter by (e.g. bigquery,kafka)"), + ), + mcp.WithNumber("size", + mcp.Description("Maximum number of results to return (default: 10)"), + ), + ) +} + +func getAssetTool() mcp.Tool { + return mcp.NewTool("get_asset", + mcp.WithDescription("Get full details of a data asset by its ID (UUID) or URN. Returns schema, owners, description, labels, and metadata."), + mcp.WithString("id", + mcp.Required(), + mcp.Description("Asset ID (UUID) or URN"), + ), + ) +} + +func getLineageTool() mcp.Tool { + return mcp.NewTool("get_lineage", + mcp.WithDescription("Get the lineage graph for a data asset. Shows upstream sources and downstream consumers."), + mcp.WithString("urn", + mcp.Required(), + mcp.Description("URN of the asset to get lineage for"), + ), + mcp.WithString("direction", + mcp.Description("Lineage direction: upstream, downstream, or both (default: both)"), + ), + mcp.WithNumber("level", + mcp.Description("Number of hops to traverse (default: 1)"), + ), + ) +} + +func listTypesTool() mcp.Tool { + return mcp.NewTool("list_types", + mcp.WithDescription("List all asset types in the catalog with their counts."), + ) +} + +func getAllAssetsTool() mcp.Tool { + return mcp.NewTool("get_all_assets", + mcp.WithDescription("Browse and filter data assets in the catalog with pagination."), + mcp.WithString("types", + mcp.Description("Comma-separated asset types to filter by (e.g. table,topic)"), + ), + mcp.WithString("services", + mcp.Description("Comma-separated services to filter by (e.g. bigquery,kafka)"), + ), + mcp.WithString("q", + mcp.Description("Query string to filter assets by name"), + ), + mcp.WithNumber("size", + mcp.Description("Number of results per page (default: 20)"), + ), + mcp.WithNumber("offset", + mcp.Description("Offset for pagination (default: 0)"), + ), + ) +} diff --git a/internal/server/bootstrap.go b/internal/server/bootstrap.go index 029bd076..f9bbfa3b 100644 --- a/internal/server/bootstrap.go +++ b/internal/server/bootstrap.go @@ -15,6 +15,7 @@ import ( "github.com/raystack/compass/core/tag" "github.com/raystack/compass/core/user" "github.com/raystack/compass/internal/config" + compassmcp "github.com/raystack/compass/internal/mcp" "github.com/raystack/compass/internal/telemetry" esStore "github.com/raystack/compass/store/elasticsearch" "github.com/raystack/compass/store/postgres" @@ -115,9 +116,13 @@ func Start(ctx context.Context, cfg *config.Config, version string) error { // init namespace namespaceService := namespace.NewService(postgres.NewNamespaceRepository(pgClient), discoveryRepository) + // init MCP server + mcpServer := compassmcp.New(assetService, namespace.DefaultNamespace) + return Serve( ctx, cfg.Service, + mcpServer, namespaceService, assetService, starService, diff --git a/internal/server/server.go b/internal/server/server.go index b094d7af..1b1d5c98 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,6 +16,7 @@ import ( "github.com/raystack/compass/gen/raystack/compass/v1beta1/compassv1beta1connect" "github.com/raystack/compass/handler" "github.com/raystack/compass/internal/config" + compassmcp "github.com/raystack/compass/internal/mcp" "github.com/raystack/compass/internal/middleware" "github.com/rs/cors" "golang.org/x/net/http2" @@ -25,6 +26,7 @@ import ( func Serve( ctx context.Context, cfg config.ServerConfig, + mcpServer *compassmcp.Server, namespaceService handler.NamespaceService, assetService handler.AssetService, starService handler.StarService, @@ -88,6 +90,12 @@ func Serve( _, _ = w.Write([]byte("pong")) }) + // MCP server for AI agent tool access + if mcpServer != nil { + mux.Handle("/mcp", mcpServer.Handler()) + logger.InfoContext(ctx, "MCP server enabled at /mcp") + } + // CORS middleware corsHandler := cors.New(cors.Options{ AllowedOrigins: cfg.CORS.AllowedOrigins,