Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions drivers/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package drivers

import (
"context"
"fmt"
"log/slog"
"net"
"net/url"
"strings"

awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
)

type DatabaseConfiguration struct {
Connection string `json:"connection"`
Address string `json:"addr"`
Database string `json:"database"`
Username string `json:"username"`
Secret string `json:"secret"`
MaxConcurrentSessions int `json:"max_concurrent_sessions"`
EnableRDSIAMAuth bool `json:"enable_rds_iam_auth"`
}

func (s DatabaseConfiguration) defaultPostgreSQLConnectionString() string {
if s.Connection != "" {
return s.Connection
}

return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(s.Secret), s.Address, s.Database)
}

func (s DatabaseConfiguration) RDSIAMAuthConnectionString() string {
slog.Info("Loading RDS Configuration With IAM Auth")

if cfg, err := awsConfig.LoadDefaultConfig(context.TODO()); err != nil {
slog.Error("AWS Config Loading Error", slog.String("err", err.Error()))
} else {
host := s.Address

if hostCName, err := net.LookupCNAME(s.Address); err != nil {
slog.Warn("Error looking up CNAME for DB host. Using original address.", slog.String("err", err.Error()))
} else {
host = hostCName
}

endpoint := strings.TrimSuffix(host, ".") + ":5432"

slog.Info("Requesting RDS IAM Auth Token")

if authenticationToken, err := auth.BuildAuthToken(context.TODO(), endpoint, cfg.Region, s.Username, cfg.Credentials); err != nil {
slog.Error("RDS IAM Auth Token Request Error", slog.String("err", err.Error()))
} else {
slog.Info("RDS IAM Auth Token Created")
return fmt.Sprintf("postgresql://%s:%s@%s/%s", s.Username, url.QueryEscape(authenticationToken), endpoint, s.Database)
}
}

return s.defaultPostgreSQLConnectionString()
}

func (s DatabaseConfiguration) PostgreSQLConnectionString() string {
if s.EnableRDSIAMAuth {
return s.RDSIAMAuthConnectionString()
}

return s.defaultPostgreSQLConnectionString()
}

func (s DatabaseConfiguration) Neo4jConnectionString() string {
if s.Connection == "" {
return fmt.Sprintf("neo4j://%s:%s@%s/%s", s.Username, s.Secret, s.Address, s.Database)
}

return s.Connection
}
23 changes: 18 additions & 5 deletions drivers/pg/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/specterops/dawgs"
"github.com/specterops/dawgs/cypher/models/pgsql"
"github.com/specterops/dawgs/drivers"
"github.com/specterops/dawgs/graph"
)

Expand Down Expand Up @@ -50,15 +51,12 @@ func afterPooledConnectionRelease(conn *pgx.Conn) bool {
return true
}

func NewPool(connectionString string) (*pgxpool.Pool, error) {
if connectionString == "" {
return nil, fmt.Errorf("graph connection requires a connection url to be set")
}
func NewPool(cfg drivers.DatabaseConfiguration) (*pgxpool.Pool, error) {

poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout)
defer done()

poolCfg, err := pgxpool.ParseConfig(connectionString)
poolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString())
if err != nil {
return nil, err
}
Expand All @@ -73,6 +71,21 @@ func NewPool(connectionString string) (*pgxpool.Pool, error) {
poolCfg.AfterConnect = afterPooledConnectionEstablished
poolCfg.AfterRelease = afterPooledConnectionRelease

if cfg.EnableRDSIAMAuth {
// Only enable the BeforeConnect handler if RDS IAM Auth is enabled
poolCfg.BeforeConnect = func(ctx context.Context, connCfg *pgx.ConnConfig) error {
slog.Debug("New Connection RDS IAM Auth")

if newPoolCfg, err := pgxpool.ParseConfig(cfg.PostgreSQLConnectionString()); err != nil {
return err
} else {
connCfg.Password = newPoolCfg.ConnConfig.Password
}

return nil
}
}

pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg)
if err != nil {
return nil, err
Expand Down
14 changes: 14 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ require (
cuelang.org/go v0.15.3
github.com/RoaringBitmap/roaring/v2 v2.14.4
github.com/antlr4-go/antlr/v4 v4.13.1
github.com/aws/aws-sdk-go-v2/config v1.31.13
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10
github.com/axiomhq/hyperloglog v0.2.6
github.com/bits-and-blooms/bitset v1.24.4
github.com/cespare/xxhash/v2 v2.3.0
Expand All @@ -17,6 +19,18 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect
github.com/aws/smithy-go v1.23.1 // indirect
github.com/cockroachdb/apd/v3 v3.2.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect
Expand Down
28 changes: 28 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@ github.com/RoaringBitmap/roaring/v2 v2.14.4 h1:4aKySrrg9G/5oRtJ3TrZLObVqxgQ9f1zn
github.com/RoaringBitmap/roaring/v2 v2.14.4/go.mod h1:oMvV6omPWr+2ifRdeZvVJyaz+aoEUopyv5iH0u/+wbY=
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw=
github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM=
github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM=
github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k=
github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10 h1:xfgjONWMae6+y//dlhVukwt9N+I++FPuiwcQt7DI7Qg=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.10/go.mod h1:FO6aarJTHA2N3S8F2A4wKfnX9Jr6MPerJFaqoLgTctU=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o=
github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M=
github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/axiomhq/hyperloglog v0.2.6 h1:sRhvvF3RIXWQgAXaTphLp4yJiX4S0IN3MWTaAgZoRJw=
github.com/axiomhq/hyperloglog v0.2.6/go.mod h1:YjX/dQqCR/7QYX0g8mu8UZAjpIenz1FKM71UEsjFoTo=
github.com/bits-and-blooms/bitset v1.24.4 h1:95H15Og1clikBrKr/DuzMXkQzECs1M6hhoGXLwLQOZE=
Expand Down