diff --git a/drivers/config.go b/drivers/config.go new file mode 100644 index 0000000..f8bbdce --- /dev/null +++ b/drivers/config.go @@ -0,0 +1,76 @@ +package drivers + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/url" + "strings" + + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" +) + +type DatabaseConfiguration struct { + Connection string `json:"connection"` + Address string `json:"addr"` + Database string `json:"database"` + Username string `json:"username"` + Secret string `json:"secret"` + MaxConcurrentSessions int `json:"max_concurrent_sessions"` + EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"` +} + +func (s DatabaseConfiguration) defaultPostgreSQLConnectionString() string { + if s.Connection != "" { + return s.Connection + } + + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(s.Secret), s.Address, s.Database) +} + +func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string { + slog.Info("Loading RDS Configuration With IAM Auth") + + if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil { + slog.Error("AWS Config Loading Error", slog.String("err", err.Error())) + } else { + host := s.Address + + if hostCName, err := net.LookupCNAME(s.Address); err != nil { + slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error())) + } else { + host = hostCName + } + + endpoint := strings.TrimSuffix(host, ".") + ":5432" + + slog.Info("Requesting RDS IAM Auth Token") + + if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil { + slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error())) + } else { + slog.Info("RDS IAM Auth Token Created") + return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database) + } + } + + return s.defaultPostgreSQLConnectionString() +} + +func (s DatabaseConfiguration) PostgreSQLConnectionString() string { + if s.EnableRDSIAMAuth { + return s.RDSIAMAuthConnectionString() + } + + return s.defaultPostgreSQLConnectionString() +} + +func (s DatabaseConfiguration) Neo4jConnectionString() string { + if s.Connection == "" { + return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database) + } + + return s.Connection +} diff --git a/drivers/pg/pg.go b/drivers/pg/pg.go index 1e2d0a2..047c1de 100644 --- a/drivers/pg/pg.go +++ b/drivers/pg/pg.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/specterops/dawgs" "github.com/specterops/dawgs/cypher/models/pgsql" + "github.com/specterops/dawgs/drivers" "github.com/specterops/dawgs/graph" ) @@ -50,15 +51,12 @@ func afterPooledConnectionRelease(conn *pgx.Conn) bool { return true } -func NewPool(connectionString string) (*pgxpool.Pool, error) { - if connectionString == "" { - return nil, fmt.Errorf("graph connection requires a connection url to be set") - } +func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) { poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout) defer done() - poolCfg, err := pgxpool.ParseConfig(connectionString) + poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()) if err != nil { return nil, err } @@ -73,6 +71,21 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) { poolCfg.AfterConnect = afterPooledConnectionEstablished poolCfg.AfterRelease = afterPooledConnectionRelease + if cfg.EnableRDSIAMAuth { + // Only enable the BeforeConnect handler if RDS IAM Auth is enabled + poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error { + slog.Debug("New Connection RDS IAM Auth") + + if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil { + return err + } else { + connCfg.Password = newPoolCfg.ConnConfig.Password + } + + return nil + } + } + pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg) if err != nil { return nil, err diff --git a/go.mod b/go.mod index e294cab..f3a1173 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( cuelang.org/go v0.15.3 github.com/RoaringBitmap/roaring/v2 v2.14.4 github.com/antlr4-go/antlr/v4 v4.13.1 + github.com/aws/aws-sdk-go-v2/config v1.31.13 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 github.com/axiomhq/hyperloglog v0.2.6 github.com/bits-and-blooms/bitset v1.24.4 github.com/cespare/xxhash/v2 v2.3.0 @@ -17,6 +19,18 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect diff --git a/go.sum b/go.sum index e9c170f..a01871e 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,34 @@ github.com/RoaringBitmap/roaring/v2 v2.14.4 h1:4aKySrrg9G/5oRtJ3TrZLObVqxgQ9f1zn github.com/RoaringBitmap/roaring/v2 v2.14.4/go.mod h1:oMvV6omPWr+2ifRdeZvVJyaz+aoEUopyv5iH0u/+wbY= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM= +github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 h1:xfgjONWMae6+y//dlhVukwt9N+I++FPuiwcQt7DI7Qg= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10/go.mod h1:FO6aarJTHA2N3S8F2A4wKfnX9Jr6MPerJFaqoLgTctU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/axiomhq/hyperloglog v0.2.6 h1:sRhvvF3RIXWQgAXaTphLp4yJiX4S0IN3MWTaAgZoRJw= github.com/axiomhq/hyperloglog v0.2.6/go.mod h1:YjX/dQqCR/7QYX0g8mu8UZAjpIenz1FKM71UEsjFoTo= github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE=