diff --git a/config.example.yml b/config.example.yml index c86cd98..bc6a6b5 100644 --- a/config.example.yml +++ b/config.example.yml @@ -58,6 +58,13 @@ infrastructure: host: "" port: 6379 password: "" +# cluster: +# enabled: false +# server_name: "" # defaults to OS hostname +# advertise_url: "" # reachable URL for this agent (e.g. https://my-server:8090) +# health_interval: "30s" +# request_timeout: "10s" + security: enabled: true realtime_capture: false diff --git a/internal/api/cluster_handlers.go b/internal/api/cluster_handlers.go new file mode 100644 index 0000000..9759c94 --- /dev/null +++ b/internal/api/cluster_handlers.go @@ -0,0 +1,342 @@ +package api + +import ( + "crypto/rand" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/cluster" + "github.com/flatrun/agent/pkg/version" + "github.com/gin-gonic/gin" +) + +func (s *Server) clusterStatus(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusOK, gin.H{ + "enabled": false, + }) + return + } + + peers := s.clusterManager.ListPeers() + c.JSON(http.StatusOK, gin.H{ + "enabled": true, + "server_name": s.clusterManager.ServerName(), + "peer_count": len(peers), + "version": version.Get(), + }) +} + +func (s *Server) clusterListPeers(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + peers := s.clusterManager.ListPeers() + c.JSON(http.StatusOK, gin.H{"peers": peers}) +} + +func (s *Server) clusterInvite(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + actor := auth.GetActorFromContext(c) + if actor == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate invite token"}) + return + } + token := base64.URLEncoding.EncodeToString(tokenBytes) + tokenHash := cluster.HashToken(token) + + invite := &cluster.Invite{ + TokenHash: tokenHash, + Status: "pending", + CreatedBy: actor.UserID, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + if _, err := s.clusterManager.DB().CreateInvite(invite); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create invite"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "invite_token": token, + "expires_at": invite.ExpiresAt, + }) +} + +func (s *Server) clusterAccept(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + var req struct { + InviteToken string `json:"invite_token" binding:"required"` + PeerURL string `json:"peer_url" binding:"required"` + CallbackURL string `json:"callback_url"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + callbackURL := req.CallbackURL + if callbackURL == "" { + callbackURL = s.config.Cluster.AdvertiseURL + } + if callbackURL == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "No callback URL available. Set cluster.advertise_url in config or pass callback_url in the request.", + }) + return + } + + apiKeyBytes := make([]byte, 32) + if _, err := rand.Read(apiKeyBytes); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate API key"}) + return + } + ourAPIKeyForThem := base64.URLEncoding.EncodeToString(apiKeyBytes) + + exchangeReq := exchangeRequest{ + InviteToken: req.InviteToken, + URL: callbackURL, + APIKey: ourAPIKeyForThem, + Name: s.clusterManager.ServerName(), + } + + body, err := json.Marshal(exchangeReq) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to encode exchange request"}) + return + } + + tempClient := cluster.NewClient(req.PeerURL, "", 10*time.Second) + respData, status, err := tempClient.Post(c.Request.Context(), "/api/cluster/exchange", body) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("Failed to contact peer: %v", err)}) + return + } + if status != http.StatusOK { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("Peer rejected exchange: %s", string(respData))}) + return + } + + var exchangeResp exchangeResponse + if err := json.Unmarshal(respData, &exchangeResp); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse peer response"}) + return + } + + if err := s.clusterManager.AddPeer(exchangeResp.Name, req.PeerURL, exchangeResp.APIKey); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to store peer: %v", err)}) + return + } + + if s.authManager != nil { + s.createClusterAPIKey(ourAPIKeyForThem, exchangeResp.Name) + } + + c.JSON(http.StatusOK, gin.H{ + "peer_name": exchangeResp.Name, + "peer_url": req.PeerURL, + "status": "peered", + }) +} + +type exchangeRequest struct { + InviteToken string `json:"invite_token"` + URL string `json:"url"` + APIKey string `json:"api_key"` + Name string `json:"name"` +} + +type exchangeResponse struct { + APIKey string `json:"api_key"` + Name string `json:"name"` +} + +func (s *Server) clusterExchange(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + var req exchangeRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + tokenHash := cluster.HashToken(req.InviteToken) + invite, err := s.clusterManager.DB().GetInviteByHash(tokenHash) + if err != nil { + if err == sql.ErrNoRows { + c.JSON(http.StatusNotFound, gin.H{"error": "Invalid or expired invite token"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to look up invite"}) + return + } + + if invite.Status != "pending" { + c.JSON(http.StatusConflict, gin.H{"error": "Invite has already been used"}) + return + } + + if time.Now().After(invite.ExpiresAt) { + c.JSON(http.StatusGone, gin.H{"error": "Invite has expired"}) + return + } + + if err := s.clusterManager.DB().ConsumeInvite(tokenHash, req.Name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to consume invite"}) + return + } + + apiKeyBytes := make([]byte, 32) + if _, err := rand.Read(apiKeyBytes); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate API key"}) + return + } + ourAPIKeyForThem := base64.URLEncoding.EncodeToString(apiKeyBytes) + + if err := s.clusterManager.AddPeer(req.Name, req.URL, req.APIKey); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to store peer: %v", err)}) + return + } + + if s.authManager != nil { + s.createClusterAPIKey(ourAPIKeyForThem, req.Name) + } + + c.JSON(http.StatusOK, exchangeResponse{ + APIKey: ourAPIKeyForThem, + Name: s.clusterManager.ServerName(), + }) +} + +func (s *Server) createClusterAPIKey(rawKey, peerName string) { + if s.authManager == nil { + return + } + _, _ = s.authManager.CreateAPIKeyFromRaw( + rawKey, + 1, + fmt.Sprintf("cluster-peer-%s", peerName), + fmt.Sprintf("Auto-generated API key for cluster peer %s", peerName), + auth.RoleAdmin, + nil, + nil, + time.Time{}, + ) +} + +func (s *Server) clusterRemovePeer(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + name := c.Param("name") + if err := s.clusterManager.RemovePeer(name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "removed", "peer": name}) +} + +func (s *Server) clusterProxy(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + name := c.Param("name") + path := c.Param("path") + + client, err := s.clusterManager.GetPeer(name) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + var body io.Reader + if c.Request.Body != nil { + body = c.Request.Body + } + + data, status, headers, err := client.Forward(c.Request.Context(), c.Request.Method, "/api"+path, body) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("Failed to proxy request: %v", err)}) + return + } + + for k, v := range headers { + if k != "Content-Length" && k != "Transfer-Encoding" { + c.Header(k, v) + } + } + c.Data(status, "application/json", data) +} + +func (s *Server) clusterAggregateDeployments(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + deployments, err := s.manager.ListDeployments() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + localData, err := json.Marshal(deployments) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to marshal local deployments"}) + return + } + + result := cluster.AggregateFromPeers(c.Request.Context(), localData, s.clusterManager, "/api/deployments") + c.JSON(http.StatusOK, result) +} + +func (s *Server) clusterAggregateStats(c *gin.Context) { + if s.clusterManager == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cluster is not enabled"}) + return + } + + localStats := gin.H{ + "status": "healthy", + "version": version.Get(), + } + + localData, err := json.Marshal(localStats) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to marshal local stats"}) + return + } + + result := cluster.AggregateFromPeers(c.Request.Context(), localData, s.clusterManager, "/api/health") + c.JSON(http.StatusOK, result) +} diff --git a/internal/api/cluster_handlers_test.go b/internal/api/cluster_handlers_test.go new file mode 100644 index 0000000..77b7e18 --- /dev/null +++ b/internal/api/cluster_handlers_test.go @@ -0,0 +1,621 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/cluster" + "github.com/flatrun/agent/pkg/config" + "github.com/gin-gonic/gin" +) + +type testClusterEnv struct { + server *Server + router *gin.Engine + tmpDir string + cleanup func() +} + +func setupClusterTestServer(t *testing.T, serverName string, clusterEnabled bool) *testClusterEnv { + t.Helper() + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "cluster_api_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + cfg := &config.Config{ + DeploymentsPath: tmpDir, + API: config.APIConfig{ + Host: "127.0.0.1", + Port: 8090, + }, + Auth: config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-for-cluster", + APIKeys: []string{"legacy-test-key"}, + }, + Cluster: config.ClusterConfig{ + Enabled: clusterEnabled, + ServerName: serverName, + HealthInterval: "30s", + RequestTimeout: "5s", + }, + } + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + + authManager, err := auth.NewManager(tmpDir, &cfg.Auth) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create auth manager: %v", err) + } + + var clusterManager *cluster.Manager + if clusterEnabled { + clusterDB, err := cluster.NewDB(tmpDir) + if err != nil { + authManager.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create cluster DB: %v", err) + } + clusterManager = cluster.NewManager(clusterDB, serverName, 30*time.Second, 5*time.Second, cfg.Auth.JWTSecret) + if err := clusterManager.Start(context.Background()); err != nil { + authManager.Close() + os.RemoveAll(tmpDir) + t.Fatalf("Failed to start cluster manager: %v", err) + } + } + + server := &Server{ + config: cfg, + authManager: authManager, + clusterManager: clusterManager, + } + + router := gin.New() + authMiddleware := auth.NewMiddlewareWithManager(&cfg.Auth, authManager) + + api := router.Group("/api") + api.POST("/auth/login", authMiddleware.Login) + + api.POST("/cluster/exchange", server.clusterExchange) + + protected := api.Group("") + protected.Use(authMiddleware.RequireAuth()) + { + clusterGroup := protected.Group("/cluster") + clusterGroup.Use(authMiddleware.RequirePermission(auth.PermClusterRead)) + { + clusterGroup.GET("/status", server.clusterStatus) + clusterGroup.GET("/peers", server.clusterListPeers) + clusterGroup.POST("/invite", authMiddleware.RequirePermission(auth.PermClusterWrite), server.clusterInvite) + clusterGroup.POST("/accept", authMiddleware.RequirePermission(auth.PermClusterWrite), server.clusterAccept) + clusterGroup.DELETE("/peers/:name", authMiddleware.RequirePermission(auth.PermClusterWrite), server.clusterRemovePeer) + clusterGroup.Any("/peers/:name/proxy/*path", authMiddleware.RequirePermission(auth.PermClusterWrite), server.clusterProxy) + clusterGroup.GET("/deployments", server.clusterAggregateDeployments) + clusterGroup.GET("/stats", server.clusterAggregateStats) + } + } + + cleanup := func() { + if clusterManager != nil { + clusterManager.Stop() + } + authManager.Close() + os.RemoveAll(tmpDir) + } + + return &testClusterEnv{ + server: server, + router: router, + tmpDir: tmpDir, + cleanup: cleanup, + } +} + +func clusterLogin(t *testing.T, router *gin.Engine) string { + t.Helper() + body, _ := json.Marshal(map[string]string{ + "username": "admin", + "password": "testadminpass", + }) + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Login failed: %d - %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + return resp["token"].(string) +} + +func TestClusterStatusDisabled(t *testing.T) { + env := setupClusterTestServer(t, "disabled-server", false) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/status", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + if resp["enabled"] != false { + t.Error("Expected enabled=false when cluster is disabled") + } +} + +func TestClusterStatusEnabled(t *testing.T) { + env := setupClusterTestServer(t, "my-server", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/status", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + if resp["enabled"] != true { + t.Error("Expected enabled=true") + } + if resp["server_name"] != "my-server" { + t.Errorf("server_name = %v, want my-server", resp["server_name"]) + } + if resp["peer_count"] != float64(0) { + t.Errorf("peer_count = %v, want 0", resp["peer_count"]) + } +} + +func TestClusterListPeersEmpty(t *testing.T) { + env := setupClusterTestServer(t, "server-a", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &resp) + peers := resp["peers"] + if peers != nil { + peerList, ok := peers.([]interface{}) + if ok && len(peerList) != 0 { + t.Errorf("Expected empty peers list, got %d peers", len(peerList)) + } + } +} + +func TestClusterListPeersDisabled(t *testing.T) { + env := setupClusterTestServer(t, "disabled", false) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected 400 when cluster disabled, got %d", w.Code) + } +} + +func TestClusterInviteAndExchange(t *testing.T) { + env := setupClusterTestServer(t, "server-a", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + // Step 1: Create invite + req := httptest.NewRequest(http.MethodPost, "/api/cluster/invite", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Invite failed: %d - %s", w.Code, w.Body.String()) + } + + var inviteResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &inviteResp) + + inviteToken, ok := inviteResp["invite_token"].(string) + if !ok || inviteToken == "" { + t.Fatal("Expected non-empty invite_token in response") + } + + if _, ok := inviteResp["expires_at"]; !ok { + t.Error("Expected expires_at in response") + } + + // Step 2: Exchange (simulate Server B calling exchange on Server A) + exchangeBody, _ := json.Marshal(map[string]string{ + "invite_token": inviteToken, + "url": "https://server-b.example.com:8090", + "api_key": "server-b-api-key-for-a", + "name": "server-b", + }) + + req = httptest.NewRequest(http.MethodPost, "/api/cluster/exchange", bytes.NewBuffer(exchangeBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Exchange failed: %d - %s", w.Code, w.Body.String()) + } + + var exchangeResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &exchangeResp) + + if exchangeResp["name"] != "server-a" { + t.Errorf("Exchange name = %v, want server-a", exchangeResp["name"]) + } + apiKey, ok := exchangeResp["api_key"].(string) + if !ok || apiKey == "" { + t.Error("Expected non-empty api_key in exchange response") + } + + // Step 3: Verify peer was added + req = httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+token) + w = httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + var peersResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &peersResp) + + peers, ok := peersResp["peers"].([]interface{}) + if !ok || len(peers) != 1 { + t.Fatalf("Expected 1 peer after exchange, got %v", peersResp["peers"]) + } + + peer := peers[0].(map[string]interface{}) + if peer["name"] != "server-b" { + t.Errorf("Peer name = %v, want server-b", peer["name"]) + } +} + +func TestClusterExchangeInvalidToken(t *testing.T) { + env := setupClusterTestServer(t, "server-a", true) + defer env.cleanup() + + exchangeBody, _ := json.Marshal(map[string]string{ + "invite_token": "invalid-token", + "url": "https://b.example.com", + "api_key": "key", + "name": "server-b", + }) + + req := httptest.NewRequest(http.MethodPost, "/api/cluster/exchange", bytes.NewBuffer(exchangeBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected 404 for invalid token, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestClusterExchangeTokenReuse(t *testing.T) { + env := setupClusterTestServer(t, "server-a", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + // Create invite + req := httptest.NewRequest(http.MethodPost, "/api/cluster/invite", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + var inviteResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &inviteResp) + inviteToken := inviteResp["invite_token"].(string) + + // First exchange succeeds + exchangeBody, _ := json.Marshal(map[string]string{ + "invite_token": inviteToken, + "url": "https://b.example.com", + "api_key": "key-b", + "name": "server-b", + }) + req = httptest.NewRequest(http.MethodPost, "/api/cluster/exchange", bytes.NewBuffer(exchangeBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("First exchange should succeed: %d", w.Code) + } + + // Second exchange with same token fails + exchangeBody, _ = json.Marshal(map[string]string{ + "invite_token": inviteToken, + "url": "https://c.example.com", + "api_key": "key-c", + "name": "server-c", + }) + req = httptest.NewRequest(http.MethodPost, "/api/cluster/exchange", bytes.NewBuffer(exchangeBody)) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("Reused token should return 409 Conflict, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestClusterRemovePeer(t *testing.T) { + env := setupClusterTestServer(t, "server-a", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + // Add a peer via exchange + invReq := httptest.NewRequest(http.MethodPost, "/api/cluster/invite", nil) + invReq.Header.Set("Authorization", "Bearer "+token) + invW := httptest.NewRecorder() + env.router.ServeHTTP(invW, invReq) + + var invResp map[string]interface{} + json.Unmarshal(invW.Body.Bytes(), &invResp) + + exchangeBody, _ := json.Marshal(map[string]string{ + "invite_token": invResp["invite_token"].(string), + "url": "https://b.example.com", + "api_key": "key-b", + "name": "server-b", + }) + exReq := httptest.NewRequest(http.MethodPost, "/api/cluster/exchange", bytes.NewBuffer(exchangeBody)) + exReq.Header.Set("Content-Type", "application/json") + exW := httptest.NewRecorder() + env.router.ServeHTTP(exW, exReq) + + // Remove peer + req := httptest.NewRequest(http.MethodDelete, "/api/cluster/peers/server-b", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify removed + req = httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+token) + w = httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + var peersResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &peersResp) + + if peersResp["peers"] != nil { + peerList, ok := peersResp["peers"].([]interface{}) + if ok && len(peerList) != 0 { + t.Errorf("Expected 0 peers after removal, got %d", len(peerList)) + } + } +} + +func TestClusterProxyForwardsToPeer(t *testing.T) { + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "proxied_path": r.URL.Path, + "proxied_method": r.Method, + }) + })) + defer peerServer.Close() + + env := setupClusterTestServer(t, "primary", true) + defer env.cleanup() + + _ = env.server.clusterManager.AddPeer("remote", peerServer.URL, "key") + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/peers/remote/proxy/deployments", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Proxy failed: %d - %s", w.Code, w.Body.String()) + } + + var resp map[string]string + json.Unmarshal(w.Body.Bytes(), &resp) + if resp["proxied_path"] != "/api/deployments" { + t.Errorf("Proxied path = %s, want /api/deployments", resp["proxied_path"]) + } + if resp["proxied_method"] != "GET" { + t.Errorf("Proxied method = %s, want GET", resp["proxied_method"]) + } +} + +func TestClusterProxyUnknownPeer(t *testing.T) { + env := setupClusterTestServer(t, "primary", true) + defer env.cleanup() + + token := clusterLogin(t, env.router) + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/peers/unknown/proxy/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected 404 for unknown peer, got %d", w.Code) + } +} + +func TestClusterUnauthorizedAccess(t *testing.T) { + env := setupClusterTestServer(t, "server", true) + defer env.cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/cluster/status", nil) + w := httptest.NewRecorder() + env.router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected 401, got %d", w.Code) + } +} + +func TestClusterFullPeeringE2E(t *testing.T) { + // This test simulates the full peering flow between two servers + // using httptest servers as live HTTP endpoints. + + envA := setupClusterTestServer(t, "server-a", true) + defer envA.cleanup() + + envB := setupClusterTestServer(t, "server-b", true) + defer envB.cleanup() + + // Stand up real HTTP servers so the agents can reach each other + httpServerA := httptest.NewServer(envA.router) + defer httpServerA.Close() + + httpServerB := httptest.NewServer(envB.router) + defer httpServerB.Close() + + tokenA := clusterLogin(t, envA.router) + tokenB := clusterLogin(t, envB.router) + + // Step 1: Server A creates an invite + req := httptest.NewRequest(http.MethodPost, "/api/cluster/invite", nil) + req.Header.Set("Authorization", "Bearer "+tokenA) + w := httptest.NewRecorder() + envA.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Invite failed: %d - %s", w.Code, w.Body.String()) + } + + var inviteResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &inviteResp) + inviteToken := inviteResp["invite_token"].(string) + + // Step 2: Server B accepts the invite (contacts Server A via live HTTP) + // callback_url tells A how to reach B back + acceptBody, _ := json.Marshal(map[string]string{ + "invite_token": inviteToken, + "peer_url": httpServerA.URL, + "callback_url": httpServerB.URL, + }) + + req = httptest.NewRequest(http.MethodPost, "/api/cluster/accept", bytes.NewBuffer(acceptBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+tokenB) + w = httptest.NewRecorder() + envB.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Accept failed: %d - %s", w.Code, w.Body.String()) + } + + var acceptResp map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &acceptResp) + if acceptResp["peer_name"] != "server-a" { + t.Errorf("Expected peer_name=server-a, got %v", acceptResp["peer_name"]) + } + if acceptResp["status"] != "peered" { + t.Errorf("Expected status=peered, got %v", acceptResp["status"]) + } + + // Step 3: Verify Server A sees Server B as a peer + req = httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+tokenA) + w = httptest.NewRecorder() + envA.router.ServeHTTP(w, req) + + var peersA map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &peersA) + peersListA, _ := peersA["peers"].([]interface{}) + if len(peersListA) != 1 { + t.Fatalf("Server A expected 1 peer, got %d", len(peersListA)) + } + peerOnA := peersListA[0].(map[string]interface{}) + if peerOnA["name"] != "server-b" { + t.Errorf("Server A peer name = %v, want server-b", peerOnA["name"]) + } + + // Step 4: Verify Server B sees Server A as a peer + req = httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+tokenB) + w = httptest.NewRecorder() + envB.router.ServeHTTP(w, req) + + var peersB map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &peersB) + peersListB, _ := peersB["peers"].([]interface{}) + if len(peersListB) != 1 { + t.Fatalf("Server B expected 1 peer, got %d", len(peersListB)) + } + peerOnB := peersListB[0].(map[string]interface{}) + if peerOnB["name"] != "server-a" { + t.Errorf("Server B peer name = %v, want server-a", peerOnB["name"]) + } + + // Step 5: Server A removes the peer + req = httptest.NewRequest(http.MethodDelete, "/api/cluster/peers/server-b", nil) + req.Header.Set("Authorization", "Bearer "+tokenA) + w = httptest.NewRecorder() + envA.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Remove peer failed: %d - %s", w.Code, w.Body.String()) + } + + // Step 6: Verify Server A no longer has the peer + req = httptest.NewRequest(http.MethodGet, "/api/cluster/peers", nil) + req.Header.Set("Authorization", "Bearer "+tokenA) + w = httptest.NewRecorder() + envA.router.ServeHTTP(w, req) + + var afterRemove map[string]interface{} + json.Unmarshal(w.Body.Bytes(), &afterRemove) + if afterRemove["peers"] != nil { + peerList, ok := afterRemove["peers"].([]interface{}) + if ok && len(peerList) != 0 { + t.Errorf("Server A expected 0 peers after removal, got %d", len(peerList)) + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index d0a24fc..e55ffe3 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -18,6 +18,7 @@ import ( "github.com/flatrun/agent/internal/audit" "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/internal/cluster" "github.com/flatrun/agent/internal/backup" "github.com/flatrun/agent/internal/certs" "github.com/flatrun/agent/internal/credentials" @@ -68,6 +69,7 @@ type Server struct { auditManager *audit.Manager auditMiddleware *audit.Middleware powerDNSManager *dns.PowerDNSManager + clusterManager *cluster.Manager } func New(cfg *config.Config, configPath string) *Server { @@ -170,6 +172,28 @@ func New(cfg *config.Config, configPath string) *Server { powerDNSManager := dns.NewPowerDNSManager(cfg) + var clusterManager *cluster.Manager + if cfg.Cluster.Enabled { + clusterDB, clusterErr := cluster.NewDB(cfg.DeploymentsPath) + if clusterErr != nil { + log.Printf("Warning: Failed to initialize cluster database: %v", clusterErr) + } else { + healthInterval, _ := time.ParseDuration(cfg.Cluster.HealthInterval) + if healthInterval == 0 { + healthInterval = 30 * time.Second + } + requestTimeout, _ := time.ParseDuration(cfg.Cluster.RequestTimeout) + if requestTimeout == 0 { + requestTimeout = 10 * time.Second + } + clusterManager = cluster.NewManager(clusterDB, cfg.Cluster.ServerName, healthInterval, requestTimeout, cfg.Auth.JWTSecret) + if startErr := clusterManager.Start(context.Background()); startErr != nil { + log.Printf("Warning: Failed to start cluster manager: %v", startErr) + clusterManager = nil + } + } + } + s := &Server{ config: cfg, configPath: configPath, @@ -192,6 +216,7 @@ func New(cfg *config.Config, configPath string) *Server { auditManager: auditManager, auditMiddleware: auditMiddleware, powerDNSManager: powerDNSManager, + clusterManager: clusterManager, } if backupManager != nil { @@ -492,8 +517,25 @@ func (s *Server) setupRoutes() { // PowerDNS routes NewPowerDNSHandlers(s.powerDNSManager).RegisterRoutes(protected) } + + // Cluster endpoints + clusterGroup := protected.Group("/cluster") + clusterGroup.Use(s.authMiddleware.RequirePermission(auth.PermClusterRead)) + { + clusterGroup.GET("/status", s.clusterStatus) + clusterGroup.GET("/peers", s.clusterListPeers) + clusterGroup.POST("/invite", s.authMiddleware.RequirePermission(auth.PermClusterWrite), s.clusterInvite) + clusterGroup.POST("/accept", s.authMiddleware.RequirePermission(auth.PermClusterWrite), s.clusterAccept) + clusterGroup.DELETE("/peers/:name", s.authMiddleware.RequirePermission(auth.PermClusterWrite), s.clusterRemovePeer) + clusterGroup.Any("/peers/:name/proxy/*path", s.authMiddleware.RequirePermission(auth.PermClusterWrite), s.clusterProxy) + clusterGroup.GET("/deployments", s.clusterAggregateDeployments) + clusterGroup.GET("/stats", s.clusterAggregateStats) + } } + // Cluster exchange endpoint (no auth - uses invite token) + api.POST("/cluster/exchange", s.clusterExchange) + // Ingest endpoints (no auth - called by nginx Lua) api.POST("/security/events/ingest", s.ingestSecurityEvent) api.POST("/traffic/ingest", s.ingestTrafficLog) @@ -516,6 +558,9 @@ func (s *Server) Start() error { } func (s *Server) Stop() error { + if s.clusterManager != nil { + s.clusterManager.Stop() + } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return s.server.Shutdown(ctx) diff --git a/internal/auth/manager.go b/internal/auth/manager.go index 9f4bed6..f5c6b52 100644 --- a/internal/auth/manager.go +++ b/internal/auth/manager.go @@ -1,7 +1,9 @@ package auth import ( + "crypto/rand" "database/sql" + "encoding/hex" "errors" "fmt" "log" @@ -216,6 +218,46 @@ func (m *Manager) CreateAPIKey(userID int64, name, description string, role Role return key, plainKey, nil } +func (m *Manager) CreateAPIKeyFromRaw(rawKey string, userID int64, name, description string, role Role, permissions, deployments []string, expiresAt time.Time) (*APIKey, error) { + keyHash := HashAPIKey(rawKey) + + idBytes := make([]byte, keyIDLength/2) + if _, err := rand.Read(idBytes); err != nil { + return nil, fmt.Errorf("failed to generate key ID: %w", err) + } + keyID := hex.EncodeToString(idBytes) + + var prefix string + if len(rawKey) >= 12 { + prefix = rawKey[:12] + "..." + } else { + prefix = rawKey + "..." + } + + key := &APIKey{ + KeyID: keyID, + UserID: userID, + Name: name, + Description: description, + KeyHash: keyHash, + KeyPrefix: prefix, + Role: role, + Permissions: permissions, + Deployments: deployments, + ExpiresAt: expiresAt, + IsActive: true, + CreatedAt: time.Now(), + } + + id, err := m.db.CreateAPIKey(key) + if err != nil { + return nil, err + } + + key.ID = id + return key, nil +} + func (m *Manager) GetAPIKey(id int64) (*APIKey, error) { key, err := m.db.GetAPIKeyByID(id) if err == sql.ErrNoRows { diff --git a/internal/auth/permissions.go b/internal/auth/permissions.go index a8d1780..7d87153 100644 --- a/internal/auth/permissions.go +++ b/internal/auth/permissions.go @@ -73,6 +73,9 @@ const ( PermTrafficRead Permission = "traffic:read" PermTrafficWrite Permission = "traffic:write" + + PermClusterRead Permission = "cluster:read" + PermClusterWrite Permission = "cluster:write" ) var adminPermissions = []Permission{ @@ -96,6 +99,7 @@ var adminPermissions = []Permission{ PermRegistriesRead, PermRegistriesWrite, PermRegistriesDelete, PermTemplatesRead, PermTemplatesWrite, PermTrafficRead, PermTrafficWrite, + PermClusterRead, PermClusterWrite, } var operatorPermissions = []Permission{ diff --git a/internal/cluster/aggregator.go b/internal/cluster/aggregator.go new file mode 100644 index 0000000..27c9d5c --- /dev/null +++ b/internal/cluster/aggregator.go @@ -0,0 +1,59 @@ +package cluster + +import ( + "context" + "encoding/json" + "fmt" +) + +type ServerResult struct { + Name string `json:"name"` + Online bool `json:"online"` + Data json.RawMessage `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +type AggregatedResponse struct { + Servers map[string]ServerResult `json:"servers"` +} + +func AggregateFromPeers(ctx context.Context, localData []byte, mgr *Manager, path string) *AggregatedResponse { + resp := &AggregatedResponse{ + Servers: make(map[string]ServerResult), + } + + resp.Servers[mgr.ServerName()] = ServerResult{ + Name: mgr.ServerName(), + Online: true, + Data: json.RawMessage(localData), + } + + results := mgr.ForEachPeer(ctx, func(ctx context.Context, name string, client *Client) ([]byte, error) { + data, status, err := client.Get(ctx, path) + if err != nil { + return nil, err + } + if status < 200 || status >= 300 { + return nil, fmt.Errorf("peer returned status %d", status) + } + return data, nil + }) + + for name, result := range results { + if result.Error != "" { + resp.Servers[name] = ServerResult{ + Name: name, + Online: false, + Error: result.Error, + } + } else { + resp.Servers[name] = ServerResult{ + Name: name, + Online: true, + Data: json.RawMessage(result.Data), + } + } + } + + return resp +} diff --git a/internal/cluster/aggregator_test.go b/internal/cluster/aggregator_test.go new file mode 100644 index 0000000..f1a7a3f --- /dev/null +++ b/internal/cluster/aggregator_test.go @@ -0,0 +1,173 @@ +package cluster + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +func TestAggregateFromPeers(t *testing.T) { + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "remote-app", "status": "running"}, + }) + })) + defer peerServer.Close() + + tmpDir, err := os.MkdirTemp("", "cluster_agg_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + mgr := NewManager(db, "primary", 30*time.Second, 5*time.Second, "test-secret") + _ = mgr.AddPeer("remote-server", peerServer.URL, "key") + + localData, _ := json.Marshal([]map[string]string{ + {"name": "local-app", "status": "running"}, + }) + + result := AggregateFromPeers(context.Background(), localData, mgr, "/api/deployments") + + if len(result.Servers) != 2 { + t.Fatalf("Expected 2 servers, got %d", len(result.Servers)) + } + + local, ok := result.Servers["primary"] + if !ok { + t.Fatal("Expected primary server in results") + } + if !local.Online { + t.Error("Local server should be online") + } + if local.Data == nil { + t.Error("Local server data should not be nil") + } + + remote, ok := result.Servers["remote-server"] + if !ok { + t.Fatal("Expected remote-server in results") + } + if !remote.Online { + t.Error("Remote server should be online") + } + if remote.Data == nil { + t.Error("Remote server data should not be nil") + } +} + +func TestAggregateFromPeersWithOfflinePeer(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "cluster_agg_offline_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + mgr := NewManager(db, "primary", 30*time.Second, 1*time.Second, "test-secret") + _ = mgr.AddPeer("offline-peer", "http://127.0.0.1:1", "key") + + localData, _ := json.Marshal([]string{"local-deployment"}) + + result := AggregateFromPeers(context.Background(), localData, mgr, "/api/deployments") + + if len(result.Servers) != 2 { + t.Fatalf("Expected 2 servers, got %d", len(result.Servers)) + } + + offline, ok := result.Servers["offline-peer"] + if !ok { + t.Fatal("Expected offline-peer in results") + } + if offline.Online { + t.Error("Offline peer should not be online") + } + if offline.Error == "" { + t.Error("Offline peer should have error") + } +} + +func TestAggregateFromPeersNoPeers(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "cluster_agg_none_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + mgr := NewManager(db, "solo", 30*time.Second, 5*time.Second, "test-secret") + + localData, _ := json.Marshal(map[string]string{"status": "healthy"}) + + result := AggregateFromPeers(context.Background(), localData, mgr, "/api/health") + + if len(result.Servers) != 1 { + t.Fatalf("Expected 1 server (local only), got %d", len(result.Servers)) + } + + local, ok := result.Servers["solo"] + if !ok { + t.Fatal("Expected solo server in results") + } + if !local.Online { + t.Error("Local server should be online") + } +} + +func TestAggregateFromPeersBadStatusCode(t *testing.T) { + peerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error":"internal error"}`)) + })) + defer peerServer.Close() + + tmpDir, err := os.MkdirTemp("", "cluster_agg_bad_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + mgr := NewManager(db, "primary", 30*time.Second, 5*time.Second, "test-secret") + _ = mgr.AddPeer("error-peer", peerServer.URL, "key") + + localData, _ := json.Marshal([]string{}) + + result := AggregateFromPeers(context.Background(), localData, mgr, "/api/deployments") + + errorPeer, ok := result.Servers["error-peer"] + if !ok { + t.Fatal("Expected error-peer in results") + } + if errorPeer.Online { + t.Error("Peer that returned 500 should be marked offline") + } + if errorPeer.Error == "" { + t.Error("Peer that returned 500 should have error message") + } +} diff --git a/internal/cluster/client.go b/internal/cluster/client.go new file mode 100644 index 0000000..82ac4f8 --- /dev/null +++ b/internal/cluster/client.go @@ -0,0 +1,110 @@ +package cluster + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type Client struct { + baseURL string + apiKey string + httpClient *http.Client +} + +func NewClient(baseURL, apiKey string, timeout time.Duration) *Client { + return &Client{ + baseURL: baseURL, + apiKey: apiKey, + httpClient: &http.Client{ + Timeout: timeout, + }, + } +} + +func (c *Client) Do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + fullURL, err := url.JoinPath(c.baseURL, path) + if err != nil { + return nil, fmt.Errorf("invalid URL path %q: %w", path, err) + } + + req, err := http.NewRequestWithContext(ctx, method, fullURL, body) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + + return c.httpClient.Do(req) +} + +func (c *Client) Health(ctx context.Context) error { + resp, err := c.Do(ctx, "GET", "/api/health", nil) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("health check failed: status %d", resp.StatusCode) + } + return nil +} + +func (c *Client) Get(ctx context.Context, path string) ([]byte, int, error) { + resp, err := c.Do(ctx, "GET", path, nil) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, err + } + return data, resp.StatusCode, nil +} + +func (c *Client) Post(ctx context.Context, path string, body []byte) ([]byte, int, error) { + var reader io.Reader + if body != nil { + reader = bytes.NewReader(body) + } + + resp, err := c.Do(ctx, "POST", path, reader) + if err != nil { + return nil, 0, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, err + } + return data, resp.StatusCode, nil +} + +func (c *Client) Forward(ctx context.Context, method, path string, body io.Reader) ([]byte, int, map[string]string, error) { + resp, err := c.Do(ctx, method, path, body) + if err != nil { + return nil, 0, nil, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, nil, err + } + + headers := make(map[string]string) + for key := range resp.Header { + headers[key] = resp.Header.Get(key) + } + + return data, resp.StatusCode, headers, nil +} diff --git a/internal/cluster/client_test.go b/internal/cluster/client_test.go new file mode 100644 index 0000000..21fad7f --- /dev/null +++ b/internal/cluster/client_test.go @@ -0,0 +1,173 @@ +package cluster + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestNewClient(t *testing.T) { + c := NewClient("https://example.com", "test-key", 10*time.Second) + + if c.baseURL != "https://example.com" { + t.Errorf("baseURL = %s, want https://example.com", c.baseURL) + } + if c.apiKey != "test-key" { + t.Errorf("apiKey = %s, want test-key", c.apiKey) + } +} + +func TestClientSetsAuthHeader(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth != "Bearer my-api-key" { + t.Errorf("Authorization = %s, want Bearer my-api-key", auth) + } + ct := r.Header.Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %s, want application/json", ct) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + ctx := context.Background() + c := NewClient(server.URL, "my-api-key", 5*time.Second) + resp, err := c.Do(ctx, "GET", "/test", nil) + if err != nil { + t.Fatalf("Do failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Status = %d, want 200", resp.StatusCode) + } +} + +func TestClientHealth(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/health" { + t.Errorf("Path = %s, want /api/health", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := NewClient(server.URL, "key", 5*time.Second) + err := c.Health(context.Background()) + if err != nil { + t.Errorf("Health check failed: %v", err) + } +} + +func TestClientHealthFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + c := NewClient(server.URL, "key", 5*time.Second) + err := c.Health(context.Background()) + if err == nil { + t.Error("Health should fail for 503 response") + } +} + +func TestClientHealthUnreachable(t *testing.T) { + c := NewClient("http://127.0.0.1:1", "key", 1*time.Second) + err := c.Health(context.Background()) + if err == nil { + t.Error("Health should fail for unreachable server") + } +} + +func TestClientGet(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Method = %s, want GET", r.Method) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + })) + defer server.Close() + + c := NewClient(server.URL, "key", 5*time.Second) + data, status, err := c.Get(context.Background(), "/api/test") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if status != http.StatusOK { + t.Errorf("Status = %d, want 200", status) + } + + var resp map[string]string + if err := json.Unmarshal(data, &resp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + if resp["status"] != "ok" { + t.Errorf("status = %s, want ok", resp["status"]) + } +} + +func TestClientPost(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Method = %s, want POST", r.Method) + } + body, _ := io.ReadAll(r.Body) + var req map[string]string + json.Unmarshal(body, &req) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"received": req["msg"]}) + })) + defer server.Close() + + c := NewClient(server.URL, "key", 5*time.Second) + body, _ := json.Marshal(map[string]string{"msg": "hello"}) + data, status, err := c.Post(context.Background(), "/api/test", body) + if err != nil { + t.Fatalf("Post failed: %v", err) + } + if status != http.StatusOK { + t.Errorf("Status = %d, want 200", status) + } + + var resp map[string]string + json.Unmarshal(data, &resp) + if resp["received"] != "hello" { + t.Errorf("received = %s, want hello", resp["received"]) + } +} + +func TestClientForward(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", "test-value") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"forwarded":true}`)) + })) + defer server.Close() + + c := NewClient(server.URL, "key", 5*time.Second) + data, status, headers, err := c.Forward(context.Background(), "PUT", "/api/resource", strings.NewReader(`{"data":"test"}`)) + if err != nil { + t.Fatalf("Forward failed: %v", err) + } + if status != http.StatusCreated { + t.Errorf("Status = %d, want 201", status) + } + if headers["X-Custom-Header"] != "test-value" { + t.Errorf("X-Custom-Header = %s, want test-value", headers["X-Custom-Header"]) + } + + var resp map[string]bool + json.Unmarshal(data, &resp) + if !resp["forwarded"] { + t.Error("Expected forwarded=true") + } +} diff --git a/internal/cluster/db.go b/internal/cluster/db.go new file mode 100644 index 0000000..8dd1ac3 --- /dev/null +++ b/internal/cluster/db.go @@ -0,0 +1,251 @@ +package cluster + +import ( + "crypto/sha256" + "database/sql" + "encoding/hex" + "os" + "path/filepath" + "sync" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type Peer struct { + ID int64 `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + APIKeyHash string `json:"-"` + APIKeyEncrypted string `json:"-"` + Status string `json:"status"` + CreatedAt time.Time `json:"created_at"` + LastSeenAt time.Time `json:"last_seen_at,omitempty"` +} + +type Invite struct { + ID int64 `json:"id"` + TokenHash string `json:"-"` + Status string `json:"status"` + CreatedBy int64 `json:"created_by"` + AcceptedPeer string `json:"accepted_peer,omitempty"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} + +type DB struct { + conn *sql.DB + path string + mu sync.RWMutex +} + +func NewDB(deploymentsPath string) (*DB, error) { + dbDir := filepath.Join(deploymentsPath, ".flatrun") + if err := os.MkdirAll(dbDir, 0755); err != nil { + return nil, err + } + + dbPath := filepath.Join(dbDir, "cluster.db") + conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000") + if err != nil { + return nil, err + } + + conn.SetMaxOpenConns(10) + conn.SetMaxIdleConns(5) + conn.SetConnMaxLifetime(time.Hour) + + db := &DB{conn: conn, path: dbPath} + if err := db.migrate(); err != nil { + conn.Close() + return nil, err + } + + return db, nil +} + +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + return db.conn.Close() +} + +func (db *DB) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS peers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + url TEXT NOT NULL, + api_key_hash TEXT NOT NULL, + api_key_encrypted TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'active', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_seen_at DATETIME + ); + + CREATE TABLE IF NOT EXISTS invites ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_hash TEXT UNIQUE NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + created_by INTEGER NOT NULL, + accepted_peer TEXT, + expires_at DATETIME NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_peers_name ON peers(name); + CREATE INDEX IF NOT EXISTS idx_peers_status ON peers(status); + CREATE INDEX IF NOT EXISTS idx_invites_token_hash ON invites(token_hash); + CREATE INDEX IF NOT EXISTS idx_invites_status ON invites(status); + ` + + _, err := db.conn.Exec(schema) + return err +} + +func HashToken(token string) string { + h := sha256.Sum256([]byte(token)) + return hex.EncodeToString(h[:]) +} + +func (db *DB) CreatePeer(peer *Peer) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + INSERT INTO peers (name, url, api_key_hash, api_key_encrypted, status, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, + peer.Name, peer.URL, peer.APIKeyHash, peer.APIKeyEncrypted, peer.Status, time.Now(), + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetPeer(name string) (*Peer, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var p Peer + var lastSeen sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, name, url, api_key_hash, api_key_encrypted, status, created_at, last_seen_at + FROM peers WHERE name = ?`, name).Scan( + &p.ID, &p.Name, &p.URL, &p.APIKeyHash, &p.APIKeyEncrypted, &p.Status, &p.CreatedAt, &lastSeen, + ) + if err != nil { + return nil, err + } + + if lastSeen.Valid { + p.LastSeenAt = lastSeen.Time + } + return &p, nil +} + +func (db *DB) ListPeers() ([]Peer, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, name, url, api_key_hash, api_key_encrypted, status, created_at, last_seen_at + FROM peers ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var peers []Peer + for rows.Next() { + var p Peer + var lastSeen sql.NullTime + + if err := rows.Scan( + &p.ID, &p.Name, &p.URL, &p.APIKeyHash, &p.APIKeyEncrypted, &p.Status, &p.CreatedAt, &lastSeen, + ); err != nil { + return nil, err + } + + if lastSeen.Valid { + p.LastSeenAt = lastSeen.Time + } + peers = append(peers, p) + } + return peers, nil +} + +func (db *DB) DeletePeer(name string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`DELETE FROM peers WHERE name = ?`, name) + return err +} + +func (db *DB) UpdateLastSeen(name string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE peers SET last_seen_at = ? WHERE name = ?`, time.Now(), name) + return err +} + +func (db *DB) CreateInvite(invite *Invite) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + INSERT INTO invites (token_hash, status, created_by, expires_at, created_at) + VALUES (?, ?, ?, ?, ?)`, + invite.TokenHash, invite.Status, invite.CreatedBy, invite.ExpiresAt, time.Now(), + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetInviteByHash(tokenHash string) (*Invite, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var inv Invite + var acceptedPeer sql.NullString + + err := db.conn.QueryRow(` + SELECT id, token_hash, status, created_by, accepted_peer, expires_at, created_at + FROM invites WHERE token_hash = ?`, tokenHash).Scan( + &inv.ID, &inv.TokenHash, &inv.Status, &inv.CreatedBy, &acceptedPeer, &inv.ExpiresAt, &inv.CreatedAt, + ) + if err != nil { + return nil, err + } + + inv.AcceptedPeer = acceptedPeer.String + return &inv, nil +} + +func (db *DB) ConsumeInvite(tokenHash string, peerName string) error { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + UPDATE invites SET status = 'accepted', accepted_peer = ? + WHERE token_hash = ? AND status = 'pending' AND expires_at > ?`, + peerName, tokenHash, time.Now(), + ) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return sql.ErrNoRows + } + return nil +} diff --git a/internal/cluster/db_test.go b/internal/cluster/db_test.go new file mode 100644 index 0000000..3e195d8 --- /dev/null +++ b/internal/cluster/db_test.go @@ -0,0 +1,345 @@ +package cluster + +import ( + "database/sql" + "os" + "path/filepath" + "testing" + "time" +) + +func setupTestDB(t *testing.T) (*DB, func()) { + tmpDir, err := os.MkdirTemp("", "cluster_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + db, err := NewDB(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create DB: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(tmpDir) + } + + return db, cleanup +} + +func TestNewDB(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + if db == nil { + t.Fatal("NewDB returned nil") + } + if db.conn == nil { + t.Fatal("DB connection is nil") + } +} + +func TestDBPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "cluster_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + expectedPath := filepath.Join(tmpDir, ".flatrun", "cluster.db") + if db.path != expectedPath { + t.Errorf("DB path = %s, want %s", db.path, expectedPath) + } + + if _, err := os.Stat(expectedPath); os.IsNotExist(err) { + t.Error("Database file was not created") + } +} + +func TestCreateAndGetPeer(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + peer := &Peer{ + Name: "hetzner-1", + URL: "https://hetzner-1.example.com:8090", + APIKeyHash: HashToken("test-api-key"), + APIKeyEncrypted: "encrypted-key-data", + Status: "active", + } + + id, err := db.CreatePeer(peer) + if err != nil { + t.Fatalf("CreatePeer failed: %v", err) + } + if id <= 0 { + t.Error("CreatePeer should return positive ID") + } + + retrieved, err := db.GetPeer("hetzner-1") + if err != nil { + t.Fatalf("GetPeer failed: %v", err) + } + + if retrieved.Name != "hetzner-1" { + t.Errorf("Name = %s, want hetzner-1", retrieved.Name) + } + if retrieved.URL != "https://hetzner-1.example.com:8090" { + t.Errorf("URL = %s, want https://hetzner-1.example.com:8090", retrieved.URL) + } + if retrieved.Status != "active" { + t.Errorf("Status = %s, want active", retrieved.Status) + } +} + +func TestGetPeerNotFound(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + _, err := db.GetPeer("nonexistent") + if err == nil { + t.Error("GetPeer should fail for nonexistent peer") + } +} + +func TestCreatePeerDuplicateName(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + peer := &Peer{ + Name: "dupe", + URL: "https://a.example.com", + APIKeyHash: "hash1", + APIKeyEncrypted: "enc1", + Status: "active", + } + _, err := db.CreatePeer(peer) + if err != nil { + t.Fatalf("First CreatePeer failed: %v", err) + } + + peer2 := &Peer{ + Name: "dupe", + URL: "https://b.example.com", + APIKeyHash: "hash2", + APIKeyEncrypted: "enc2", + Status: "active", + } + _, err = db.CreatePeer(peer2) + if err == nil { + t.Error("CreatePeer should fail for duplicate name") + } +} + +func TestListPeers(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + peers, err := db.ListPeers() + if err != nil { + t.Fatalf("ListPeers failed: %v", err) + } + if len(peers) != 0 { + t.Errorf("Expected 0 peers, got %d", len(peers)) + } + + _, _ = db.CreatePeer(&Peer{ + Name: "peer-a", URL: "https://a.example.com", + APIKeyHash: "h1", APIKeyEncrypted: "e1", Status: "active", + }) + _, _ = db.CreatePeer(&Peer{ + Name: "peer-b", URL: "https://b.example.com", + APIKeyHash: "h2", APIKeyEncrypted: "e2", Status: "active", + }) + + peers, err = db.ListPeers() + if err != nil { + t.Fatalf("ListPeers failed: %v", err) + } + if len(peers) != 2 { + t.Errorf("Expected 2 peers, got %d", len(peers)) + } +} + +func TestDeletePeer(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + _, _ = db.CreatePeer(&Peer{ + Name: "to-delete", URL: "https://delete.example.com", + APIKeyHash: "h", APIKeyEncrypted: "e", Status: "active", + }) + + err := db.DeletePeer("to-delete") + if err != nil { + t.Fatalf("DeletePeer failed: %v", err) + } + + _, err = db.GetPeer("to-delete") + if err == nil { + t.Error("GetPeer should fail after deletion") + } +} + +func TestUpdateLastSeen(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + _, _ = db.CreatePeer(&Peer{ + Name: "seen-peer", URL: "https://seen.example.com", + APIKeyHash: "h", APIKeyEncrypted: "e", Status: "active", + }) + + before, _ := db.GetPeer("seen-peer") + if !before.LastSeenAt.IsZero() { + t.Error("LastSeenAt should be zero initially") + } + + err := db.UpdateLastSeen("seen-peer") + if err != nil { + t.Fatalf("UpdateLastSeen failed: %v", err) + } + + after, _ := db.GetPeer("seen-peer") + if after.LastSeenAt.IsZero() { + t.Error("LastSeenAt should be set after update") + } +} + +func TestCreateAndGetInvite(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + token := "test-invite-token" + tokenHash := HashToken(token) + + invite := &Invite{ + TokenHash: tokenHash, + Status: "pending", + CreatedBy: 1, + ExpiresAt: time.Now().Add(1 * time.Hour), + } + + id, err := db.CreateInvite(invite) + if err != nil { + t.Fatalf("CreateInvite failed: %v", err) + } + if id <= 0 { + t.Error("CreateInvite should return positive ID") + } + + retrieved, err := db.GetInviteByHash(tokenHash) + if err != nil { + t.Fatalf("GetInviteByHash failed: %v", err) + } + + if retrieved.Status != "pending" { + t.Errorf("Status = %s, want pending", retrieved.Status) + } + if retrieved.CreatedBy != 1 { + t.Errorf("CreatedBy = %d, want 1", retrieved.CreatedBy) + } +} + +func TestGetInviteNotFound(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + _, err := db.GetInviteByHash("nonexistent-hash") + if err == nil { + t.Error("GetInviteByHash should fail for nonexistent invite") + } +} + +func TestConsumeInvite(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + token := "consume-token" + tokenHash := HashToken(token) + + _, _ = db.CreateInvite(&Invite{ + TokenHash: tokenHash, + Status: "pending", + CreatedBy: 1, + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + + err := db.ConsumeInvite(tokenHash, "new-peer") + if err != nil { + t.Fatalf("ConsumeInvite failed: %v", err) + } + + invite, _ := db.GetInviteByHash(tokenHash) + if invite.Status != "accepted" { + t.Errorf("Status = %s, want accepted", invite.Status) + } + if invite.AcceptedPeer != "new-peer" { + t.Errorf("AcceptedPeer = %s, want new-peer", invite.AcceptedPeer) + } +} + +func TestConsumeInviteAlreadyAccepted(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + token := "already-used" + tokenHash := HashToken(token) + + _, _ = db.CreateInvite(&Invite{ + TokenHash: tokenHash, + Status: "pending", + CreatedBy: 1, + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + + _ = db.ConsumeInvite(tokenHash, "first-peer") + + err := db.ConsumeInvite(tokenHash, "second-peer") + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows, got %v", err) + } +} + +func TestConsumeInviteExpired(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + token := "expired-token" + tokenHash := HashToken(token) + + _, _ = db.CreateInvite(&Invite{ + TokenHash: tokenHash, + Status: "pending", + CreatedBy: 1, + ExpiresAt: time.Now().Add(-1 * time.Hour), + }) + + err := db.ConsumeInvite(tokenHash, "late-peer") + if err != sql.ErrNoRows { + t.Errorf("Expected sql.ErrNoRows for expired invite, got %v", err) + } +} + +func TestHashToken(t *testing.T) { + hash1 := HashToken("token-a") + hash2 := HashToken("token-b") + hash1again := HashToken("token-a") + + if hash1 == hash2 { + t.Error("Different tokens should produce different hashes") + } + if hash1 != hash1again { + t.Error("Same token should produce same hash") + } + if len(hash1) != 64 { + t.Errorf("Hash length = %d, want 64", len(hash1)) + } +} diff --git a/internal/cluster/manager.go b/internal/cluster/manager.go new file mode 100644 index 0000000..1d8cb6f --- /dev/null +++ b/internal/cluster/manager.go @@ -0,0 +1,304 @@ +package cluster + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "log" + "sync" + "time" +) + +type PeerStatus struct { + Name string `json:"name"` + URL string `json:"url"` + Online bool `json:"online"` + LastSeen time.Time `json:"last_seen"` + Error string `json:"error,omitempty"` +} + +type Result struct { + Data []byte `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +type Manager struct { + db *DB + clients map[string]*Client + status map[string]*PeerStatus + mu sync.RWMutex + serverName string + healthInterval time.Duration + requestTimeout time.Duration + encryptionKey []byte + cancel context.CancelFunc +} + +func NewManager(db *DB, serverName string, healthInterval, requestTimeout time.Duration, jwtSecret string) *Manager { + key := sha256.Sum256([]byte(jwtSecret)) + return &Manager{ + db: db, + clients: make(map[string]*Client), + status: make(map[string]*PeerStatus), + serverName: serverName, + healthInterval: healthInterval, + requestTimeout: requestTimeout, + encryptionKey: key[:], + } +} + +func (m *Manager) Start(ctx context.Context) error { + ctx, m.cancel = context.WithCancel(ctx) + + peers, err := m.db.ListPeers() + if err != nil { + return fmt.Errorf("failed to load peers: %w", err) + } + + m.mu.Lock() + for _, p := range peers { + if p.Status != "active" { + continue + } + apiKey, err := m.decrypt(p.APIKeyEncrypted) + if err != nil { + log.Printf("Warning: Failed to decrypt API key for peer %s: %v", p.Name, err) + continue + } + m.clients[p.Name] = NewClient(p.URL, apiKey, m.requestTimeout) + m.status[p.Name] = &PeerStatus{ + Name: p.Name, + URL: p.URL, + LastSeen: p.LastSeenAt, + } + } + m.mu.Unlock() + + go m.healthLoop(ctx) + return nil +} + +func (m *Manager) Stop() { + if m.cancel != nil { + m.cancel() + } +} + +func (m *Manager) healthLoop(ctx context.Context) { + ticker := time.NewTicker(m.healthInterval) + defer ticker.Stop() + + m.checkAllPeers(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.checkAllPeers(ctx) + } + } +} + +func (m *Manager) checkAllPeers(ctx context.Context) { + m.mu.RLock() + type peerEntry struct { + name string + client *Client + } + peers := make([]peerEntry, 0, len(m.clients)) + for name, client := range m.clients { + peers = append(peers, peerEntry{name, client}) + } + m.mu.RUnlock() + + var seenNames []string + + for _, p := range peers { + if ctx.Err() != nil { + return + } + + err := p.client.Health(ctx) + + m.mu.Lock() + st, exists := m.status[p.name] + if exists { + if err != nil { + st.Online = false + st.Error = err.Error() + } else { + st.Online = true + st.Error = "" + st.LastSeen = time.Now() + seenNames = append(seenNames, p.name) + } + } + m.mu.Unlock() + } + + for _, name := range seenNames { + _ = m.db.UpdateLastSeen(name) + } +} + +func (m *Manager) GetPeer(name string) (*Client, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + client, ok := m.clients[name] + if !ok { + return nil, fmt.Errorf("peer %q not found", name) + } + return client, nil +} + +func (m *Manager) ListPeers() []PeerStatus { + m.mu.RLock() + defer m.mu.RUnlock() + + result := make([]PeerStatus, 0, len(m.status)) + for _, s := range m.status { + result = append(result, *s) + } + return result +} + +func (m *Manager) AddPeer(name, url, apiKey string) error { + encrypted, err := m.encrypt(apiKey) + if err != nil { + return fmt.Errorf("failed to encrypt API key: %w", err) + } + + peer := &Peer{ + Name: name, + URL: url, + APIKeyHash: HashToken(apiKey), + APIKeyEncrypted: encrypted, + Status: "active", + } + + if _, err := m.db.CreatePeer(peer); err != nil { + return fmt.Errorf("failed to store peer: %w", err) + } + + m.mu.Lock() + m.clients[name] = NewClient(url, apiKey, m.requestTimeout) + m.status[name] = &PeerStatus{ + Name: name, + URL: url, + Online: false, + } + m.mu.Unlock() + + return nil +} + +func (m *Manager) RemovePeer(name string) error { + if err := m.db.DeletePeer(name); err != nil { + return fmt.Errorf("failed to delete peer: %w", err) + } + + m.mu.Lock() + delete(m.clients, name) + delete(m.status, name) + m.mu.Unlock() + + return nil +} + +func (m *Manager) ForEachPeer(ctx context.Context, fn func(ctx context.Context, name string, client *Client) ([]byte, error)) map[string]Result { + m.mu.RLock() + peers := make(map[string]*Client, len(m.clients)) + for n, c := range m.clients { + peers[n] = c + } + m.mu.RUnlock() + + results := make(map[string]Result, len(peers)) + var mu sync.Mutex + var wg sync.WaitGroup + + for name, client := range peers { + wg.Add(1) + go func(n string, c *Client) { + defer wg.Done() + + data, err := fn(ctx, n, c) + + mu.Lock() + if err != nil { + results[n] = Result{Error: err.Error()} + } else { + results[n] = Result{Data: data} + } + mu.Unlock() + }(name, client) + } + + wg.Wait() + return results +} + +func (m *Manager) ServerName() string { + return m.serverName +} + +func (m *Manager) DB() *DB { + return m.db +} + +func (m *Manager) encrypt(plaintext string) (string, error) { + block, err := aes.NewCipher(m.encryptionKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func (m *Manager) decrypt(encoded string) (string, error) { + ciphertext, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(m.encryptionKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} diff --git a/internal/cluster/manager_test.go b/internal/cluster/manager_test.go new file mode 100644 index 0000000..d406560 --- /dev/null +++ b/internal/cluster/manager_test.go @@ -0,0 +1,295 @@ +package cluster + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + "time" +) + +func setupTestManager(t *testing.T) (*Manager, *DB, func()) { + tmpDir, err := os.MkdirTemp("", "cluster_mgr_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + db, err := NewDB(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create DB: %v", err) + } + + mgr := NewManager(db, "test-server", 1*time.Second, 5*time.Second, "test-jwt-secret") + + cleanup := func() { + mgr.Stop() + db.Close() + os.RemoveAll(tmpDir) + } + + return mgr, db, cleanup +} + +func TestManagerServerName(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + if mgr.ServerName() != "test-server" { + t.Errorf("ServerName = %s, want test-server", mgr.ServerName()) + } +} + +func TestManagerAddAndGetPeer(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + err := mgr.AddPeer("peer-a", "https://a.example.com", "api-key-for-a") + if err != nil { + t.Fatalf("AddPeer failed: %v", err) + } + + client, err := mgr.GetPeer("peer-a") + if err != nil { + t.Fatalf("GetPeer failed: %v", err) + } + if client == nil { + t.Fatal("GetPeer returned nil client") + } +} + +func TestManagerGetPeerNotFound(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + _, err := mgr.GetPeer("nonexistent") + if err == nil { + t.Error("GetPeer should fail for nonexistent peer") + } +} + +func TestManagerListPeers(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + peers := mgr.ListPeers() + if len(peers) != 0 { + t.Errorf("Expected 0 peers, got %d", len(peers)) + } + + _ = mgr.AddPeer("peer-a", "https://a.example.com", "key-a") + _ = mgr.AddPeer("peer-b", "https://b.example.com", "key-b") + + peers = mgr.ListPeers() + if len(peers) != 2 { + t.Errorf("Expected 2 peers, got %d", len(peers)) + } +} + +func TestManagerRemovePeer(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + _ = mgr.AddPeer("remove-me", "https://rm.example.com", "key") + + err := mgr.RemovePeer("remove-me") + if err != nil { + t.Fatalf("RemovePeer failed: %v", err) + } + + _, err = mgr.GetPeer("remove-me") + if err == nil { + t.Error("GetPeer should fail after RemovePeer") + } + + peers := mgr.ListPeers() + if len(peers) != 0 { + t.Errorf("Expected 0 peers after removal, got %d", len(peers)) + } +} + +func TestManagerEncryptDecryptRoundtrip(t *testing.T) { + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + original := "my-secret-api-key-12345" + encrypted, err := mgr.encrypt(original) + if err != nil { + t.Fatalf("encrypt failed: %v", err) + } + + if encrypted == original { + t.Error("Encrypted should differ from original") + } + + decrypted, err := mgr.decrypt(encrypted) + if err != nil { + t.Fatalf("decrypt failed: %v", err) + } + + if decrypted != original { + t.Errorf("Decrypted = %s, want %s", decrypted, original) + } +} + +func TestManagerStartLoadsExistingPeers(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "cluster_start_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + + mgr1 := NewManager(db, "server-1", 30*time.Second, 5*time.Second, "test-secret") + _ = mgr1.AddPeer("pre-existing", "https://pre.example.com", "pre-key") + mgr1.Stop() + + mgr2 := NewManager(db, "server-1", 30*time.Second, 5*time.Second, "test-secret") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err = mgr2.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + defer mgr2.Stop() + + peers := mgr2.ListPeers() + if len(peers) != 1 { + t.Fatalf("Expected 1 peer after restart, got %d", len(peers)) + } + if peers[0].Name != "pre-existing" { + t.Errorf("Peer name = %s, want pre-existing", peers[0].Name) + } + + _, err = mgr2.GetPeer("pre-existing") + if err != nil { + t.Error("Should be able to get pre-existing peer client after restart") + } + + db.Close() +} + +func TestManagerForEachPeer(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{"server": "one"}) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{"server": "two"}) + })) + defer server2.Close() + + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + _ = mgr.AddPeer("server-one", server1.URL, "key1") + _ = mgr.AddPeer("server-two", server2.URL, "key2") + + results := mgr.ForEachPeer(context.Background(), func(ctx context.Context, name string, client *Client) ([]byte, error) { + data, _, err := client.Get(ctx, "/api/test") + return data, err + }) + + if len(results) != 2 { + t.Fatalf("Expected 2 results, got %d", len(results)) + } + + for name, result := range results { + if result.Error != "" { + t.Errorf("Peer %s returned error: %s", name, result.Error) + } + if result.Data == nil { + t.Errorf("Peer %s returned nil data", name) + } + } +} + +func TestManagerForEachPeerWithFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{"ok": "true"}) + })) + defer server.Close() + + mgr, _, cleanup := setupTestManager(t) + defer cleanup() + + _ = mgr.AddPeer("good", server.URL, "key1") + _ = mgr.AddPeer("bad", "http://127.0.0.1:1", "key2") + + results := mgr.ForEachPeer(context.Background(), func(ctx context.Context, name string, client *Client) ([]byte, error) { + data, _, err := client.Get(ctx, "/api/test") + return data, err + }) + + if results["good"].Error != "" { + t.Errorf("Good peer should not have error: %s", results["good"].Error) + } + if results["bad"].Error == "" { + t.Error("Bad peer should have error") + } +} + +func TestManagerHealthChecks(t *testing.T) { + var healthCalls int64 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/health" { + atomic.AddInt64(&healthCalls, 1) + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "cluster_health_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + mgr := NewManager(db, "health-test", 100*time.Millisecond, 5*time.Second, "test-secret") + _ = mgr.AddPeer("healthy-peer", server.URL, "key") + + ctx, cancel := context.WithCancel(context.Background()) + _ = mgr.Start(ctx) + + time.Sleep(350 * time.Millisecond) + + cancel() + mgr.Stop() + + calls := atomic.LoadInt64(&healthCalls) + if calls < 2 { + t.Errorf("Expected at least 2 health checks, got %d", calls) + } + + peers := mgr.ListPeers() + found := false + for _, p := range peers { + if p.Name == "healthy-peer" { + found = true + if !p.Online { + t.Error("Healthy peer should be online") + } + } + } + if !found { + t.Error("healthy-peer not found in peer list") + } +} diff --git a/internal/docker/discovery_test.go b/internal/docker/discovery_test.go index f15756d..369e5fb 100644 --- a/internal/docker/discovery_test.go +++ b/internal/docker/discovery_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "testing" ) @@ -421,7 +422,11 @@ func TestExtractBindMounts(t *testing.T) { if len(result) != len(tt.expected) { t.Fatalf("ExtractBindMounts returned %d paths, want %d: got %v", len(result), len(tt.expected), result) } - for i, path := range tt.expected { + sort.Strings(result) + sorted := make([]string, len(tt.expected)) + copy(sorted, tt.expected) + sort.Strings(sorted) + for i, path := range sorted { if result[i] != path { t.Errorf("ExtractBindMounts[%d] = %q, want %q", i, result[i], path) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 4307447..97dcaf6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,6 +9,14 @@ import ( "gopkg.in/yaml.v3" ) +type ClusterConfig struct { + Enabled bool `yaml:"enabled"` + ServerName string `yaml:"server_name"` + AdvertiseURL string `yaml:"advertise_url"` + HealthInterval string `yaml:"health_interval"` + RequestTimeout string `yaml:"request_timeout"` +} + type Config struct { DeploymentsPath string `yaml:"deployments_path"` DockerSocket string `yaml:"docker_socket"` @@ -22,6 +30,7 @@ type Config struct { Infrastructure InfrastructureConfig `yaml:"infrastructure"` Security SecurityConfig `yaml:"security"` Audit AuditConfig `yaml:"audit"` + Cluster ClusterConfig `yaml:"cluster"` } type DomainConfig struct { @@ -335,6 +344,20 @@ func setDefaults(cfg *Config) { if cfg.Audit.SensitiveFields == nil { cfg.Audit.SensitiveFields = []string{"password", "token", "secret", "api_key", "authorization"} } + // Cluster defaults + if cfg.Cluster.ServerName == "" { + hostname, err := os.Hostname() + if err != nil { + hostname = "flatrun-agent" + } + cfg.Cluster.ServerName = hostname + } + if cfg.Cluster.HealthInterval == "" { + cfg.Cluster.HealthInterval = "30s" + } + if cfg.Cluster.RequestTimeout == "" { + cfg.Cluster.RequestTimeout = "10s" + } } func Save(cfg *Config, path string) error {