diff --git a/.gitignore b/.gitignore index 6bf2903a..4aa2917e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +go.mod +go.sum tags cmd/routing-api/routing-api /routing-api diff --git a/cmd/routing-api/main_test.go b/cmd/routing-api/main_test.go index 579a27b5..c1ecc415 100644 --- a/cmd/routing-api/main_test.go +++ b/cmd/routing-api/main_test.go @@ -6,11 +6,13 @@ import ( "os" "time" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + routingAPI "code.cloudfoundry.org/routing-api" "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" "code.cloudfoundry.org/routing-api/db" "code.cloudfoundry.org/routing-api/models" - "github.com/jinzhu/gorm" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" . "github.com/onsi/gomega/gbytes" @@ -18,6 +20,7 @@ import ( "github.com/onsi/gomega/ghttp" "github.com/tedsuo/ifrit" ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" + "gorm.io/gorm" ) const ( @@ -161,7 +164,7 @@ var _ = Describe("Main", func() { routingAPIConfig := testrunner.GetRoutingAPIConfig(defaultConfig) connectionString, err := db.ConnectionString(&routingAPIConfig.SqlDB) Expect(err).NotTo(HaveOccurred()) - gormDB, err := gorm.Open(routingAPIConfig.SqlDB.Type, connectionString) + gormDB, err := gorm.Open(getGormDialect(routingAPIConfig.SqlDB.Type, connectionString), &gorm.Config{}) Expect(err).NotTo(HaveOccurred()) getRoutes := func() string { @@ -241,7 +244,7 @@ var _ = Describe("Main", func() { } connectionString, err := db.ConnectionString(&routingAPIConfig.SqlDB) Expect(err).NotTo(HaveOccurred()) - gormDB, err = gorm.Open(routingAPIConfig.SqlDB.Type, connectionString) + gormDB, err = gorm.Open(getGormDialect(routingAPIConfig.SqlDB.Type, connectionString), &gorm.Config{}) Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -260,3 +263,16 @@ var _ = Describe("Main", func() { }) }) }) + +func getGormDialect(databaseType string, connectionString string) gorm.Dialector { + var dialect gorm.Dialector + + switch databaseType { + case "postgres": + dialect = postgres.Open(connectionString) + case "mysql": + dialect = mysql.Open(connectionString) + } + + return dialect +} diff --git a/cmd/routing-api/routing_api_suite_test.go b/cmd/routing-api/routing_api_suite_test.go index 03c371a6..a2090bce 100644 --- a/cmd/routing-api/routing_api_suite_test.go +++ b/cmd/routing-api/routing_api_suite_test.go @@ -60,7 +60,8 @@ var ( func TestRoutingAPI(test *testing.T) { RegisterFailHandler(Fail) - RunSpecs(test, "Routing API Test Suite") + suiteConfig, reporterConfig := GinkgoConfiguration() + RunSpecs(test, "Routing API Test Suite", suiteConfig, reporterConfig) } var _ = SynchronizedBeforeSuite( @@ -77,8 +78,9 @@ var _ = SynchronizedBeforeSuite( grpclog.SetLoggerV2(grpclog.NewLoggerV2(io.Discard, io.Discard, io.Discard)) path := string(binPaths) - routingAPIBinPath = strings.Split(path, ",")[0] - locketBinPath = strings.Split(path, ",")[1] + parts := strings.Split(path, ",") + routingAPIBinPath = parts[0] + locketBinPath = parts[1] SetDefaultEventuallyTimeout(15 * time.Second) @@ -113,13 +115,19 @@ var _ = SynchronizedBeforeSuite( ) var _ = SynchronizedAfterSuite(func() { - err := dbAllocator.Delete() - Expect(err).NotTo(HaveOccurred()) + if dbAllocator != nil { + err := dbAllocator.Delete() + Expect(err).NotTo(HaveOccurred()) + } - oAuthServer.Close() + if oAuthServer != nil { + oAuthServer.Close() + } - err = os.Remove(uaaCACertsPath) - Expect(err).NotTo(HaveOccurred()) + if uaaCACertsPath != "" { + err := os.Remove(uaaCACertsPath) + Expect(err).NotTo(HaveOccurred()) + } }, func() { gexec.CleanupBuildArtifacts() }) diff --git a/cmd/routing-api/testrunner/constants.go b/cmd/routing-api/testrunner/constants.go index cf80de42..baef7388 100644 --- a/cmd/routing-api/testrunner/constants.go +++ b/cmd/routing-api/testrunner/constants.go @@ -1,8 +1,9 @@ package testrunner import ( - "code.cloudfoundry.org/routing-api/config" "os" + + "code.cloudfoundry.org/routing-api/config" ) const ( diff --git a/cmd/routing-api/testrunner/db.go b/cmd/routing-api/testrunner/db.go index 8e07d7a1..58d5631d 100644 --- a/cmd/routing-api/testrunner/db.go +++ b/cmd/routing-api/testrunner/db.go @@ -10,9 +10,9 @@ import ( "code.cloudfoundry.org/routing-api/db" "code.cloudfoundry.org/routing-api/config" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" . "github.com/onsi/ginkgo/v2" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" ) type DbAllocator interface { @@ -63,7 +63,7 @@ func (a *postgresAllocator) Create() (*config.SqlDB, error) { if err != nil { return nil, err } - a.sqlDB, err = sql.Open("postgres", connStr) + a.sqlDB, err = sql.Open("pgx", connStr) if err != nil { return nil, err } diff --git a/db/client.go b/db/client.go index 6a9b68ee..1ec7a3b8 100644 --- a/db/client.go +++ b/db/client.go @@ -3,7 +3,7 @@ package db import ( "database/sql" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) //go:generate counterfeiter -o fakes/fake_client.go . Client @@ -13,7 +13,7 @@ type Client interface { Create(value interface{}) (int64, error) Delete(value interface{}, where ...interface{}) (int64, error) Save(value interface{}) (int64, error) - Update(attrs ...interface{}) (int64, error) + Update(column string, value interface{}) (int64, error) First(out interface{}, where ...interface{}) error Find(out interface{}, where ...interface{}) error AutoMigrate(values ...interface{}) error @@ -21,13 +21,15 @@ type Client interface { Rollback() error Commit() error HasTable(value interface{}) bool - AddUniqueIndex(indexName string, columns ...string) (Client, error) - RemoveIndex(indexName string) (Client, error) + AddUniqueIndex(indexName string, columns interface{}) error + RemoveIndex(indexName string, columns interface{}) error Model(value interface{}) Client Exec(query string, args ...interface{}) int64 + ExecWithError(query string, args ...interface{}) error Rows(tableName string) (*sql.Rows, error) DropColumn(column string) error - Dialect() gorm.Dialect + Dialect() gorm.Dialector + Migrator() gorm.Migrator } type gormClient struct { @@ -38,25 +40,25 @@ func NewGormClient(db *gorm.DB) Client { return &gormClient{db: db} } func (c *gormClient) DropColumn(name string) error { - return c.db.DropColumn(name).Error + return c.db.Migrator().DropColumn(c.db.Statement.Table, name) } func (c *gormClient) Close() error { - return c.db.Close() + sqlDB, err := c.db.DB() + if err != nil { + return err + } + return sqlDB.Close() } -func (c *gormClient) AddUniqueIndex(indexName string, columns ...string) (Client, error) { - var newClient gormClient - newClient.db = c.db.AddUniqueIndex(indexName, columns...) - return &newClient, newClient.db.Error +func (c *gormClient) AddUniqueIndex(indexName string, columns interface{}) error { + return c.db.Migrator().CreateIndex(columns, indexName) } -func (c *gormClient) Dialect() gorm.Dialect { - return c.db.Dialect() +func (c *gormClient) Dialect() gorm.Dialector { + return c.db.Dialector } -func (c *gormClient) RemoveIndex(indexName string) (Client, error) { - var newClient gormClient - newClient.db = c.db.RemoveIndex(indexName) - return &newClient, newClient.db.Error +func (c *gormClient) RemoveIndex(indexName string, columns interface{}) error { + return c.db.Migrator().DropIndex(columns, indexName) } func (c *gormClient) Model(value interface{}) Client { @@ -85,8 +87,8 @@ func (c *gormClient) Save(value interface{}) (int64, error) { return newDb.RowsAffected, newDb.Error } -func (c *gormClient) Update(attrs ...interface{}) (int64, error) { - newDb := c.db.Update(attrs...) +func (c *gormClient) Update(column string, value interface{}) (int64, error) { + newDb := c.db.Update(column, value) return newDb.RowsAffected, newDb.Error } @@ -99,7 +101,7 @@ func (c *gormClient) Find(out interface{}, where ...interface{}) error { } func (c *gormClient) AutoMigrate(values ...interface{}) error { - return c.db.AutoMigrate(values...).Error + return c.db.AutoMigrate(values...) } func (c *gormClient) Begin() Client { @@ -117,14 +119,22 @@ func (c *gormClient) Commit() error { } func (c *gormClient) HasTable(value interface{}) bool { - return c.db.HasTable(value) + return c.db.Migrator().HasTable(value) } func (c *gormClient) Exec(query string, args ...interface{}) int64 { - dbClient := c.db.Exec(query, args) + dbClient := c.db.Exec(query, args...) return dbClient.RowsAffected } +func (c *gormClient) ExecWithError(query string, args ...interface{}) error { + return c.db.Exec(query, args...).Error +} + +func (c *gormClient) Migrator() gorm.Migrator { + return c.db.Migrator() +} + func (c *gormClient) Rows(tablename string) (*sql.Rows, error) { tableDb := c.db.Table(tablename) return tableDb.Rows() diff --git a/db/db_sql.go b/db/db_sql.go index 67e32edc..8353571f 100644 --- a/db/db_sql.go +++ b/db/db_sql.go @@ -10,15 +10,16 @@ import ( "sync/atomic" "time" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "code.cloudfoundry.org/clock" "code.cloudfoundry.org/eventhub" "code.cloudfoundry.org/lager/v3" "code.cloudfoundry.org/routing-api/config" "code.cloudfoundry.org/routing-api/models" - "github.com/jinzhu/gorm" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/postgres" + "gorm.io/gorm" ) //go:generate counterfeiter -o fakes/fake_db.go . DB @@ -102,27 +103,38 @@ var DeleteRouterGroupError = DBError{Type: KeyNotFound, Message: "Delete Fails: func NewSqlDB(cfg *config.SqlDB) (*SqlDB, error) { if cfg == nil { - return nil, errors.New("SQL configuration cannot be nil") + return nil, errors.New("Sql configuration cannot be nil") } - if cfg.Type != "mysql" && cfg.Type != "postgres" { - return &SqlDB{}, fmt.Errorf("Unknown type %s", cfg.Type) + connStr, err := ConnectionString(cfg) + if err != nil { + return nil, err } - connStr, err := ConnectionString(cfg) + var dialect gorm.Dialector + switch cfg.Type { + case "postgres": + dialect = postgres.Open(connStr) + case "mysql": + dialect = mysql.Open(connStr) + default: + return &SqlDB{}, fmt.Errorf("unknown type %s", cfg.Type) + } + + db, err := gorm.Open(dialect, &gorm.Config{}) if err != nil { return nil, err } - db, err := gorm.Open(cfg.Type, connStr) + sqlDB, err := db.DB() if err != nil { return nil, err } - db.DB().SetMaxIdleConns(cfg.MaxIdleConns) - db.DB().SetMaxOpenConns(cfg.MaxOpenConns) + sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) + sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) connMaxLifetime := time.Duration(cfg.ConnMaxLifetime) * time.Second - db.DB().SetConnMaxLifetime(connMaxLifetime) + sqlDB.SetConnMaxLifetime(connMaxLifetime) tcpEventHub := eventhub.NewNonBlocking(1024) httpEventHub := eventhub.NewNonBlocking(1024) @@ -140,7 +152,7 @@ func (s *SqlDB) FindExpiredRoutes(routes interface{}, c clock.Clock) error { // postgres stores at microsecond precision. we subtract a second from expiry time to give // us an extra second of buffer to account for rounding issues: // if we tell the db to save an expiry of 5.3s, and we query at 5.2s, mysql will think it expired, - // as the db will compare 5s against 5.2s. Oops. + // as the db will compare 5s against 5.2s. Oops. return s.Client.Find(routes, "expires_at < ?", c.Now().Add(-1*time.Second)) } @@ -295,7 +307,8 @@ func (s *SqlDB) DeleteRouterGroup(guid string) error { return DeleteRouterGroupError } - _, err = s.Client.Delete(&routerGroup) + // Use WHERE clause to delete the specific router group by guid + _, err = s.Client.Where("guid = ?", guid).Delete(&models.RouterGroupDB{}) if err != nil { return err } diff --git a/db/db_sql_test.go b/db/db_sql_test.go index 6d3b1dde..d5681211 100644 --- a/db/db_sql_test.go +++ b/db/db_sql_test.go @@ -414,7 +414,7 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&rg) + _, err = sqlDB.Client.Where("guid = ?", rg.Guid).Delete(&models.RouterGroupDB{}) Expect(err).ToNot(HaveOccurred()) }) @@ -534,7 +534,7 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&rg) + _, err = sqlDB.Client.Where("guid = ?", rg.Guid).Delete(&models.RouterGroupDB{}) Expect(err).ToNot(HaveOccurred()) }) @@ -589,9 +589,7 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&models.RouterGroupDB{ - Model: models.Model{Guid: routerGroupId}, - }) + _, err = sqlDB.Client.Where("guid = ?", routerGroupId).Delete(&models.RouterGroupDB{}) Expect(err).ToNot(HaveOccurred()) }) @@ -705,7 +703,7 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&routerGroupDB2) + _, err = sqlDB.Client.Where("guid = ?", routerGroupDB2.Guid).Delete(&models.RouterGroupDB{}) Expect(err).ToNot(HaveOccurred()) }) @@ -771,7 +769,7 @@ var _ = Describe("SqlDB", func() { AfterEach(func() { for _, tcpRoute := range tcpRoutes { - rowsAffected, err := sqlDB.Client.Delete(&tcpRoute) + rowsAffected, err := sqlDB.Client.Where("guid = ?", tcpRoute.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).NotTo(HaveOccurred()) Expect(rowsAffected).To(BeEquivalentTo(1)) } @@ -813,7 +811,7 @@ var _ = Describe("SqlDB", func() { AfterEach(func() { if tcpRoute.Guid != "" { - _, err := sqlDB.Client.Delete(&tcpRoute) + _, err := sqlDB.Client.Where("guid = ?", tcpRoute.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).ToNot(HaveOccurred()) } }) @@ -876,9 +874,9 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err := sqlDB.Client.Delete(&tcpRoute1) + _, err := sqlDB.Client.Where("guid = ?", tcpRoute1.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).ToNot(HaveOccurred()) - _, err = sqlDB.Client.Delete(&tcpRoute2) + _, err = sqlDB.Client.Where("guid = ?", tcpRoute2.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).ToNot(HaveOccurred()) }) @@ -932,7 +930,7 @@ var _ = Describe("SqlDB", func() { Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRoute) + _, err = sqlDB.Client.Where("guid = ?", tcpRoute.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).NotTo(HaveOccurred()) }) @@ -1009,7 +1007,7 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRoute) + _, err = sqlDB.Client.Where("guid = ?", tcpRoute.Guid).Delete(&models.TcpRouteMapping{}) Expect(err).ToNot(HaveOccurred()) }) @@ -1057,8 +1055,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRoute2) - Expect(err).ToNot(HaveOccurred()) + if tcpRoute2.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", tcpRoute2.Guid).Delete(&models.TcpRouteMapping{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("creates another tcp route", func() { @@ -1141,8 +1141,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRouteWithModel) - Expect(err).ToNot(HaveOccurred()) + if tcpRouteWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", tcpRouteWithModel.Guid).Delete(&models.TcpRouteMapping{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("returns the tcp routes", func() { @@ -1168,8 +1170,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&expiredTcpRouteWithModel) - Expect(err).ToNot(HaveOccurred()) + if expiredTcpRouteWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", expiredTcpRouteWithModel.Guid).Delete(&models.TcpRouteMapping{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("does not return the tcp route", func() { @@ -1236,8 +1240,10 @@ var _ = Describe("SqlDB", func() { AfterEach(func() { for _, tcpRouteWithModel := range tcpRoutesWithModel { - _, err = sqlDB.Client.Delete(&tcpRouteWithModel) - Expect(err).ToNot(HaveOccurred()) + if tcpRouteWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", tcpRouteWithModel.Guid).Delete(&models.TcpRouteMapping{}) + Expect(err).ToNot(HaveOccurred()) + } } }) @@ -1314,8 +1320,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRouteWithModel) - Expect(err).ToNot(HaveOccurred()) + if tcpRouteWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", tcpRouteWithModel.Guid).Delete(&models.TcpRouteMapping{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("deletes the tcp route", func() { @@ -1339,8 +1347,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&tcpRouteWithModel2) - Expect(err).ToNot(HaveOccurred()) + if tcpRouteWithModel2.Guid != "" { + _, err = sqlDB.Client.Delete(&tcpRouteWithModel2) + Expect(err).ToNot(HaveOccurred()) + } }) It("does not delete everything", func() { @@ -1379,8 +1389,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&httpRoute) - Expect(err).ToNot(HaveOccurred()) + if httpRoute.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", httpRoute.Guid).Delete(&models.Route{}) + Expect(err).ToNot(HaveOccurred()) + } }) Context("when the http route already exists", func() { @@ -1485,8 +1497,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&routeWithModel) - Expect(err).ToNot(HaveOccurred()) + if routeWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", routeWithModel.Guid).Delete(&models.Route{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("returns the routes", func() { @@ -1512,8 +1526,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&expiredRouteWithModel) - Expect(err).ToNot(HaveOccurred()) + if expiredRouteWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", expiredRouteWithModel.Guid).Delete(&models.Route{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("does not return the route", func() { @@ -1569,8 +1585,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&routeWithModel) - Expect(err).ToNot(HaveOccurred()) + if routeWithModel.Guid != "" { + _, err = sqlDB.Client.Where("guid = ?", routeWithModel.Guid).Delete(&models.Route{}) + Expect(err).ToNot(HaveOccurred()) + } }) It("deletes the route", func() { @@ -1594,8 +1612,10 @@ var _ = Describe("SqlDB", func() { }) AfterEach(func() { - _, err = sqlDB.Client.Delete(&routeWithModel2) - Expect(err).ToNot(HaveOccurred()) + if routeWithModel2.Guid != "" { + _, err = sqlDB.Client.Delete(&routeWithModel2) + Expect(err).ToNot(HaveOccurred()) + } }) It("deletes the specified route", func() { @@ -2054,7 +2074,7 @@ var _ = Describe("SqlDB", func() { c := atomic.AddInt32(&count, 1) if c > 5 && c < 10 { return 0, errors.New("temp-error") - } else if c >= 10 { + } else if c >= 10 && c < 15 { return 111, nil } else { return 1, nil @@ -2063,14 +2083,13 @@ var _ = Describe("SqlDB", func() { }) It("eventually resolves the issue", func() { - timeout := 2.5 + timeout := 10.0 Eventually(logger, timeout).Should(gbytes.Say(`"prune.successfully-finished-pruning-tcp-routes","log_level":1,"data":{"rowsAffected":1}`)) - Eventually(logger, timeout).Should(gbytes.Say(`"prune.successfully-finished-pruning-http-routes","log_level":1,"data":{"rowsAffected":1}`)) + Eventually(logger, timeout).Should(gbytes.Say(`failed-to-prune-tcp-routes","log_level":2,"data":{"error":"temp-error"}`)) - Eventually(logger, timeout).Should(gbytes.Say(`failed-to-prune-http-routes","log_level":2,"data":{"error":"temp-error"}`)) + Eventually(logger, timeout).Should(gbytes.Say(`"prune.successfully-finished-pruning-tcp-routes","log_level":1,"data":{"rowsAffected":111}`)) - Eventually(logger, timeout).Should(gbytes.Say(`"prune.successfully-finished-pruning-http-routes","log_level":1,"data":{"rowsAffected":111}`)) }) }) }) diff --git a/db/db_suite_test.go b/db/db_suite_test.go index f1b701d7..05fa50c9 100644 --- a/db/db_suite_test.go +++ b/db/db_suite_test.go @@ -5,7 +5,6 @@ import ( "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" "code.cloudfoundry.org/routing-api/config" - _ "github.com/lib/pq" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) diff --git a/db/fakes/fake_client.go b/db/fakes/fake_client.go index ea415f61..819e2eec 100644 --- a/db/fakes/fake_client.go +++ b/db/fakes/fake_client.go @@ -6,23 +6,21 @@ import ( "sync" "code.cloudfoundry.org/routing-api/db" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) type FakeClient struct { - AddUniqueIndexStub func(string, ...string) (db.Client, error) + AddUniqueIndexStub func(string, interface{}) error addUniqueIndexMutex sync.RWMutex addUniqueIndexArgsForCall []struct { arg1 string - arg2 []string + arg2 interface{} } addUniqueIndexReturns struct { - result1 db.Client - result2 error + result1 error } addUniqueIndexReturnsOnCall map[int]struct { - result1 db.Client - result2 error + result1 error } AutoMigrateStub func(...interface{}) error autoMigrateMutex sync.RWMutex @@ -92,15 +90,15 @@ type FakeClient struct { result1 int64 result2 error } - DialectStub func() gorm.Dialect + DialectStub func() gorm.Dialector dialectMutex sync.RWMutex dialectArgsForCall []struct { } dialectReturns struct { - result1 gorm.Dialect + result1 gorm.Dialector } dialectReturnsOnCall map[int]struct { - result1 gorm.Dialect + result1 gorm.Dialector } DropColumnStub func(string) error dropColumnMutex sync.RWMutex @@ -125,6 +123,18 @@ type FakeClient struct { execReturnsOnCall map[int]struct { result1 int64 } + ExecWithErrorStub func(string, ...interface{}) error + execWithErrorMutex sync.RWMutex + execWithErrorArgsForCall []struct { + arg1 string + arg2 []interface{} + } + execWithErrorReturns struct { + result1 error + } + execWithErrorReturnsOnCall map[int]struct { + result1 error + } FindStub func(interface{}, ...interface{}) error findMutex sync.RWMutex findArgsForCall []struct { @@ -171,18 +181,27 @@ type FakeClient struct { modelReturnsOnCall map[int]struct { result1 db.Client } - RemoveIndexStub func(string) (db.Client, error) + MigratorStub func() gorm.Migrator + migratorMutex sync.RWMutex + migratorArgsForCall []struct { + } + migratorReturns struct { + result1 gorm.Migrator + } + migratorReturnsOnCall map[int]struct { + result1 gorm.Migrator + } + RemoveIndexStub func(string, interface{}) error removeIndexMutex sync.RWMutex removeIndexArgsForCall []struct { arg1 string + arg2 interface{} } removeIndexReturns struct { - result1 db.Client - result2 error + result1 error } removeIndexReturnsOnCall map[int]struct { - result1 db.Client - result2 error + result1 error } RollbackStub func() error rollbackMutex sync.RWMutex @@ -220,10 +239,11 @@ type FakeClient struct { result1 int64 result2 error } - UpdateStub func(...interface{}) (int64, error) + UpdateStub func(string, interface{}) (int64, error) updateMutex sync.RWMutex updateArgsForCall []struct { - arg1 []interface{} + arg1 string + arg2 interface{} } updateReturns struct { result1 int64 @@ -249,24 +269,24 @@ type FakeClient struct { invocationsMutex sync.RWMutex } -func (fake *FakeClient) AddUniqueIndex(arg1 string, arg2 ...string) (db.Client, error) { +func (fake *FakeClient) AddUniqueIndex(arg1 string, arg2 interface{}) error { fake.addUniqueIndexMutex.Lock() ret, specificReturn := fake.addUniqueIndexReturnsOnCall[len(fake.addUniqueIndexArgsForCall)] fake.addUniqueIndexArgsForCall = append(fake.addUniqueIndexArgsForCall, struct { arg1 string - arg2 []string + arg2 interface{} }{arg1, arg2}) stub := fake.AddUniqueIndexStub fakeReturns := fake.addUniqueIndexReturns fake.recordInvocation("AddUniqueIndex", []interface{}{arg1, arg2}) fake.addUniqueIndexMutex.Unlock() if stub != nil { - return stub(arg1, arg2...) + return fake.AddUniqueIndexStub(arg1, arg2) } if specificReturn { - return ret.result1, ret.result2 + return ret.result1 } - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1 } func (fake *FakeClient) AddUniqueIndexCallCount() int { @@ -275,43 +295,40 @@ func (fake *FakeClient) AddUniqueIndexCallCount() int { return len(fake.addUniqueIndexArgsForCall) } -func (fake *FakeClient) AddUniqueIndexCalls(stub func(string, ...string) (db.Client, error)) { +func (fake *FakeClient) AddUniqueIndexCalls(stub func(string, interface{}) error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = stub } -func (fake *FakeClient) AddUniqueIndexArgsForCall(i int) (string, []string) { +func (fake *FakeClient) AddUniqueIndexArgsForCall(i int) (string, interface{}) { fake.addUniqueIndexMutex.RLock() defer fake.addUniqueIndexMutex.RUnlock() argsForCall := fake.addUniqueIndexArgsForCall[i] return argsForCall.arg1, argsForCall.arg2 } -func (fake *FakeClient) AddUniqueIndexReturns(result1 db.Client, result2 error) { +func (fake *FakeClient) AddUniqueIndexReturns(result1 error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = nil fake.addUniqueIndexReturns = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } -func (fake *FakeClient) AddUniqueIndexReturnsOnCall(i int, result1 db.Client, result2 error) { +func (fake *FakeClient) AddUniqueIndexReturnsOnCall(i int, result1 error) { fake.addUniqueIndexMutex.Lock() defer fake.addUniqueIndexMutex.Unlock() fake.AddUniqueIndexStub = nil if fake.addUniqueIndexReturnsOnCall == nil { fake.addUniqueIndexReturnsOnCall = make(map[int]struct { - result1 db.Client - result2 error + result1 error }) } fake.addUniqueIndexReturnsOnCall[i] = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } func (fake *FakeClient) AutoMigrate(arg1 ...interface{}) error { @@ -663,7 +680,7 @@ func (fake *FakeClient) DeleteReturnsOnCall(i int, result1 int64, result2 error) }{result1, result2} } -func (fake *FakeClient) Dialect() gorm.Dialect { +func (fake *FakeClient) Dialect() gorm.Dialector { fake.dialectMutex.Lock() ret, specificReturn := fake.dialectReturnsOnCall[len(fake.dialectArgsForCall)] fake.dialectArgsForCall = append(fake.dialectArgsForCall, struct { @@ -687,32 +704,32 @@ func (fake *FakeClient) DialectCallCount() int { return len(fake.dialectArgsForCall) } -func (fake *FakeClient) DialectCalls(stub func() gorm.Dialect) { +func (fake *FakeClient) DialectCalls(stub func() gorm.Dialector) { fake.dialectMutex.Lock() defer fake.dialectMutex.Unlock() fake.DialectStub = stub } -func (fake *FakeClient) DialectReturns(result1 gorm.Dialect) { +func (fake *FakeClient) DialectReturns(result1 gorm.Dialector) { fake.dialectMutex.Lock() defer fake.dialectMutex.Unlock() fake.DialectStub = nil fake.dialectReturns = struct { - result1 gorm.Dialect + result1 gorm.Dialector }{result1} } -func (fake *FakeClient) DialectReturnsOnCall(i int, result1 gorm.Dialect) { +func (fake *FakeClient) DialectReturnsOnCall(i int, result1 gorm.Dialector) { fake.dialectMutex.Lock() defer fake.dialectMutex.Unlock() fake.DialectStub = nil if fake.dialectReturnsOnCall == nil { fake.dialectReturnsOnCall = make(map[int]struct { - result1 gorm.Dialect + result1 gorm.Dialector }) } fake.dialectReturnsOnCall[i] = struct { - result1 gorm.Dialect + result1 gorm.Dialector }{result1} } @@ -839,6 +856,68 @@ func (fake *FakeClient) ExecReturnsOnCall(i int, result1 int64) { }{result1} } +func (fake *FakeClient) ExecWithError(arg1 string, arg2 ...interface{}) error { + fake.execWithErrorMutex.Lock() + ret, specificReturn := fake.execWithErrorReturnsOnCall[len(fake.execWithErrorArgsForCall)] + fake.execWithErrorArgsForCall = append(fake.execWithErrorArgsForCall, struct { + arg1 string + arg2 []interface{} + }{arg1, arg2}) + stub := fake.ExecWithErrorStub + fakeReturns := fake.execWithErrorReturns + fake.recordInvocation("ExecWithError", []interface{}{arg1, arg2}) + fake.execWithErrorMutex.Unlock() + if stub != nil { + return stub(arg1, arg2...) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeClient) ExecWithErrorCallCount() int { + fake.execWithErrorMutex.RLock() + defer fake.execWithErrorMutex.RUnlock() + return len(fake.execWithErrorArgsForCall) +} + +func (fake *FakeClient) ExecWithErrorCalls(stub func(string, ...interface{}) error) { + fake.execWithErrorMutex.Lock() + defer fake.execWithErrorMutex.Unlock() + fake.ExecWithErrorStub = stub +} + +func (fake *FakeClient) ExecWithErrorArgsForCall(i int) (string, []interface{}) { + fake.execWithErrorMutex.RLock() + defer fake.execWithErrorMutex.RUnlock() + argsForCall := fake.execWithErrorArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeClient) ExecWithErrorReturns(result1 error) { + fake.execWithErrorMutex.Lock() + defer fake.execWithErrorMutex.Unlock() + fake.ExecWithErrorStub = nil + fake.execWithErrorReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeClient) ExecWithErrorReturnsOnCall(i int, result1 error) { + fake.execWithErrorMutex.Lock() + defer fake.execWithErrorMutex.Unlock() + fake.ExecWithErrorStub = nil + if fake.execWithErrorReturnsOnCall == nil { + fake.execWithErrorReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.execWithErrorReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeClient) Find(arg1 interface{}, arg2 ...interface{}) error { fake.findMutex.Lock() ret, specificReturn := fake.findReturnsOnCall[len(fake.findArgsForCall)] @@ -1085,23 +1164,77 @@ func (fake *FakeClient) ModelReturnsOnCall(i int, result1 db.Client) { }{result1} } -func (fake *FakeClient) RemoveIndex(arg1 string) (db.Client, error) { +func (fake *FakeClient) Migrator() gorm.Migrator { + fake.migratorMutex.Lock() + ret, specificReturn := fake.migratorReturnsOnCall[len(fake.migratorArgsForCall)] + fake.migratorArgsForCall = append(fake.migratorArgsForCall, struct { + }{}) + stub := fake.MigratorStub + fakeReturns := fake.migratorReturns + fake.recordInvocation("Migrator", []interface{}{}) + fake.migratorMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeClient) MigratorCallCount() int { + fake.migratorMutex.RLock() + defer fake.migratorMutex.RUnlock() + return len(fake.migratorArgsForCall) +} + +func (fake *FakeClient) MigratorCalls(stub func() gorm.Migrator) { + fake.migratorMutex.Lock() + defer fake.migratorMutex.Unlock() + fake.MigratorStub = stub +} + +func (fake *FakeClient) MigratorReturns(result1 gorm.Migrator) { + fake.migratorMutex.Lock() + defer fake.migratorMutex.Unlock() + fake.MigratorStub = nil + fake.migratorReturns = struct { + result1 gorm.Migrator + }{result1} +} + +func (fake *FakeClient) MigratorReturnsOnCall(i int, result1 gorm.Migrator) { + fake.migratorMutex.Lock() + defer fake.migratorMutex.Unlock() + fake.MigratorStub = nil + if fake.migratorReturnsOnCall == nil { + fake.migratorReturnsOnCall = make(map[int]struct { + result1 gorm.Migrator + }) + } + fake.migratorReturnsOnCall[i] = struct { + result1 gorm.Migrator + }{result1} +} + +func (fake *FakeClient) RemoveIndex(arg1 string, arg2 interface{}) error { fake.removeIndexMutex.Lock() ret, specificReturn := fake.removeIndexReturnsOnCall[len(fake.removeIndexArgsForCall)] fake.removeIndexArgsForCall = append(fake.removeIndexArgsForCall, struct { arg1 string - }{arg1}) + arg2 interface{} + }{arg1, arg2}) stub := fake.RemoveIndexStub fakeReturns := fake.removeIndexReturns fake.recordInvocation("RemoveIndex", []interface{}{arg1}) fake.removeIndexMutex.Unlock() if stub != nil { - return stub(arg1) + return fake.RemoveIndexStub(arg1, arg2) } if specificReturn { - return ret.result1, ret.result2 + return ret.result1 } - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1 } func (fake *FakeClient) RemoveIndexCallCount() int { @@ -1110,7 +1243,7 @@ func (fake *FakeClient) RemoveIndexCallCount() int { return len(fake.removeIndexArgsForCall) } -func (fake *FakeClient) RemoveIndexCalls(stub func(string) (db.Client, error)) { +func (fake *FakeClient) RemoveIndexCalls(stub func(string, interface{}) error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = stub @@ -1123,30 +1256,27 @@ func (fake *FakeClient) RemoveIndexArgsForCall(i int) string { return argsForCall.arg1 } -func (fake *FakeClient) RemoveIndexReturns(result1 db.Client, result2 error) { +func (fake *FakeClient) RemoveIndexReturns(result1 error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = nil fake.removeIndexReturns = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } -func (fake *FakeClient) RemoveIndexReturnsOnCall(i int, result1 db.Client, result2 error) { +func (fake *FakeClient) RemoveIndexReturnsOnCall(i int, result1 error) { fake.removeIndexMutex.Lock() defer fake.removeIndexMutex.Unlock() fake.RemoveIndexStub = nil if fake.removeIndexReturnsOnCall == nil { fake.removeIndexReturnsOnCall = make(map[int]struct { - result1 db.Client - result2 error + result1 error }) } fake.removeIndexReturnsOnCall[i] = struct { - result1 db.Client - result2 error - }{result1, result2} + result1 error + }{result1} } func (fake *FakeClient) Rollback() error { @@ -1330,18 +1460,20 @@ func (fake *FakeClient) SaveReturnsOnCall(i int, result1 int64, result2 error) { }{result1, result2} } -func (fake *FakeClient) Update(arg1 ...interface{}) (int64, error) { +func (fake *FakeClient) Update(arg1 string, arg2 interface{}) (int64, error) { fake.updateMutex.Lock() ret, specificReturn := fake.updateReturnsOnCall[len(fake.updateArgsForCall)] fake.updateArgsForCall = append(fake.updateArgsForCall, struct { - arg1 []interface{} - }{arg1}) + arg1 string + arg2 interface{} + }{arg1, arg2}) + fake.recordInvocation("Update", []interface{}{arg1, arg2}) stub := fake.UpdateStub fakeReturns := fake.updateReturns fake.recordInvocation("Update", []interface{}{arg1}) fake.updateMutex.Unlock() if stub != nil { - return stub(arg1...) + return fake.UpdateStub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2 @@ -1355,17 +1487,17 @@ func (fake *FakeClient) UpdateCallCount() int { return len(fake.updateArgsForCall) } -func (fake *FakeClient) UpdateCalls(stub func(...interface{}) (int64, error)) { +func (fake *FakeClient) UpdateCalls(stub func(string, interface{}) (int64, error)) { fake.updateMutex.Lock() defer fake.updateMutex.Unlock() fake.UpdateStub = stub } -func (fake *FakeClient) UpdateArgsForCall(i int) []interface{} { +func (fake *FakeClient) UpdateArgsForCall(i int) (string, interface{}) { fake.updateMutex.RLock() defer fake.updateMutex.RUnlock() argsForCall := fake.updateArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeClient) UpdateReturns(result1 int64, result2 error) { @@ -1479,6 +1611,8 @@ func (fake *FakeClient) Invocations() map[string][][]interface{} { defer fake.dropColumnMutex.RUnlock() fake.execMutex.RLock() defer fake.execMutex.RUnlock() + fake.execWithErrorMutex.RLock() + defer fake.execWithErrorMutex.RUnlock() fake.findMutex.RLock() defer fake.findMutex.RUnlock() fake.firstMutex.RLock() @@ -1487,6 +1621,8 @@ func (fake *FakeClient) Invocations() map[string][][]interface{} { defer fake.hasTableMutex.RUnlock() fake.modelMutex.RLock() defer fake.modelMutex.RUnlock() + fake.migratorMutex.RLock() + defer fake.migratorMutex.RUnlock() fake.removeIndexMutex.RLock() defer fake.removeIndexMutex.RUnlock() fake.rollbackMutex.RLock() diff --git a/migration/V2_update_rg_migration.go b/migration/V2_update_rg_migration.go index 5bc1757d..99b9d0aa 100644 --- a/migration/V2_update_rg_migration.go +++ b/migration/V2_update_rg_migration.go @@ -1,6 +1,8 @@ package migration import ( + "fmt" + "code.cloudfoundry.org/routing-api/db" "code.cloudfoundry.org/routing-api/models" ) @@ -18,6 +20,32 @@ func (v *V2UpdateRgMigration) Version() int { } func (v *V2UpdateRgMigration) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.RouterGroup{}).AddUniqueIndex("idx_rg_name", "name") - return err + // Check for duplicate router group names before creating the unique index + var rgs []models.RouterGroupDB + err := sqlDB.Client.Find(&rgs) + if err != nil { + return err + } + + nameMap := make(map[string]int) + for _, rg := range rgs { + nameMap[rg.Name]++ + } + + for name, count := range nameMap { + if count > 1 { + return fmt.Errorf("cannot create unique index: router group name '%s' appears %d times", name, count) + } + } + + dropIndex(sqlDB, "idx_rg_name", "router_groups") + + // Create unique index - MySQL requires length prefix for text columns + var indexSQL string + if sqlDB.Client.Dialect().Name() == "mysql" { + indexSQL = "CREATE UNIQUE INDEX idx_rg_name ON router_groups (name(191))" + } else { + indexSQL = "CREATE UNIQUE INDEX idx_rg_name ON router_groups (name)" + } + return sqlDB.Client.ExecWithError(indexSQL) } diff --git a/migration/V2_update_rg_migration_test.go b/migration/V2_update_rg_migration_test.go index 79917ac1..74fb27a2 100644 --- a/migration/V2_update_rg_migration_test.go +++ b/migration/V2_update_rg_migration_test.go @@ -55,6 +55,10 @@ var _ = Describe("V2UpdateRgMigration", func() { }) It("does not allow duplicate router group names", func() { + migrator := sqlDB.Client.Migrator() + hasIndex := migrator.HasIndex("router_groups", "idx_rg_name") + Expect(hasIndex).To(BeTrue(), "Index idx_rg_name should exist on router_groups table") + rg1 := models.RouterGroupDB{ Model: models.Model{Guid: "guid-1"}, Name: "rg-1", diff --git a/migration/V3_update_tcp_route_migration.go b/migration/V3_update_tcp_route_migration.go index 967c8e9c..dda0e971 100644 --- a/migration/V3_update_tcp_route_migration.go +++ b/migration/V3_update_tcp_route_migration.go @@ -18,5 +18,5 @@ func (v *V3UpdateTcpRouteMigration) Version() int { } func (v *V3UpdateTcpRouteMigration) Run(sqlDB *db.SqlDB) error { - return sqlDB.Client.Model(models.TcpRouteMapping{}).AutoMigrate(models.TcpRouteMapping{}) + return sqlDB.Client.AutoMigrate(&models.TcpRouteMapping{}) } diff --git a/migration/V3_update_tcp_route_migration_test.go b/migration/V3_update_tcp_route_migration_test.go index 9ed7c353..f4b52580 100644 --- a/migration/V3_update_tcp_route_migration_test.go +++ b/migration/V3_update_tcp_route_migration_test.go @@ -48,7 +48,8 @@ var _ = Describe("V3UpdateTcpRouteMigration", func() { Expect(err).NotTo(HaveOccurred()) Expect(tcpRoutes).To(HaveLen(1)) - err = sqlDB.Client.Model(&models.TcpRouteMapping{}).DropColumn("isolation_segment") + // Use direct SQL instead of GORM DropColumn to work around GORM bug + err = sqlDB.Client.ExecWithError("ALTER TABLE tcp_routes DROP COLUMN isolation_segment") Expect(err).ToNot(HaveOccurred()) rows, err := sqlDB.Client.Rows("tcp_routes") diff --git a/migration/V4_add_rg_uniq_idx_tcp_route_migration.go b/migration/V4_add_rg_uniq_idx_tcp_route_migration.go index dace6c00..6a5a3e53 100644 --- a/migration/V4_add_rg_uniq_idx_tcp_route_migration.go +++ b/migration/V4_add_rg_uniq_idx_tcp_route_migration.go @@ -2,7 +2,6 @@ package migration import ( "code.cloudfoundry.org/routing-api/db" - "code.cloudfoundry.org/routing-api/models" ) type V4AddRgUniqIdxTCPRoute struct{} @@ -18,10 +17,16 @@ func (v *V4AddRgUniqIdxTCPRoute) Version() int { } func (v *V4AddRgUniqIdxTCPRoute) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") - if err != nil { - return err + dropIndex(sqlDB, "idx_tcp_route", "tcp_routes") + + // Create unique index - MySQL requires length prefixes for text columns + var indexSQL string + if sqlDB.Client.Dialect().Name() == "mysql" { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid(191), host_port, host_ip(191), external_port)" + } else { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid, host_port, host_ip, external_port)" } - _, err = sqlDB.Client.Model(&models.TcpRouteMapping{}).AddUniqueIndex("idx_tcp_route", "router_group_guid", "host_port", "host_ip", "external_port") - return err + sqlDB.Client.Exec(indexSQL) + + return nil } diff --git a/migration/V5_sni_hostname_migration.go b/migration/V5_sni_hostname_migration.go index ae9c6221..9ec70faf 100644 --- a/migration/V5_sni_hostname_migration.go +++ b/migration/V5_sni_hostname_migration.go @@ -18,13 +18,20 @@ func (v *V5SniHostnameMigration) Version() int { } func (v *V5SniHostnameMigration) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + err := sqlDB.Client.AutoMigrate(&models.TcpRouteMapping{}) if err != nil { return err } - _, err = sqlDB.Client.Model(&models.TcpRouteMapping{}).AddUniqueIndex("idx_tcp_route", "router_group_guid", "host_port", "host_ip", "external_port", "sni_hostname") - if err != nil { - return err + + // Drop old index if it exists (ignore errors since it might not exist) + dropIndex(sqlDB, "idx_tcp_route", "tcp_routes") + + // Create unique index - MySQL requires length prefixes for text columns + var indexSQL string + if sqlDB.Client.Dialect().Name() == "mysql" { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid(191), host_port, host_ip(191), external_port, sni_hostname(191))" + } else { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid, host_port, host_ip, external_port, sni_hostname)" } - return err + return sqlDB.Client.ExecWithError(indexSQL) } diff --git a/migration/V5_sni_hostname_migration_test.go b/migration/V5_sni_hostname_migration_test.go index 71064931..976d01e3 100644 --- a/migration/V5_sni_hostname_migration_test.go +++ b/migration/V5_sni_hostname_migration_test.go @@ -78,6 +78,10 @@ var _ = Describe("V5SniHostnameMigration", func() { }) It("denies adding the same TCP routes with same SNI hostnames", func() { + migrator := sqlDB.Client.Migrator() + hasIndex := migrator.HasIndex("tcp_routes", "idx_tcp_route") + Expect(hasIndex).To(BeTrue(), "Index idx_tcp_route should exist on tcp_routes table") + sniHostname1 := "sniHostname1" tcpRoute2 := models.TcpRouteMapping{ Model: models.Model{Guid: "guid-2"}, diff --git a/migration/V6_tls_tcp_route.go b/migration/V6_tls_tcp_route.go index f22f9d7e..a3f5217c 100644 --- a/migration/V6_tls_tcp_route.go +++ b/migration/V6_tls_tcp_route.go @@ -18,17 +18,19 @@ func (v *V6TCPTLSRoutes) Version() int { } func (v *V6TCPTLSRoutes) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + err := sqlDB.Client.AutoMigrate(&models.TcpRouteMapping{}) if err != nil { return err } - err = sqlDB.Client.AutoMigrate(&models.TcpRouteMapping{}) - if err != nil { - return err - } - _, err = sqlDB.Client.Model(&models.TcpRouteMapping{}).AddUniqueIndex("idx_tcp_route", "router_group_guid", "host_port", "host_ip", "external_port", "sni_hostname", "host_tls_port") - if err != nil { - return err + + dropIndex(sqlDB, "idx_tcp_route", "tcp_routes") + + // Create unique index - MySQL requires length prefixes for text columns + var indexSQL string + if sqlDB.Client.Dialect().Name() == "mysql" { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid(191), host_port, host_ip(191), external_port, sni_hostname(191), host_tls_port)" + } else { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid, host_port, host_ip, external_port, sni_hostname, host_tls_port)" } - return err + return sqlDB.Client.ExecWithError(indexSQL) } diff --git a/migration/V6_tls_tcp_route_test.go b/migration/V6_tls_tcp_route_test.go index 0ec5f94d..04447cd5 100644 --- a/migration/V6_tls_tcp_route_test.go +++ b/migration/V6_tls_tcp_route_test.go @@ -96,6 +96,10 @@ var _ = Describe("V6TCPTLSRoutes", func() { }) It("denies adding the same TCP routes with same host TLS ports", func() { + migrator := sqlDB.Client.Migrator() + hasIndex := migrator.HasIndex("tcp_routes", "idx_tcp_route") + Expect(hasIndex).To(BeTrue(), "Index idx_tcp_route should exist on tcp_routes table") + sniHostname1 := "sniHostname1" tcpRoute2 := models.TcpRouteMapping{ Model: models.Model{Guid: "guid-2"}, @@ -119,6 +123,10 @@ var _ = Describe("V6TCPTLSRoutes", func() { }) It("denies adding the same TCP routes with different instance_ids", func() { + migrator := sqlDB.Client.Migrator() + hasIndex := migrator.HasIndex("tcp_routes", "idx_tcp_route") + Expect(hasIndex).To(BeTrue(), "Index idx_tcp_route should exist on tcp_routes table") + sniHostname1 := "sniHostname1" tcpRoute2 := models.TcpRouteMapping{ Model: models.Model{Guid: "guid-2"}, diff --git a/migration/V7_instance_id_defaults.go b/migration/V7_instance_id_defaults.go index 96a61f8b..2b5d40bf 100644 --- a/migration/V7_instance_id_defaults.go +++ b/migration/V7_instance_id_defaults.go @@ -2,7 +2,6 @@ package migration import ( "code.cloudfoundry.org/routing-api/db" - "code.cloudfoundry.org/routing-api/models" ) type V7TCPTLSRoutes struct{} @@ -16,20 +15,27 @@ func (v *V7TCPTLSRoutes) Version() int { } func (v *V7TCPTLSRoutes) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") - if err != nil { - return err - } - - if sqlDB.Client.Dialect().GetName() == "postgres" { - sqlDB.Client.Exec("ALTER TABLE tcp_routes ALTER COLUMN instance_id DROP NOT NULL") + // Update the instance_id column to allow NULL values - syntax differs by database + if sqlDB.Client.Dialect().Name() == "mysql" { + err := sqlDB.Client.ExecWithError("ALTER TABLE tcp_routes MODIFY COLUMN instance_id varchar(255) DEFAULT NULL") + if err != nil { + return err + } } else { - sqlDB.Client.Exec("ALTER TABLE tcp_routes MODIFY COLUMN instance_id varchar(255) DEFAULT NULL") + err := sqlDB.Client.ExecWithError("ALTER TABLE tcp_routes ALTER COLUMN instance_id SET DEFAULT NULL") + if err != nil { + return err + } } - _, err = sqlDB.Client.Model(&models.TcpRouteMapping{}).AddUniqueIndex("idx_tcp_route", "router_group_guid", "host_port", "host_ip", "external_port", "sni_hostname", "host_tls_port") - if err != nil { - return err + dropIndex(sqlDB, "idx_tcp_route", "tcp_routes") + + // Create unique index - MySQL requires length prefixes for text columns + var indexSQL string + if sqlDB.Client.Dialect().Name() == "mysql" { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid(191), host_port, host_ip(191), external_port, sni_hostname(191), host_tls_port)" + } else { + indexSQL = "CREATE UNIQUE INDEX idx_tcp_route ON tcp_routes (router_group_guid, host_port, host_ip, external_port, sni_hostname, host_tls_port)" } - return err + return sqlDB.Client.ExecWithError(indexSQL) } diff --git a/migration/V8_host_tls_port_tcp_default_zero.go b/migration/V8_host_tls_port_tcp_default_zero.go index f23d7153..917c65fa 100644 --- a/migration/V8_host_tls_port_tcp_default_zero.go +++ b/migration/V8_host_tls_port_tcp_default_zero.go @@ -2,7 +2,6 @@ package migration import ( "code.cloudfoundry.org/routing-api/db" - "code.cloudfoundry.org/routing-api/models" ) type V8HostTLSPortTCPDefaultZero struct{} @@ -16,18 +15,20 @@ func (v *V8HostTLSPortTCPDefaultZero) Version() int { } func (v *V8HostTLSPortTCPDefaultZero) Run(sqlDB *db.SqlDB) error { - _, err := sqlDB.Client.Model(&models.TcpRouteMapping{}).RemoveIndex("idx_tcp_route") + // Update existing rows where host_tls_port is NULL to 0 + err := sqlDB.Client.ExecWithError("UPDATE tcp_routes SET host_tls_port = 0 WHERE host_tls_port IS NULL") if err != nil { return err } - if sqlDB.Client.Dialect().GetName() == "postgres" { - sqlDB.Client.Exec("ALTER TABLE tcp_routes ALTER COLUMN host_tls_port SET DEFAULT 0") + // Try to remove the old index if it exists + dropIndex(sqlDB, "idx_tcp_route", "tcp_routes") + + // Set the DEFAULT 0 on the host_tls_port column - syntax differs by database + if sqlDB.Client.Dialect().Name() == "mysql" { + err = sqlDB.Client.ExecWithError("ALTER TABLE tcp_routes MODIFY COLUMN host_tls_port int DEFAULT 0") } else { - sqlDB.Client.Exec("ALTER TABLE tcp_routes MODIFY COLUMN host_tls_port int DEFAULT 0") + err = sqlDB.Client.ExecWithError("ALTER TABLE tcp_routes ALTER COLUMN host_tls_port SET DEFAULT 0") } - - sqlDB.Client.Exec("UPDATE tcp_routes SET host_tls_port = 0 WHERE host_tls_port IS NULL") - - return nil + return err } diff --git a/migration/V8_host_tls_port_tcp_default_zero_test.go b/migration/V8_host_tls_port_tcp_default_zero_test.go index 9b9ed7a6..276d6a80 100644 --- a/migration/V8_host_tls_port_tcp_default_zero_test.go +++ b/migration/V8_host_tls_port_tcp_default_zero_test.go @@ -153,7 +153,7 @@ var _ = Describe("V8HostTLSPortTCPDefaultZero", func() { It("doesnt fail during the migration", func() { By("manually updating the default") - if sqlDB.Client.Dialect().GetName() == "postgres" { + if sqlDB.Client.Dialect().Name() == "postgres" { sqlDB.Client.Exec("ALTER TABLE tcp_routes ALTER COLUMN host_tls_port SET DEFAULT 0") } else { sqlDB.Client.Exec("ALTER TABLE tcp_routes MODIFY COLUMN host_tls_port int DEFAULT 0") diff --git a/migration/migration.go b/migration/migration.go index 769ce153..c48ae64d 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -1,13 +1,24 @@ package migration import ( + "fmt" "os" "code.cloudfoundry.org/lager/v3" "code.cloudfoundry.org/routing-api/db" - "github.com/jinzhu/gorm" + "gorm.io/gorm" ) +// dropIndex drops an index by name. MySQL requires "DROP INDEX name ON table", +// PostgreSQL uses "DROP INDEX IF EXISTS name". +func dropIndex(sqlDB *db.SqlDB, indexName, tableName string) { + if sqlDB.Client.Dialect().Name() == "mysql" { + _ = sqlDB.Client.ExecWithError(fmt.Sprintf("DROP INDEX %s ON %s", indexName, tableName)) + } else { + _ = sqlDB.Client.ExecWithError(fmt.Sprintf("DROP INDEX IF EXISTS %s", indexName)) + } +} + const MigrationKey = "routing-api-migration" type MigrationData struct { diff --git a/migration/migration_suite_test.go b/migration/migration_suite_test.go index 18e00a32..5e337dee 100644 --- a/migration/migration_suite_test.go +++ b/migration/migration_suite_test.go @@ -1,13 +1,33 @@ package migration_test import ( + "testing" + + "code.cloudfoundry.org/routing-api/cmd/routing-api/testrunner" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" +) - "testing" +var ( + databaseAllocator testrunner.DbAllocator ) func TestMigration(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Migration Suite") } + +var _ = BeforeSuite(func() { + var err error + databaseAllocator = testrunner.NewDbAllocator() + _, err = databaseAllocator.Create() + Expect(err).ToNot(HaveOccurred(), "error occurred starting database client, is the database running?") +}) +var _ = AfterSuite(func() { + err := databaseAllocator.Delete() + Expect(err).ToNot(HaveOccurred()) +}) +var _ = BeforeEach(func() { + err := databaseAllocator.Reset() + Expect(err).ToNot(HaveOccurred()) +}) diff --git a/migration/v0/models.go b/migration/v0/models.go index 894f28d2..d80a2ed5 100644 --- a/migration/v0/models.go +++ b/migration/v0/models.go @@ -25,9 +25,9 @@ func (TcpRouteMapping) TableName() string { type TcpMappingEntity struct { RouterGroupGuid string `json:"router_group_guid"` - HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int" json:"backend_port"` + HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"backend_port"` HostIP string `gorm:"not null; unique_index:idx_tcp_route" json:"backend_ip"` - ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type: int" json:"port"` + ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"port"` ModificationTag `json:"modification_tag"` TTL *int `json:"ttl,omitempty"` } @@ -40,7 +40,7 @@ type Route struct { type RouteEntity struct { Route string `gorm:"not null; unique_index:idx_route" json:"route"` - Port uint16 `gorm:"not null; unique_index:idx_route" json:"port"` + Port uint16 `gorm:"not null; unique_index:idx_route; size:32" json:"port"` IP string `gorm:"not null; unique_index:idx_route" json:"ip"` TTL *int `json:"ttl"` LogGuid string `json:"log_guid"` diff --git a/migration/v5/models.go b/migration/v5/models.go index 37103347..97660614 100644 --- a/migration/v5/models.go +++ b/migration/v5/models.go @@ -27,10 +27,10 @@ func (TcpRouteMapping) TableName() string { type TcpMappingEntity struct { RouterGroupGuid string `gorm:"not null; unique_index:idx_tcp_route" json:"router_group_guid"` - HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int" json:"backend_port"` + HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"backend_port"` HostIP string `gorm:"not null; unique_index:idx_tcp_route" json:"backend_ip"` SniHostname *string `gorm:"default:null; unique_index:idx_tcp_route" json:"backend_sni_hostname,omitempty"` - ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type: int" json:"port"` + ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"port"` ModificationTag `json:"modification_tag"` TTL *int `json:"ttl,omitempty"` IsolationSegment string `json:"isolation_segment"` diff --git a/models/route.go b/models/route.go index c6ff3795..42049b61 100644 --- a/models/route.go +++ b/models/route.go @@ -14,7 +14,7 @@ type Route struct { type RouteEntity struct { Route string `gorm:"not null; unique_index:idx_route" json:"route"` - Port uint16 `gorm:"not null; unique_index:idx_route" json:"port"` + Port uint16 `gorm:"not null; unique_index:idx_route; size:32" json:"port"` IP string `gorm:"not null; unique_index:idx_route" json:"ip"` TTL *int `json:"ttl"` LogGuid string `json:"log_guid"` diff --git a/models/tcp_route.go b/models/tcp_route.go index aa2455cc..59af8631 100644 --- a/models/tcp_route.go +++ b/models/tcp_route.go @@ -20,7 +20,7 @@ type TcpRouteMapping struct { // WHERE filter to include the new field type TcpMappingEntity struct { RouterGroupGuid string `gorm:"not null; unique_index:idx_tcp_route" json:"router_group_guid"` - HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int" json:"backend_port"` + HostPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"backend_port"` HostTLSPort int `gorm:"default:null; unique_index:idx_tcp_route; type:int" json:"backend_tls_port"` HostIP string `gorm:"not null; unique_index:idx_tcp_route" json:"backend_ip"` SniHostname *string `gorm:"default:null; unique_index:idx_tcp_route" json:"backend_sni_hostname,omitempty"` @@ -29,7 +29,7 @@ type TcpMappingEntity struct { // different InstanceId, we fail uniqueness and prevent stale/duplicate routes. If this fails a route, the // TTL on the old record should expire + allow the new route to be created eventually. InstanceId string `gorm:"null; default:null;" json:"instance_id"` - ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type: int" json:"port"` + ExternalPort uint16 `gorm:"not null; unique_index:idx_tcp_route; type:int; size:32" json:"port"` ModificationTag `json:"modification_tag"` TTL *int `json:"ttl,omitempty"` IsolationSegment string `json:"isolation_segment"`