diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 15ea6f0..50c3a2a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,5 +55,8 @@ jobs: - name: Run catastrophic single-node failure recovery test run: go test -count=1 -v ./tests/integration -run 'TestCatastrophicSingleNodeFailure' + - name: Run native PG (no spock) tests + run: go test -count=1 -v ./tests/integration -run 'TestNativePG' + - name: Run timestamp comparison tests run: go test -count=1 -v ./tests/integration -run 'TestCompareTimestampsExact|TestPostgreSQLMicrosecondPrecision|TestOldVsNewComparison' diff --git a/db/queries/queries.go b/db/queries/queries.go index e3edfd6..6a5ba90 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -928,6 +928,116 @@ func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string return lsn, nil } +func GetNativeOriginLSNForNode(ctx context.Context, db DBQuerier, originNodeName string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeOriginLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, originNodeName).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch native origin lsn: %w", err) + } + return lsn, nil +} + +func GetNativeSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeSlotLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch native slot lsn: %w", err) + } + return lsn, nil +} + +func GetReplicationOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + sql, err := RenderSQL(SQLTemplates.GetReplicationOriginNames, nil) + if err != nil { + return nil, err + } + + rows, err := db.Query(ctx, sql) + if err != nil { + return nil, err + } + defer rows.Close() + + names := make(map[string]string) + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + names[id] = name + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return names, nil +} + +// GetNativeNodeOriginNames maps replication origin IDs to subscription names +// for native PG logical replication (no spock). This is the native PG +// equivalent of GetSpockNodeNames. +func GetNativeNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeNodeOriginNames, nil) + if err != nil { + return nil, err + } + + rows, err := db.Query(ctx, sql) + if err != nil { + return nil, err + } + defer rows.Close() + + names := make(map[string]string) + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + names[id] = name + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return names, nil +} + +func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + var spockAvailable bool + err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&spockAvailable) + if err != nil { + return nil, fmt.Errorf("detecting spock extension: %w", err) + } + if spockAvailable { + return GetSpockNodeNames(ctx, db) + } + return GetNativeNodeOriginNames(ctx, db) +} + +func CheckSpockInstalled(ctx context.Context, db DBQuerier) (bool, error) { + var exists bool + err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&exists) + if err != nil { + return false, fmt.Errorf("detecting spock extension: %w", err) + } + return exists, nil +} + func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) { sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil) if err != nil { diff --git a/db/queries/templates.go b/db/queries/templates.go index ccf564d..6a0ffb8 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -11,7 +11,18 @@ package queries -import "text/template" +import ( + "text/template" + + "github.com/jackc/pgx/v5" + "github.com/pgedge/ace/pkg/config" +) + +// aceTemplateFuncs provides the {{aceSchema}} function to SQL templates. +// The function is evaluated at render time (after config is loaded), not at parse time. +var aceTemplateFuncs = template.FuncMap{ + "aceSchema": func() string { return pgx.Identifier{config.Get().MTree.Schema}.Sanitize() }, +} type Templates struct { EstimateRowCount *template.Template @@ -78,6 +89,7 @@ type Templates struct { GetBlockCountSimple *template.Template GetBlockSizeFromMetadata *template.Template GetMaxNodeLevel *template.Template + CompareBlocksSQL *template.Template DropXORFunction *template.Template DropMetadataTable *template.Template @@ -119,6 +131,10 @@ type Templates struct { RemoveTableFromCDCMetadata *template.Template GetSpockOriginLSNForNode *template.Template GetSpockSlotLSNForNode *template.Template + GetNativeOriginLSNForNode *template.Template + GetNativeSlotLSNForNode *template.Template + GetReplicationOriginNames *template.Template + GetNativeNodeOriginNames *template.Template EnsureHashVersionColumn *template.Template GetHashVersion *template.Template MarkAllLeavesDirty *template.Template @@ -133,8 +149,8 @@ type Templates struct { var SQLTemplates = Templates{ // A template isn't needed for this query; just keeping the struct uniform - CreateMetadataTable: template.Must(template.New("createMetadataTable").Parse(` - CREATE TABLE IF NOT EXISTS spock.ace_mtree_metadata ( + CreateMetadataTable: template.Must(template.New("createMetadataTable").Funcs(aceTemplateFuncs).Parse(` + CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_mtree_metadata ( schema_name text, table_name text, total_rows bigint, @@ -161,8 +177,8 @@ var SQLTemplates = Templates{ ALTER PUBLICATION {{.PublicationName}} DROP TABLE {{.TableName}} `)), - RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Parse(` - UPDATE spock.ace_cdc_metadata + RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Funcs(aceTemplateFuncs).Parse(` + UPDATE {{aceSchema}}.ace_cdc_metadata SET tables = array_remove(tables, $1) WHERE publication_name = $2 `)), @@ -179,9 +195,9 @@ var SQLTemplates = Templates{ ) `)), - UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Parse(` + UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Funcs(aceTemplateFuncs).Parse(` INSERT INTO - spock.ace_cdc_metadata ( + {{aceSchema}}.ace_cdc_metadata ( publication_name, slot_name, start_lsn, @@ -220,17 +236,17 @@ var SQLTemplates = Templates{ CheckPIDExists: template.Must(template.New("checkPIDExists").Parse(` SELECT pid FROM pg_stat_activity WHERE pid = $1 `)), - DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Parse(` - DROP TABLE IF EXISTS spock.ace_cdc_metadata + DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(` + DROP TABLE IF EXISTS {{aceSchema}}.ace_cdc_metadata `)), - GetCDCMetadata: template.Must(template.New("getCDCMetadata").Parse(` + GetCDCMetadata: template.Must(template.New("getCDCMetadata").Funcs(aceTemplateFuncs).Parse(` SELECT slot_name, start_lsn, tables FROM - spock.ace_cdc_metadata + {{aceSchema}}.ace_cdc_metadata WHERE publication_name = $1 `)), @@ -317,8 +333,8 @@ var SQLTemplates = Templates{ AND mt.node_position = b.node_position; `)), - CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Parse(` - CREATE TABLE IF NOT EXISTS spock.ace_cdc_metadata ( + CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(` + CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_cdc_metadata ( publication_name text PRIMARY KEY, slot_name text, start_lsn text, @@ -772,9 +788,9 @@ var SQLTemplates = Templates{ VALUES (0, $1, {{.StartExpr}}, {{.EndExpr}}); `)), - CreateXORFunction: template.Must(template.New("createXORFunction").Parse(` + CreateXORFunction: template.Must(template.New("createXORFunction").Funcs(aceTemplateFuncs).Parse(` CREATE - OR REPLACE FUNCTION spock.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$ + OR REPLACE FUNCTION {{aceSchema}}.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$ DECLARE result bytea; len int; @@ -805,7 +821,7 @@ var SQLTemplates = Templates{ CREATE OPERATOR # ( LEFTARG = bytea, RIGHTARG = bytea, - PROCEDURE = spock.bytea_xor + PROCEDURE = {{aceSchema}}.bytea_xor ); END IF; END $$; @@ -840,9 +856,9 @@ var SQLTemplates = Templates{ AND c.relname = $2 AND a.attname = $3 `)), - UpdateMetadata: template.Must(template.New("updateMetadata").Parse(` + UpdateMetadata: template.Must(template.New("updateMetadata").Funcs(aceTemplateFuncs).Parse(` INSERT INTO - spock.ace_mtree_metadata ( + {{aceSchema}}.ace_mtree_metadata ( schema_name, table_name, total_rows, @@ -873,8 +889,8 @@ var SQLTemplates = Templates{ hash_version = EXCLUDED.hash_version, last_updated = EXCLUDED.last_updated `)), - DeleteMetadata: template.Must(template.New("deleteMetadata").Parse(` - DELETE FROM spock.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 + DeleteMetadata: template.Must(template.New("deleteMetadata").Funcs(aceTemplateFuncs).Parse(` + DELETE FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 `)), InsertBlockRanges: template.Must(template.New("insertBlockRanges").Parse(` INSERT INTO @@ -1055,11 +1071,11 @@ var SQLTemplates = Templates{ ORDER BY node_position `)), - GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Parse(` + GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Funcs(aceTemplateFuncs).Parse(` SELECT total_rows FROM - spock.ace_mtree_metadata + {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 @@ -1294,11 +1310,11 @@ var SQLTemplates = Templates{ mt.range_start, mt.range_end `)), - GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Parse(` + GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Funcs(aceTemplateFuncs).Parse(` SELECT block_size FROM - spock.ace_mtree_metadata + {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 @@ -1309,11 +1325,19 @@ var SQLTemplates = Templates{ FROM {{.MtreeTable}} `)), - DropXORFunction: template.Must(template.New("dropXORFunction").Parse(` - DROP FUNCTION IF EXISTS spock.bytea_xor(bytea, bytea) CASCADE + CompareBlocksSQL: template.Must(template.New("compareBlocksSQL").Parse(` + SELECT + * + FROM + {{.TableName}} + WHERE + {{.WhereClause}} + `)), + DropXORFunction: template.Must(template.New("dropXORFunction").Funcs(aceTemplateFuncs).Parse(` + DROP FUNCTION IF EXISTS {{aceSchema}}.bytea_xor(bytea, bytea) CASCADE `)), - DropMetadataTable: template.Must(template.New("dropMetadataTable").Parse(` - DROP TABLE IF EXISTS spock.ace_mtree_metadata CASCADE + DropMetadataTable: template.Must(template.New("dropMetadataTable").Funcs(aceTemplateFuncs).Parse(` + DROP TABLE IF EXISTS {{aceSchema}}.ace_mtree_metadata CASCADE `)), DropMtreeTable: template.Must(template.New("dropMtreeTable").Parse(` DROP TABLE IF EXISTS {{.MtreeTable}} CASCADE @@ -1506,13 +1530,13 @@ var SQLTemplates = Templates{ ORDER BY rs.confirmed_flush_lsn DESC LIMIT 1 `)), - EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Parse(` - ALTER TABLE spock.ace_mtree_metadata + EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Funcs(aceTemplateFuncs).Parse(` + ALTER TABLE {{aceSchema}}.ace_mtree_metadata ADD COLUMN IF NOT EXISTS hash_version int NOT NULL DEFAULT 1 `)), - GetHashVersion: template.Must(template.New("getHashVersion").Parse(` + GetHashVersion: template.Must(template.New("getHashVersion").Funcs(aceTemplateFuncs).Parse(` SELECT COALESCE( - (SELECT hash_version FROM spock.ace_mtree_metadata + (SELECT hash_version FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2), 1 ) @@ -1522,11 +1546,41 @@ var SQLTemplates = Templates{ SET dirty = true WHERE node_level = 0 `)), - UpdateHashVersion: template.Must(template.New("updateHashVersion").Parse(` - UPDATE spock.ace_mtree_metadata + UpdateHashVersion: template.Must(template.New("updateHashVersion").Funcs(aceTemplateFuncs).Parse(` + UPDATE {{aceSchema}}.ace_mtree_metadata SET hash_version = $1, last_updated = current_timestamp WHERE schema_name = $2 AND table_name = $3 `)), + GetNativeOriginLSNForNode: template.Must(template.New("getNativeOriginLSNForNode").Parse(` + SELECT ros.remote_lsn::text + FROM pg_catalog.pg_replication_origin_status ros + JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id + JOIN pg_catalog.pg_subscription s ON ro.roname LIKE 'pg_%' || s.oid::text + WHERE s.subname ~ ('\m' || $1 || '\M') + AND ros.remote_lsn IS NOT NULL + LIMIT 1 + `)), + GetNativeSlotLSNForNode: template.Must(template.New("getNativeSlotLSNForNode").Parse(` + SELECT rs.confirmed_flush_lsn::text + FROM pg_catalog.pg_replication_slots rs + JOIN pg_catalog.pg_subscription s ON rs.slot_name = s.subslotname + WHERE s.subname ~ ('\m' || $1 || '\M') + AND rs.confirmed_flush_lsn IS NOT NULL + ORDER BY rs.confirmed_flush_lsn DESC + LIMIT 1 + `)), + GetReplicationOriginNames: template.Must(template.New("getReplicationOriginNames").Parse(` + SELECT roident::text, roname FROM pg_replication_origin; + `)), + // GetNativeNodeOriginNames maps pg_replication_origin entries to their + // corresponding pg_subscription names. This provides the native PG + // equivalent of GetSpockNodeNames — mapping origin IDs (used by + // pg_xact_commit_timestamp_origin) to human-readable node identifiers. + GetNativeNodeOriginNames: template.Must(template.New("getNativeNodeOriginNames").Parse(` + SELECT ro.roident::text, s.subname + FROM pg_catalog.pg_replication_origin ro + JOIN pg_catalog.pg_subscription s ON ro.roname = 'pg_' || s.oid::text + `)), GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(` SELECT roident FROM pg_replication_origin WHERE roname = $1 `)), diff --git a/internal/consistency/diff/repset_diff.go b/internal/consistency/diff/repset_diff.go index 8eeb94c..eb97dc6 100644 --- a/internal/consistency/diff/repset_diff.go +++ b/internal/consistency/diff/repset_diff.go @@ -157,6 +157,17 @@ func (c *RepsetDiffCmd) RunChecks(skipValidation bool) error { return fmt.Errorf("could not connect to node %s: %w", nodeName, err) } + // Check if spock extension is installed + spockInstalled, err := queries.CheckSpockInstalled(c.Ctx, pool) + if err != nil { + pool.Close() + return fmt.Errorf("failed to check for spock extension on node %s: %w", nodeName, err) + } + if !spockInstalled { + pool.Close() + return fmt.Errorf("repset-diff requires the spock extension, which is not installed on node %s", nodeName) + } + repsetExists, err := queries.CheckRepSetExists(c.Ctx, pool, c.RepsetName) if err != nil { pool.Close() @@ -385,6 +396,7 @@ func RepsetDiff(task *RepsetDiffCmd) (err error) { tdTask.QuietMode = task.Quiet tdTask.Ctx = task.Ctx tdTask.SkipDBUpdate = task.SkipDBUpdate + tdTask.TaskStore = task.TaskStore tdTask.TaskStorePath = task.TaskStorePath if err := tdTask.Validate(); err != nil { diff --git a/internal/consistency/diff/spock_diff.go b/internal/consistency/diff/spock_diff.go index d560a88..7056734 100644 --- a/internal/consistency/diff/spock_diff.go +++ b/internal/consistency/diff/spock_diff.go @@ -307,6 +307,17 @@ func (t *SpockDiffTask) ExecuteTask() (err error) { } t.Pools = pools + // Check that spock is installed on every selected node + for name, pool := range t.Pools { + spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) + if err != nil { + return fmt.Errorf("failed to check for spock extension on node %s: %w", name, err) + } + if !spockInstalled { + return fmt.Errorf("spock-diff requires the spock extension, which is not installed on node %s", name) + } + } + allNodeConfigs := make(map[string]SpockNodeConfig) var nodeNames []string diff --git a/internal/consistency/diff/table_diff.go b/internal/consistency/diff/table_diff.go index a3f05f0..e0a7d49 100644 --- a/internal/consistency/diff/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -102,7 +102,7 @@ type TableDiffTask struct { blockHashSQLCache map[hashBoundsKey]string blockHashSQLMu sync.Mutex - SpockNodeNames map[string]string + NodeOriginNames map[string]string CompareUnitSize int MaxDiffRows int64 @@ -184,8 +184,8 @@ func (t *TableDiffTask) incrementDiffRowsLocked(delta int) bool { return false } -func (t *TableDiffTask) loadSpockNodeNames() error { - if t.SpockNodeNames != nil { +func (t *TableDiffTask) loadNodeOriginNames() error { + if t.NodeOriginNames != nil { return nil } @@ -196,17 +196,17 @@ func (t *TableDiffTask) loadSpockNodeNames() error { } if firstPool == nil { - t.SpockNodeNames = make(map[string]string) - return fmt.Errorf("no connection pool available to load spock node names") + t.NodeOriginNames = make(map[string]string) + return fmt.Errorf("no connection pool available to load node origin names") } - names, err := queries.GetSpockNodeNames(t.Ctx, firstPool) + names, err := queries.GetNodeOriginNames(t.Ctx, firstPool) if err != nil { - t.SpockNodeNames = make(map[string]string) + t.NodeOriginNames = make(map[string]string) return err } - t.SpockNodeNames = names + t.NodeOriginNames = names return nil } @@ -214,26 +214,36 @@ func (t *TableDiffTask) resolveAgainstOrigin() error { if strings.TrimSpace(t.AgainstOrigin) == "" { return nil } - if len(t.SpockNodeNames) == 0 { - return fmt.Errorf("unable to resolve --against-origin: spock node names not available") + if len(t.NodeOriginNames) == 0 { + return fmt.Errorf("unable to resolve --against-origin: no node origin names available") } orig := strings.TrimSpace(t.AgainstOrigin) - // direct match on id - if _, ok := t.SpockNodeNames[orig]; ok { + // direct match on origin id + if _, ok := t.NodeOriginNames[orig]; ok { t.resolvedAgainstOrigin = orig return nil } - // match on name - for id, name := range t.SpockNodeNames { + // match on origin name + for id, name := range t.NodeOriginNames { if name == orig { t.resolvedAgainstOrigin = id return nil } } - return fmt.Errorf("unable to resolve against-origin %q to a spock node id", t.AgainstOrigin) + // build a list of available origins for the error message + available := make([]string, 0, len(t.NodeOriginNames)) + for id, name := range t.NodeOriginNames { + if id != name { + available = append(available, fmt.Sprintf("%s (%s)", id, name)) + } else { + available = append(available, id) + } + } + + return fmt.Errorf("unable to resolve --against-origin %q; available origins: %s", t.AgainstOrigin, strings.Join(available, ", ")) } func (t *TableDiffTask) buildEffectiveFilter() (string, error) { @@ -256,7 +266,7 @@ func (t *TableDiffTask) buildEffectiveFilter() (string, error) { if err != nil { return "", fmt.Errorf("resolved against-origin %q is not a valid numeric node ID", t.resolvedAgainstOrigin) } - parts = append(parts, fmt.Sprintf("(to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' = '%d')", nodeID)) + parts = append(parts, fmt.Sprintf("(to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' = '%d')", nodeID)) } if t.untilTime != nil { @@ -270,7 +280,7 @@ func (t *TableDiffTask) buildEffectiveFilter() (string, error) { } func (t *TableDiffTask) withSpockMetadata(row map[string]any) map[string]any { - row["node_origin"] = utils.TranslateNodeOrigin(row["node_origin"], t.SpockNodeNames) + row["node_origin"] = utils.TranslateNodeOrigin(row["node_origin"], t.NodeOriginNames) return utils.AddSpockMetadata(row) } @@ -450,7 +460,7 @@ func (t *TableDiffTask) fetchRows(nodeName string, r Range) ([]types.OrderedMap, } selectCols := make([]string, 0, len(t.Cols)+2) - selectCols = append(selectCols, "pg_xact_commit_timestamp(xmin) as commit_ts", "to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin") + selectCols = append(selectCols, "pg_xact_commit_timestamp(xmin) as commit_ts", "to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin") for _, colName := range t.Cols { colType := colTypes[colName] @@ -1287,8 +1297,8 @@ func (t *TableDiffTask) ExecuteTask() (err error) { } } - if err := t.loadSpockNodeNames(); err != nil { - logger.Warn("table-diff: unable to load spock node names; using raw node_origin values: %v", err) + if err := t.loadNodeOriginNames(); err != nil { + logger.Warn("table-diff: unable to load node origin names; using raw node_origin values: %v", err) } if err := t.resolveAgainstOrigin(); err != nil { @@ -1344,8 +1354,8 @@ func (t *TableDiffTask) ExecuteTask() (err error) { DiffRowsCount: make(map[string]int), AgainstOrigin: t.AgainstOrigin, AgainstOriginResolved: func() string { - if t.resolvedAgainstOrigin != "" && t.SpockNodeNames != nil { - if name, ok := t.SpockNodeNames[t.resolvedAgainstOrigin]; ok { + if t.resolvedAgainstOrigin != "" && t.NodeOriginNames != nil { + if name, ok := t.NodeOriginNames[t.resolvedAgainstOrigin]; ok { return name } } diff --git a/internal/consistency/diff/table_diff_origin_test.go b/internal/consistency/diff/table_diff_origin_test.go new file mode 100644 index 0000000..bafd7aa --- /dev/null +++ b/internal/consistency/diff/table_diff_origin_test.go @@ -0,0 +1,135 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package diff + +import ( + "strings" + "testing" +) + +func TestResolveAgainstOrigin_EmptyInput(t *testing.T) { + task := &TableDiffTask{AgainstOrigin: ""} + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "" { + t.Fatalf("expected empty resolvedAgainstOrigin, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_NoNodeOriginNames(t *testing.T) { + task := &TableDiffTask{AgainstOrigin: "n1"} + err := task.resolveAgainstOrigin() + if err == nil { + t.Fatal("expected error when NodeOriginNames is empty") + } + if !strings.Contains(err.Error(), "no node origin names available") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveAgainstOrigin_MatchByID(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "3", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "3" { + t.Fatalf("expected resolvedAgainstOrigin=3, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_MatchByName_SpockNodeName(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "n1", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "3" { + t.Fatalf("expected resolvedAgainstOrigin=3, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_MatchByName_SubscriptionName(t *testing.T) { + // Native PG: NodeOriginNames maps roident -> subscription name + task := &TableDiffTask{ + AgainstOrigin: "sub_n1_to_n2", + NodeOriginNames: map[string]string{"5": "sub_n1_to_n2", "6": "sub_n3_to_n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "5" { + t.Fatalf("expected resolvedAgainstOrigin=5, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_NoMatch(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "nonexistent", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + err := task.resolveAgainstOrigin() + if err == nil { + t.Fatal("expected error for unresolvable origin") + } + if !strings.Contains(err.Error(), "nonexistent") { + t.Fatalf("error should mention the unresolved name: %v", err) + } +} + +func TestBuildEffectiveFilter_AgainstOrigin(t *testing.T) { + task := &TableDiffTask{ + resolvedAgainstOrigin: "3", + } + filter, err := task.buildEffectiveFilter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(filter, "pg_xact_commit_timestamp_origin") { + t.Fatalf("expected pg_xact_commit_timestamp_origin in filter, got: %s", filter) + } + if !strings.Contains(filter, "'3'") { + t.Fatalf("expected roident=3 in filter, got: %s", filter) + } + if strings.Contains(filter, "spock") { + t.Fatalf("filter should not reference spock: %s", filter) + } +} + +func TestBuildEffectiveFilter_NonNumericOrigin(t *testing.T) { + task := &TableDiffTask{ + resolvedAgainstOrigin: "not_a_number", + } + _, err := task.buildEffectiveFilter() + if err == nil { + t.Fatal("expected error for non-numeric resolved origin") + } + if !strings.Contains(err.Error(), "not a valid numeric node ID") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestBuildEffectiveFilter_Empty(t *testing.T) { + task := &TableDiffTask{} + filter, err := task.buildEffectiveFilter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filter != "" { + t.Fatalf("expected empty filter, got: %s", filter) + } +} diff --git a/internal/consistency/diff/table_rerun.go b/internal/consistency/diff/table_rerun.go index f3ca7a0..ca4b1ac 100644 --- a/internal/consistency/diff/table_rerun.go +++ b/internal/consistency/diff/table_rerun.go @@ -79,8 +79,8 @@ func (t *TableDiffTask) ExecuteRerunTask() error { } }() - if err := t.loadSpockNodeNames(); err != nil { - logger.Warn("table-diff rerun: unable to load spock node names; using raw node_origin values: %v", err) + if err := t.loadNodeOriginNames(); err != nil { + logger.Warn("table-diff rerun: unable to load node origin names; using raw node_origin values: %v", err) } // Collect all unique primary keys from the original diff report @@ -277,7 +277,7 @@ func fetchRowsByPkeys(ctx context.Context, pool *pgxpool.Pool, t *TableDiffTask, } selectCols := make([]string, 0, len(t.Cols)+2) - selectCols = append(selectCols, "pg_xact_commit_timestamp(t.xmin) as commit_ts", "to_json(spock.xact_commit_timestamp_origin(t.xmin))->>'roident' as node_origin") + selectCols = append(selectCols, "pg_xact_commit_timestamp(t.xmin) as commit_ts", "to_json(pg_xact_commit_timestamp_origin(t.xmin))->>'roident' as node_origin") for _, col := range t.Cols { selectCols = append(selectCols, "t."+pgx.Identifier{col}.Sanitize()) } diff --git a/internal/consistency/mtree/merkle.go b/internal/consistency/mtree/merkle.go index c80d4b1..b3e1e4f 100644 --- a/internal/consistency/mtree/merkle.go +++ b/internal/consistency/mtree/merkle.go @@ -96,7 +96,7 @@ type MerkleTreeTask struct { diffMutex sync.Mutex diffRowKeySets map[string]map[string]map[string]struct{} StartTime time.Time - SpockNodeNames map[string]string + NodeOriginNames map[string]string Ctx context.Context } @@ -498,8 +498,8 @@ func (m *MerkleTreeTask) processWorkItem(work CompareRangesWorkItem, pool1, pool return nil } -func (m *MerkleTreeTask) loadSpockNodeNames() error { - if m.SpockNodeNames != nil { +func (m *MerkleTreeTask) loadNodeOriginNames() error { + if m.NodeOriginNames != nil { return nil } @@ -510,21 +510,21 @@ func (m *MerkleTreeTask) loadSpockNodeNames() error { lastErr = err continue } - names, err := queries.GetSpockNodeNames(m.Ctx, pool) + names, err := queries.GetNodeOriginNames(m.Ctx, pool) pool.Close() if err != nil { lastErr = err continue } - m.SpockNodeNames = names + m.NodeOriginNames = names return nil } - m.SpockNodeNames = make(map[string]string) + m.NodeOriginNames = make(map[string]string) if lastErr != nil { return lastErr } - return fmt.Errorf("no nodes available to load spock node names") + return fmt.Errorf("no nodes available to load node origin names") } func (m *MerkleTreeTask) appendDiffs(nodePairKey string, work CompareRangesWorkItem, pr1, pr2 []types.OrderedMap) error { @@ -596,7 +596,7 @@ func (m *MerkleTreeTask) addRowToDiff(nodePairKey, nodeName string, row types.Or } rowMap := utils.OrderedMapToMap(row) - rowMap["node_origin"] = utils.TranslateNodeOrigin(rowMap["node_origin"], m.SpockNodeNames) + rowMap["node_origin"] = utils.TranslateNodeOrigin(rowMap["node_origin"], m.NodeOriginNames) rowWithMeta := utils.AddSpockMetadata(rowMap) orderedRow := utils.MapToOrderedMap(rowWithMeta, m.Cols) @@ -714,7 +714,7 @@ func buildFetchRowsSQLSimple(schema, table, pk string, orderBy string, keys []an } qualifiedTable := fmt.Sprintf("%s.%s", pgx.Identifier{schema}.Sanitize(), pgx.Identifier{table}.Sanitize()) where := fmt.Sprintf("%s IN (%s)", pgx.Identifier{pk}.Sanitize(), strings.Join(placeholders, ",")) - selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, qualifiedTable, where, orderBy) return q, args } @@ -738,7 +738,7 @@ func buildFetchRowsSQLComposite(schema, table string, pk []string, orderBy strin } qualifiedTable := fmt.Sprintf("%s.%s", pgx.Identifier{schema}.Sanitize(), pgx.Identifier{table}.Sanitize()) where := fmt.Sprintf("( %s ) IN ( %s )", strings.Join(tupleCols, ","), strings.Join(tuples, ",")) - selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, qualifiedTable, where, orderBy) return q, args } @@ -1891,8 +1891,8 @@ func (m *MerkleTreeTask) DiffMtree() (err error) { if err = m.UpdateMtree(true); err != nil { return fmt.Errorf("failed to update merkle tree before diff: %w", err) } - if err := m.loadSpockNodeNames(); err != nil { - logger.Warn("mtree diff: unable to load spock node names; using raw node_origin values: %v", err) + if err := m.loadNodeOriginNames(); err != nil { + logger.Warn("mtree diff: unable to load node origin names; using raw node_origin values: %v", err) } nodePairs := getNodePairs(m.ClusterNodes) mtreeTableIdentifier := pgx.Identifier{m.aceSchema(), fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table)} diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index a2c70f9..635514f 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -98,6 +98,8 @@ type TableRepairTask struct { autoSelectionFailedNode string autoSelectionDetails map[string]map[string]string + spockPerNode map[string]bool + Ctx context.Context } @@ -160,28 +162,76 @@ func (t *TableRepairTask) setRole(tx pgx.Tx, nodeName string) error { return nil } -// setupTransactionMode enables spock repair mode, sets the session replication +// isSpockAvailable returns whether spock is installed on the given node, +// detecting lazily on first check per node. Returns an error if detection +// fails so callers don't silently fall back to the wrong repair mode. +func (t *TableRepairTask) isSpockAvailable(nodeName string) (bool, error) { + if t.spockPerNode == nil { + t.spockPerNode = make(map[string]bool) + } + if _, checked := t.spockPerNode[nodeName]; !checked { + pool := t.Pools[nodeName] + if pool == nil { + return false, fmt.Errorf("no connection pool for node %s", nodeName) + } + spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) + if err != nil { + return false, fmt.Errorf("failed to detect spock extension on %s: %w", nodeName, err) + } + t.spockPerNode[nodeName] = spockInstalled + logger.Info("spock extension on %s: %v", nodeName, spockInstalled) + } + return t.spockPerNode[nodeName], nil +} + +// setupTransactionMode enables spock repair mode (when available), sets the session replication // role, and applies the client role for a repair transaction. -func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) error { - if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { - return fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) +// Returns true if spock repair mode was activated, false otherwise. +func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool, error) { + spockRepairModeActive := false + spock, err := t.isSpockAvailable(nodeName) + if err != nil { + return false, err + } + if spock { + if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { + return false, fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) + } + logger.Debug("spock.repair_mode(true) set on %s", nodeName) + spockRepairModeActive = true } - logger.Debug("spock.repair_mode(true) set on %s", nodeName) - var err error if t.FireTriggers { _, err = tx.Exec(t.Ctx, "SET session_replication_role = 'local'") } else { _, err = tx.Exec(t.Ctx, "SET session_replication_role = 'replica'") } if err != nil { - return fmt.Errorf("setting session_replication_role on %s: %w", nodeName, err) + return false, fmt.Errorf("setting session_replication_role on %s: %w", nodeName, err) } logger.Debug("session_replication_role set on %s (fire_triggers: %v)", nodeName, t.FireTriggers) if err := t.setRole(tx, nodeName); err != nil { + return false, err + } + return spockRepairModeActive, nil +} + +// disableSpockRepairMode disables spock repair mode on the given transaction. +// When spock is not available, this is a no-op. +func (t *TableRepairTask) disableSpockRepairMode(tx pgx.Tx, nodeName string) error { + spock, err := t.isSpockAvailable(nodeName) + if err != nil { return err } + if !spock { + return nil + } + _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") + if err != nil { + return fmt.Errorf("disabling spock.repair_mode(false) on %s: %w", nodeName, err) + } + logger.Debug("spock.repair_mode(false) set on %s", nodeName) return nil } @@ -810,13 +860,13 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { continue } - if err := t.setupTransactionMode(tx, nodeName); err != nil { + spockRepairModeActive, err := t.setupTransactionMode(tx, nodeName) + if err != nil { tx.Rollback(t.Ctx) logger.Error("%v", err) repairErrors = append(repairErrors, err.Error()) continue } - spockRepairModeActive := true colTypes, _, err := t.getColTypesForNode(nodeName) if err != nil { @@ -874,11 +924,10 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { // Commit the initial tx (repair_mode setup) first, then use // per-batch-key transactions for origin preservation. if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } } @@ -929,11 +978,10 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { if !nodeFailed { if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) nodeFailed = true } } @@ -1349,7 +1397,7 @@ func (t *TableRepairTask) executePreserveOriginFixNulls(pool *pgxpool.Pool, node return totalUpdated, fmt.Errorf("failed to begin batch transaction for origin %s: %w", batchKey.nodeOrigin, err) } - if err := t.setupTransactionMode(batchTx, nodeName); err != nil { + if _, err := t.setupTransactionMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpdated, fmt.Errorf("failed to setup transaction mode for origin batch: %w", err) } @@ -1393,7 +1441,7 @@ func (t *TableRepairTask) executePreserveOriginFixNulls(pool *pgxpool.Pool, node } } - if _, err := batchTx.Exec(t.Ctx, "SELECT spock.repair_mode(false)"); err != nil { + if err := t.disableSpockRepairMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpdated, fmt.Errorf("failed to disable repair mode for origin batch: %w", err) } @@ -1574,13 +1622,13 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { continue } - if err := t.setupTransactionMode(tx, nodeName); err != nil { + spockRepairModeActive, err := t.setupTransactionMode(tx, nodeName) + if err != nil { tx.Rollback(t.Ctx) logger.Error("%v", err) repairErrors = append(repairErrors, err.Error()) continue } - spockRepairModeActive := true // TODO: DROP PRIVILEGES HERE! @@ -1634,11 +1682,10 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { // separate per-batch-key transactions so each batch gets its own // pg_replication_origin_xact_setup (which is per-transaction). if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } } @@ -1691,15 +1738,23 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { if spockRepairModeActive { // TODO: Need to elevate privileges here, but might be difficult // with pgx transactions and connection pooling. - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } - logger.Debug("spock.repair_mode(false) set on %s", nodeName) + err = tx.Commit(t.Ctx) + if err != nil { + logger.Error("committing transaction on node %s: %v", nodeName, err) + repairErrors = append(repairErrors, fmt.Sprintf("commit failed for %s: %v", nodeName, err)) + continue + } + logger.Debug("Transaction committed successfully on %s", nodeName) + } else if !t.PreserveOrigin || len(t.extractOriginInfoForNode(nodeName, fullUpserts[nodeName])) == 0 { + // Non-spock path: commit unless the preserve-origin branch (line ~1692) + // already committed this tx and ran upserts in separate transactions. err = tx.Commit(t.Ctx) if err != nil { logger.Error("committing transaction on node %s: %v", nodeName, err) @@ -1994,7 +2049,7 @@ func (t *TableRepairTask) performBirectionalInserts(nodeName string, inserts map } defer tx.Rollback(t.Ctx) - if err := t.setupTransactionMode(tx, nodeName); err != nil { + if _, err := t.setupTransactionMode(tx, nodeName); err != nil { return 0, err } @@ -2004,9 +2059,8 @@ func (t *TableRepairTask) performBirectionalInserts(nodeName string, inserts map } logger.Info("Executed %d insert operations on %s", insertedCount, nodeName) - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { - return 0, fmt.Errorf("failed to disable spock.repair_mode(false) on %s: %w", nodeName, err) + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { + return 0, err } if err := tx.Commit(t.Ctx); err != nil { @@ -2570,7 +2624,7 @@ func (t *TableRepairTask) executePreserveOriginUpserts(pool *pgxpool.Pool, nodeN return totalUpsertedCount, fmt.Errorf("failed to begin batch transaction for origin %s: %w", batchKey.nodeOrigin, err) } - if err := t.setupTransactionMode(batchTx, nodeName); err != nil { + if _, err := t.setupTransactionMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpsertedCount, fmt.Errorf("failed to setup transaction mode for origin batch: %w", err) } @@ -2606,7 +2660,7 @@ func (t *TableRepairTask) executePreserveOriginUpserts(pool *pgxpool.Pool, nodeN } // Disable repair mode before commit - if _, err := batchTx.Exec(t.Ctx, "SELECT spock.repair_mode(false)"); err != nil { + if err := t.disableSpockRepairMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpsertedCount, fmt.Errorf("failed to disable repair mode for origin batch: %w", err) } @@ -2872,9 +2926,22 @@ func calculateRepairSetsWithSourceOfTruth(task *TableRepairTask) (map[string]map } func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survivor string) (originLSN *uint64, slotLSN *uint64, err error) { + spock, err := t.isSpockAvailable(survivor) + if err != nil { + return nil, nil, fmt.Errorf("detecting spock on %s: %w", survivor, err) + } var originStr *string - originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) + if spock { + originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) + } else { + originStr, err = queries.GetNativeOriginLSNForNode(t.Ctx, pool, failedNode) + } if err != nil { + if !spock { + // Native PG queries may fail if subscription naming doesn't match; treat as no data + logger.Warn("failed to fetch native origin lsn on %s: %v", survivor, err) + return nil, nil, nil + } return nil, nil, fmt.Errorf("failed to fetch origin lsn on %s: %w", survivor, err) } if originStr != nil { @@ -2885,8 +2952,17 @@ func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survi } var slotStr *string - slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) + if spock { + slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) + } else { + slotStr, err = queries.GetNativeSlotLSNForNode(t.Ctx, pool, failedNode) + } if err != nil { + if !spock { + // Native PG queries may fail if subscription naming doesn't match; treat as no data + logger.Warn("failed to fetch native slot lsn on %s: %v", survivor, err) + return originLSN, nil, nil + } return originLSN, nil, fmt.Errorf("failed to fetch slot lsn on %s: %w", survivor, err) } if slotStr != nil { @@ -2930,17 +3006,19 @@ func (t *TableRepairTask) autoSelectSourceOfTruth(failedNode string, involved ma logger.Warn("recovery-mode: failed to connect to %s for LSN probe: %v", nodeName, err) continue } - originLSN, slotLSN, err := t.fetchLSNsForNode(pool, failedNode, nodeName) - if err != nil { - logger.Warn("recovery-mode: failed to fetch LSNs on %s: %v", nodeName, err) - pool.Close() - continue - } + // Store pool before fetchLSNsForNode so isSpockAvailable can find it if t.Pools[nodeName] == nil { t.Pools[nodeName] = pool } else { pool.Close() + pool = t.Pools[nodeName] + } + + originLSN, slotLSN, err := t.fetchLSNsForNode(pool, failedNode, nodeName) + if err != nil { + logger.Warn("recovery-mode: failed to fetch LSNs on %s: %v", nodeName, err) + continue } lsnDetails[nodeName] = map[string]string{ diff --git a/tests/integration/advanced_repair_test.go b/tests/integration/advanced_repair_test.go index 38c6b4f..f29ad56 100644 --- a/tests/integration/advanced_repair_test.go +++ b/tests/integration/advanced_repair_test.go @@ -30,7 +30,7 @@ func TestAdvancedRepairPlan_MixedSelectorsAndActions(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) plan := ` @@ -87,7 +87,7 @@ func TestAdvancedRepairPlan_DeleteAndWhenPredicate(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) plan := ` @@ -140,7 +140,7 @@ func TestAdvancedRepairPlan_StaleRepairsSkipped(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) beforeLogs := listStaleSkipLogs(t) diff --git a/tests/integration/cdc_busy_table_test.go b/tests/integration/cdc_busy_table_test.go index c557eb3..8a2771d 100644 --- a/tests/integration/cdc_busy_table_test.go +++ b/tests/integration/cdc_busy_table_test.go @@ -192,7 +192,7 @@ func TestCDCFallbackToSlotLSN(t *testing.T) { err := pgCluster.Node1Pool.QueryRow(ctx, "SELECT ($1::pg_lsn + 16)::text", slotConfirmed).Scan(&bumpedStart) require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "UPDATE spock.ace_cdc_metadata SET start_lsn = $1 WHERE publication_name = $2", bumpedStart, config.Cfg.MTree.CDC.PublicationName) + _, err = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s.ace_cdc_metadata SET start_lsn = $1 WHERE publication_name = $2", config.Cfg.MTree.Schema), bumpedStart, config.Cfg.MTree.CDC.PublicationName) require.NoError(t, err) _, err = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = email || '.cdc_fallback' WHERE index = 2", qualifiedTableName)) @@ -263,7 +263,7 @@ func currentMetadataLSN(t *testing.T, ctx context.Context) pglogrepl.LSN { func metadataStartLSN(t *testing.T, ctx context.Context) pglogrepl.LSN { t.Helper() var lsnStr string - err := pgCluster.Node1Pool.QueryRow(ctx, "SELECT start_lsn FROM spock.ace_cdc_metadata WHERE publication_name = $1", config.Cfg.MTree.CDC.PublicationName).Scan(&lsnStr) + err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT start_lsn FROM %s.ace_cdc_metadata WHERE publication_name = $1", config.Cfg.MTree.Schema), config.Cfg.MTree.CDC.PublicationName).Scan(&lsnStr) require.NoError(t, err) return mustParseLSN(t, lsnStr) } diff --git a/tests/integration/docker-compose-native.yaml b/tests/integration/docker-compose-native.yaml new file mode 100644 index 0000000..be60f2b --- /dev/null +++ b/tests/integration/docker-compose-native.yaml @@ -0,0 +1,43 @@ +############################################################################# +# +# ACE - Active Consistency Engine +# +# Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +# +# This software is released under the PostgreSQL License: +# https://opensource.org/license/postgresql +# +############################################################################# + +# Minimal docker-compose with vanilla PostgreSQL images (no spock extension). +# Used by TestNativePG_* tests to verify that table-diff and table-repair +# work without the spock extension. +services: + native-n1: + image: postgres:17 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + command: + - "postgres" + - "-c" + - "track_commit_timestamp=on" + - "-c" + - "wal_level=logical" + ports: + - "5432" + native-n2: + image: postgres:17 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + command: + - "postgres" + - "-c" + - "track_commit_timestamp=on" + - "-c" + - "wal_level=logical" + ports: + - "5432" diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index ebef9e6..c31402f 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -18,14 +18,13 @@ import ( "io" "log" "os" - "path/filepath" - "sort" "strings" "testing" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/internal/consistency/mtree" "github.com/pgedge/ace/internal/consistency/repair" "github.com/stretchr/testify/require" ) @@ -285,54 +284,16 @@ func loadDataFromCSV( } func newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *repair.TableRepairTask { - task := repair.NewTableRepairTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.SourceOfTruth = sourceOfTruthNode - task.QualifiedTableName = qualifiedTableName - task.DiffFilePath = diffFilePath - task.Nodes = "all" - return task + return newSpockEnv().newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath) +} + +func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { + return newSpockEnv().newMerkleTreeTask(t, qualifiedTableName, nodes) } func repairTable(t *testing.T, qualifiedTableName, sourceOfTruthNode string) { t.Helper() - - files, err := filepath.Glob("*_diffs-*.json") - if err != nil { - t.Fatalf("Failed to find diff files: %v", err) - } - if len(files) == 0 { - log.Println("No diff file found to repair from, skipping repair.") - return - } - - sort.Slice(files, func(i, j int) bool { - fi, errI := os.Stat(files[i]) - if errI != nil { - t.Logf("Warning: could not stat file %s: %v", files[i], errI) - return false - } - fj, errJ := os.Stat(files[j]) - if errJ != nil { - t.Logf("Warning: could not stat file %s: %v", files[j], errJ) - return false - } - return fi.ModTime().After(fj.ModTime()) - }) - - latestDiffFile := files[0] - log.Printf("Using latest diff file for repair: %s", latestDiffFile) - - repairTask := newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, latestDiffFile) - - time.Sleep(2 * time.Second) - - if err := repairTask.Run(false); err != nil { - t.Fatalf("Failed to repair table: %v", err) - } - - log.Printf("Table '%s' repaired successfully using %s as source of truth.", qualifiedTableName, sourceOfTruthNode) + newSpockEnv().repairTable(t, qualifiedTableName, sourceOfTruthNode) } // getCommitTimestamp retrieves the commit timestamp for a specific row diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 98f5175..e6374cf 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -27,7 +27,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/types" - "github.com/stretchr/testify/require" tcLog "github.com/testcontainers/testcontainers-go/log" "github.com/testcontainers/testcontainers-go/modules/compose" "github.com/testcontainers/testcontainers-go/wait" @@ -476,18 +475,5 @@ func setupSharedCustomersTable(tableName string) error { // after other tests may have modified the table. func resetSharedTable(t *testing.T, tableName string) { t.Helper() - ctx := context.Background() - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - csvPath, err := filepath.Abs(defaultCsvFilePath + tableName + ".csv") - require.NoError(t, err) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "enable repair_mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "truncate %s on %s", qualifiedTableName, nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "disable repair_mode on %s", nodeName) - require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, tableName, csvPath), "load CSV into %s on %s", qualifiedTableName, nodeName) - } + newSpockEnv().resetSharedTable(t, tableName) } diff --git a/tests/integration/merkle_tree_test.go b/tests/integration/merkle_tree_test.go index 003a6fc..97c6a88 100644 --- a/tests/integration/merkle_tree_test.go +++ b/tests/integration/merkle_tree_test.go @@ -27,7 +27,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" - "github.com/pgedge/ace/internal/consistency/mtree" "github.com/pgedge/ace/internal/infra/cdc" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/types" @@ -35,89 +34,80 @@ import ( ) func TestMerkleTreeSimplePK(t *testing.T) { + env := newSpockEnv() tableName := "customers" - runMerkleTreeTests(t, tableName) + runMerkleTreeTests(t, env, tableName) } func TestMerkleTreeCompositePK(t *testing.T) { + env := newSpockEnv() tableName := "customers" ctx := context.Background() - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) require.NoError(t, err) } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) require.NoError(t, err) } }) - runMerkleTreeTests(t, tableName) + runMerkleTreeTests(t, env, tableName) } -func runMerkleTreeTests(t *testing.T, tableName string) { +func runMerkleTreeTests(t *testing.T, env *testEnv, tableName string) { if tableName == "customers" { - resetSharedTable(t, "customers") + env.resetSharedTable(t, "customers") } t.Run("Init", func(t *testing.T) { - testMerkleTreeInit(t, tableName) + testMerkleTreeInit(t, env, tableName) }) t.Run("Build", func(t *testing.T) { - testMerkleTreeBuild(t, tableName) + testMerkleTreeBuild(t, env, tableName) }) if tableName == "customers" { t.Run("Diff_DataOnlyOnNode1", func(t *testing.T) { - testMerkleTreeDiffDataOnlyOnNode1(t, tableName) + testMerkleTreeDiffDataOnlyOnNode1(t, env, tableName) }) t.Run("Diff_ModifiedRows", func(t *testing.T) { - testMerkleTreeDiffModifiedRows(t, tableName) + testMerkleTreeDiffModifiedRows(t, env, tableName) }) t.Run("Diff_BoundaryModifications", func(t *testing.T) { - testMerkleTreeDiffBoundaryModifications(t, tableName) + testMerkleTreeDiffBoundaryModifications(t, env, tableName) }) t.Run("MergeInitialRanges", func(t *testing.T) { - testMerkleTreeMergeInitialRanges(t, tableName) + testMerkleTreeMergeInitialRanges(t, env, tableName) }) t.Run("MergeMiddleRanges", func(t *testing.T) { - testMerkleTreeMergeMiddleRanges(t, tableName) + testMerkleTreeMergeMiddleRanges(t, env, tableName) }) t.Run("MergeLastRanges", func(t *testing.T) { - testMerkleTreeMergeLastRanges(t, tableName) + testMerkleTreeMergeLastRanges(t, env, tableName) }) t.Run("SplitInitialRanges", func(t *testing.T) { - testMerkleTreeSplitInitialRanges(t, tableName) + testMerkleTreeSplitInitialRanges(t, env, tableName) }) t.Run("SplitMiddleRanges", func(t *testing.T) { - testMerkleTreeSplitMiddleRanges(t, tableName) + testMerkleTreeSplitMiddleRanges(t, env, tableName) }) t.Run("SplitLastRanges", func(t *testing.T) { - testMerkleTreeSplitLastRanges(t, tableName) + testMerkleTreeSplitLastRanges(t, env, tableName) }) t.Run("ContinuousCDC", func(t *testing.T) { - testMerkleTreeContinuousCDC(t, tableName) + testMerkleTreeContinuousCDC(t, env, tableName) }) } t.Run("Teardown", func(t *testing.T) { - testMerkleTreeTeardown(t, tableName) + testMerkleTreeTeardown(t, env, tableName) }) } -func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { - t.Helper() - task := mtree.NewMerkleTreeTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.QualifiedTableName = qualifiedTableName - task.Nodes = strings.Join(nodes, ",") - task.BlockSize = 1000 - return task -} - -func testMerkleTreeInit(t *testing.T, tableName string) { +func testMerkleTreeInit(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") @@ -133,7 +123,7 @@ func testMerkleTreeInit(t *testing.T, tableName string) { cdcPubName := config.Cfg.MTree.CDC.PublicationName cdcSlotName := config.Cfg.MTree.CDC.SlotName - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { require.True(t, schemaExists(t, ctx, pool, aceSchema), "Schema '%s' should exist", aceSchema) require.True(t, functionExists(t, ctx, pool, "bytea_xor", aceSchema), "Function 'bytea_xor' should exist in schema '%s'", aceSchema) require.True(t, tableExists(t, ctx, pool, "ace_cdc_metadata", aceSchema), "Table 'ace_cdc_metadata' should exist in schema '%s'", aceSchema) @@ -142,11 +132,11 @@ func testMerkleTreeInit(t *testing.T, tableName string) { } } -func testMerkleTreeBuild(t *testing.T, tableName string) { +func testMerkleTreeBuild(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -165,7 +155,7 @@ func testMerkleTreeBuild(t *testing.T, tableName string) { aceSchema := config.Cfg.MTree.Schema mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) - pool := pgCluster.Node1Pool + pool := env.N1Pool require.True(t, tableExists(t, ctx, pool, mtreeTableName, aceSchema), "Merkle tree table '%s' should exist", mtreeTableName) @@ -186,11 +176,11 @@ func testMerkleTreeBuild(t *testing.T, tableName string) { require.Greater(t, totalNodeCount, leafNodeCount, "Total nodes should be greater than leaf nodes") } -func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { +func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -201,7 +191,7 @@ func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -210,55 +200,51 @@ func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { err = mtreeTask.BuildMtree() require.NoError(t, err, "BuildMtree should succeed") - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1", qualifiedTableName) - _, err = tx.Exec(ctx, updateSQL) - require.NoError(t, err) - } else { - var customerID string - err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT customer_id FROM %s WHERE index = 1 LIMIT 1", qualifiedTableName)).Scan(&customerID) - require.NoError(t, err, "could not get customer_id for index 1") - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1 AND customer_id = $1", qualifiedTableName) - _, err = tx.Exec(ctx, updateSQL, customerID) - require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + if mtreeTask.SimplePrimaryKey { + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1", qualifiedTableName) + _, err = tx.Exec(ctx, updateSQL) + require.NoError(t, err) + } else { + var customerID string + err := env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT customer_id FROM %s WHERE index = 1 LIMIT 1", qualifiedTableName)).Scan(&customerID) + require.NoError(t, err, "could not get customer_id for index 1") + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1 AND customer_id = $1", qualifiedTableName) + _, err = tx.Exec(ctx, updateSQL, customerID) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found. Result: %+v", pairKey, mtreeTask.DiffResult) - require.GreaterOrEqual(t, len(nodeDiffs.Rows[serviceN1]), 1, "Expected at least 1 modified row on %s, got %d", serviceN1, len(nodeDiffs.Rows[serviceN1])) - require.GreaterOrEqual(t, len(nodeDiffs.Rows[serviceN2]), 1, "Expected at least 1 original row on %s, got %d", serviceN2, len(nodeDiffs.Rows[serviceN2])) + require.GreaterOrEqual(t, len(nodeDiffs.Rows[env.ServiceN1]), 1, "Expected at least 1 modified row on %s, got %d", env.ServiceN1, len(nodeDiffs.Rows[env.ServiceN1])) + require.GreaterOrEqual(t, len(nodeDiffs.Rows[env.ServiceN2]), 1, "Expected at least 1 original row on %s, got %d", env.ServiceN2, len(nodeDiffs.Rows[env.ServiceN2])) require.GreaterOrEqual(t, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], 1, "Expected summary diff count to be at least 1") // Find the row with index=1 in the diff (block-based diff may include other rows in the same block) var diffRowN1, diffRowN2 types.OrderedMap - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if idx, ok := row.Get("index"); ok && idx.(int32) == 1 { diffRowN1 = row break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if idx, ok := row.Get("index"); ok && idx.(int32) == 1 { diffRowN2 = row break @@ -280,23 +266,23 @@ type compositeBoundaryKey struct { CustomerID string } -func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { +func testMerkleTreeDiffBoundaryModifications(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") err = mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = regexp_replace(email, '\\\\.(cdc_hw|cdc_fallback)$', '')", qualifiedTableName)) + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = regexp_replace(email, '\\\\.(cdc_hw|cdc_fallback)$', '')", qualifiedTableName)) err := mtreeTask.MtreeTeardown() if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -311,7 +297,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { var boundaryPkeys []any if mtreeTask.SimplePrimaryKey { query := fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND range_start IS NOT NULL UNION SELECT range_end FROM %s.%s WHERE node_level=0 AND range_end IS NOT NULL", aceSchema, mtreeTableName, aceSchema, mtreeTableName) - rows, err := pgCluster.Node1Pool.Query(ctx, query) + rows, err := env.N1Pool.Query(ctx, query) require.NoError(t, err) defer rows.Close() for rows.Next() { @@ -323,7 +309,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } else { // For composite keys, we get the text representation and parse it. query := fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND range_start IS NOT NULL UNION SELECT range_end::text FROM %s.%s WHERE node_level=0 AND range_end IS NOT NULL", aceSchema, mtreeTableName, aceSchema, mtreeTableName) - rows, err := pgCluster.Node1Pool.Query(ctx, query) + rows, err := env.N1Pool.Query(ctx, query) require.NoError(t, err) defer rows.Close() re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) @@ -367,59 +353,55 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } pkeysToModify = uniquePkeysList - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - for _, pkey := range pkeysToModify { - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1", qualifiedTableName, pkey) - _, err := tx.Exec(ctx, updateSQL, pkey) - require.NoError(t, err) - } - } else { - for _, pkey := range pkeysToModify { - ckey := pkey.(compositeBoundaryKey) - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName, ckey.Index) - _, err := tx.Exec(ctx, updateSQL, ckey.Index, ckey.CustomerID) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + if mtreeTask.SimplePrimaryKey { + for _, pkey := range pkeysToModify { + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1", qualifiedTableName, pkey) + _, err := tx.Exec(ctx, updateSQL, pkey) + require.NoError(t, err) + } + } else { + for _, pkey := range pkeysToModify { + ckey := pkey.(compositeBoundaryKey) + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName, ckey.Index) + _, err := tx.Exec(ctx, updateSQL, ckey.Index, ckey.CustomerID) + require.NoError(t, err) + } } - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) expectedDiffCount := len(pkeysToModify) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN1]), "Expected %d modified rows on %s", expectedDiffCount, serviceN1) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN2]), "Expected %d modified rows on %s", expectedDiffCount, serviceN2) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN1]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN1) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN2) require.Equal(t, expectedDiffCount, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", expectedDiffCount) if mtreeTask.SimplePrimaryKey { for _, pkey := range pkeysToModify { foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if rIndex, ok := row.Get("index"); ok && rIndex == pkey { foundN1 = true break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if rIndex, ok := row.Get("index"); ok && rIndex == pkey { foundN2 = true emailVal, _ := row.Get("email") @@ -436,7 +418,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { ckey := pkey.(compositeBoundaryKey) foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { rIndex, _ := row.Get("index") rCustomerID, _ := row.Get("customer_id") if rIndex == ckey.Index && rCustomerID == ckey.CustomerID { @@ -444,7 +426,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { rIndex, _ := row.Get("index") rCustomerID, _ := row.Get("customer_id") if rIndex == ckey.Index && rCustomerID == ckey.CustomerID { @@ -461,14 +443,14 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } } -func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { +func testMerkleTreeContinuousCDC(t *testing.T, env *testEnv, tableName string) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() largeTableName := "customers_1M" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -479,7 +461,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -493,7 +475,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { wg.Add(1) go func() { defer wg.Done() - nodeInfo := pgCluster.ClusterNodes[0] + nodeInfo := env.ClusterNodes[0] cdc.ListenForChanges(ctx, nodeInfo) }() @@ -501,7 +483,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { aceSchema := config.Cfg.MTree.Schema mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) - pool := pgCluster.Node1Pool + pool := env.N1Pool var leafNodeCount int err = pool.QueryRow(context.Background(), fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCount) @@ -516,75 +498,71 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { require.NoError(t, err) defer tx.Rollback(context.Background()) - _, err = tx.Exec(context.Background(), "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - var firstBlockStart, middleBlockStart, lastBlockStart int32 - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstBlockStart) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleBlockStart) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastBlockStart) - require.NoError(t, err) - - // DELETE + INSERT in first block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), firstBlockStart) - require.NoError(t, err) - insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) - _, err = tx.Exec(context.Background(), insertSQL, firstBlockStart) - require.NoError(t, err) + env.withRepairModeTx(t, context.Background(), tx, func() { + if mtreeTask.SimplePrimaryKey { + var firstBlockStart, middleBlockStart, lastBlockStart int32 + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstBlockStart) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleBlockStart) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastBlockStart) + require.NoError(t, err) - // UPDATE in middle block - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test@example.com' WHERE index = $1", qualifiedTableName) - _, err = tx.Exec(context.Background(), updateSQL, middleBlockStart) - require.NoError(t, err) + // DELETE + INSERT in first block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), firstBlockStart) + require.NoError(t, err) + insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) + _, err = tx.Exec(context.Background(), insertSQL, firstBlockStart) + require.NoError(t, err) - // DELETE in last block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), lastBlockStart) - require.NoError(t, err) + // UPDATE in middle block + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test@example.com' WHERE index = $1", qualifiedTableName) + _, err = tx.Exec(context.Background(), updateSQL, middleBlockStart) + require.NoError(t, err) - } else { // Composite Primary Key - re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) - parseKey := func(keyStr string) (int32, string) { - matches := re.FindStringSubmatch(keyStr) - require.Len(t, matches, 3, "could not parse composite key: %s", keyStr) - index, err := strconv.Atoi(matches[1]) + // DELETE in last block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), lastBlockStart) require.NoError(t, err) - return int32(index), matches[2] - } - var firstKeyStr, middleKeyStr, lastKeyStr string - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstKeyStr) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleKeyStr) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastKeyStr) - require.NoError(t, err) + } else { // Composite Primary Key + re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) + parseKey := func(keyStr string) (int32, string) { + matches := re.FindStringSubmatch(keyStr) + require.Len(t, matches, 3, "could not parse composite key: %s", keyStr) + index, err := strconv.Atoi(matches[1]) + require.NoError(t, err) + return int32(index), matches[2] + } - firstIdx, firstCustId := parseKey(firstKeyStr) - middleIdx, middleCustId := parseKey(middleKeyStr) - lastIdx, lastCustId := parseKey(lastKeyStr) + var firstKeyStr, middleKeyStr, lastKeyStr string + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstKeyStr) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleKeyStr) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastKeyStr) + require.NoError(t, err) - // DELETE + INSERT in first block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), firstIdx, firstCustId) - require.NoError(t, err) - insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) - _, err = tx.Exec(context.Background(), insertSQL, firstIdx, firstCustId) - require.NoError(t, err) + firstIdx, firstCustId := parseKey(firstKeyStr) + middleIdx, middleCustId := parseKey(middleKeyStr) + lastIdx, lastCustId := parseKey(lastKeyStr) - // UPDATE in middle block - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test.composite@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName) - _, err = tx.Exec(context.Background(), updateSQL, middleIdx, middleCustId) - require.NoError(t, err) + // DELETE + INSERT in first block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), firstIdx, firstCustId) + require.NoError(t, err) + insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) + _, err = tx.Exec(context.Background(), insertSQL, firstIdx, firstCustId) + require.NoError(t, err) - // DELETE in last block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), lastIdx, lastCustId) - require.NoError(t, err) - } + // UPDATE in middle block + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test.composite@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName) + _, err = tx.Exec(context.Background(), updateSQL, middleIdx, middleCustId) + require.NoError(t, err) - _, err = tx.Exec(context.Background(), "SELECT spock.repair_mode(false)") - require.NoError(t, err) + // DELETE in last block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), lastIdx, lastCustId) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(context.Background())) time.Sleep(5 * time.Second) @@ -620,8 +598,8 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { cancel() wg.Wait() - nodes = []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodes) + nodes = []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodes) err = tdTask.RunChecks(false) require.NoError(t, err, "RunChecks after CDC should succeed") @@ -645,35 +623,35 @@ const ( splitLast splitCase = "last" ) -func testMerkleTreeMergeInitialRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, initial) +func testMerkleTreeMergeInitialRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, initial) } -func testMerkleTreeMergeMiddleRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, middle) +func testMerkleTreeMergeMiddleRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, middle) } -func testMerkleTreeMergeLastRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, last) +func testMerkleTreeMergeLastRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, last) } -func testMerkleTreeSplitInitialRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitInitial) +func testMerkleTreeSplitInitialRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitInitial) } -func testMerkleTreeSplitMiddleRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitMiddle) +func testMerkleTreeSplitMiddleRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitMiddle) } -func testMerkleTreeSplitLastRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitLast) +func testMerkleTreeSplitLastRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitLast) } -func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { +func runMerkleTreeMergeTest(t *testing.T, env *testEnv, tableName string, mc mergeCase) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -684,7 +662,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -697,7 +675,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) var leafNodeCountBefore int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) require.NoError(t, err) var startPos, endPos int @@ -720,20 +698,20 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if mtreeTask.SimplePrimaryKey { var startKey int32 - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startKey) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startKey) require.NoError(t, err) startRange = []any{startKey} if endPos != -1 { var endKey int32 - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endKey) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endKey) require.NoError(t, err) endRange = []any{endKey} } } else { re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) var startStr string - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startStr) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startStr) require.NoError(t, err) startMatches := re.FindStringSubmatch(startStr) require.Len(t, startMatches, 3, "should parse composite key from string") @@ -742,7 +720,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if endPos != -1 { var endStr string - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endStr) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endStr) require.NoError(t, err) endMatches := re.FindStringSubmatch(endStr) require.Len(t, endMatches, 3, "should parse composite key from string") @@ -751,47 +729,43 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { } } - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - var cmdTag pgconn.CommandTag - if endPos == -1 { // Deletion to the end - if mtreeTask.SimplePrimaryKey { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0]) - } else { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2)", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1]) - } - } else { // Deletion within a range - if mtreeTask.SimplePrimaryKey { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1 AND index <= $2", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], endRange[0]) - } else { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2) AND (index, customer_id) <= ($3, $4)", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1], endRange[0], endRange[1]) + env.withRepairModeTx(t, ctx, tx, func() { + var cmdTag pgconn.CommandTag + if endPos == -1 { // Deletion to the end + if mtreeTask.SimplePrimaryKey { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0]) + } else { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2)", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1]) + } + } else { // Deletion within a range + if mtreeTask.SimplePrimaryKey { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1 AND index <= $2", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], endRange[0]) + } else { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2) AND (index, customer_id) <= ($3, $4)", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1], endRange[0], endRange[1]) + } } - } - require.NoError(t, err) - deletedCount = cmdTag.RowsAffected() - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + require.NoError(t, err) + deletedCount = cmdTag.RowsAffected() + }) require.NoError(t, tx.Commit(ctx)) require.Greater(t, deletedCount, int64(0), "should have deleted some rows") - mtreeTask.Nodes = serviceN1 + mtreeTask.Nodes = env.ServiceN1 mtreeTask.Rebalance = true err = mtreeTask.UpdateMtree(true) require.NoError(t, err, "UpdateMtree with rebalance should succeed") var leafNodeCountAfter int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) require.NoError(t, err) require.Less(t, leafNodeCountAfter, leafNodeCountBefore, "Number of leaf nodes should decrease after merge") @@ -802,25 +776,25 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) - require.Equal(t, 0, len(nodeDiffs.Rows[serviceN1]), "Expected 0 extra rows on %s", serviceN1) - require.Equal(t, int(deletedCount), len(nodeDiffs.Rows[serviceN2]), "Expected %d missing rows on %s", deletedCount, serviceN1) + require.Equal(t, 0, len(nodeDiffs.Rows[env.ServiceN1]), "Expected 0 extra rows on %s", env.ServiceN1) + require.Equal(t, int(deletedCount), len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d missing rows on %s", deletedCount, env.ServiceN1) require.Equal(t, int(deletedCount), mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", deletedCount) } -func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { +func runMerkleTreeSplitTest(t *testing.T, env *testEnv, tableName string, sc splitCase) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) largeTableName := "customers_1M" - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -839,35 +813,32 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { // No deletes needed for last split } if deleteRange != nil { - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = tx.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) - require.NoError(t, err) - - createSQL := fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s WHERE index >= %d AND index <= %d", tempTableName, qualifiedTableName, deleteRange[0], deleteRange[1]) - _, err = tx.Exec(ctx, createSQL) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) + require.NoError(t, err) - _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE index >= %d AND index <= %d", qualifiedTableName, deleteRange[0], deleteRange[1])) - require.NoError(t, err) + createSQL := fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s WHERE index >= %d AND index <= %d", tempTableName, qualifiedTableName, deleteRange[0], deleteRange[1]) + _, err = tx.Exec(ctx, createSQL) + require.NoError(t, err) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE index >= %d AND index <= %d", qualifiedTableName, deleteRange[0], deleteRange[1])) + require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) } t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) } - _, err = pgCluster.Node2Pool.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) + _, err = env.N2Pool.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) if err != nil { t.Logf("Warning: failed to drop temp table %s during cleanup: %v", tempTableName, err) } @@ -881,7 +852,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { err = mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") - mtreeTask.Nodes = serviceN2 + mtreeTask.Nodes = env.ServiceN2 mtreeTask.RunChecks(false) err = mtreeTask.BuildMtree() require.NoError(t, err, "BuildMtree should succeed") @@ -890,7 +861,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) var leafNodeCountBefore int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) require.NoError(t, err) var q string @@ -899,28 +870,24 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { q = fmt.Sprintf("INSERT INTO %s SELECT * FROM %s order by index", qualifiedTableName, tempTableName) case splitLast: var maxIndex int32 - err := pgCluster.Node2Pool.QueryRow(ctx, "SELECT max(index) FROM "+qualifiedTableName).Scan(&maxIndex) + err := env.N2Pool.QueryRow(ctx, "SELECT max(index) FROM "+qualifiedTableName).Scan(&maxIndex) require.NoError(t, err) insertStartIndex := maxIndex + 1 q = fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index >= %d order by index LIMIT %d", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize(), insertStartIndex, rowsToInsert) } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, q) - require.NoError(t, err) - if sc != splitLast { - _, err = tx.Exec(ctx, "DROP TABLE "+tempTableName) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, q) require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + if sc != splitLast { + _, err = tx.Exec(ctx, "DROP TABLE "+tempTableName) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) mtreeTask.Rebalance = false @@ -928,7 +895,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { require.NoError(t, err, "UpdateMtree should succeed") var leafNodeCountAfter int - pool, err := connectToNode(pgCluster.Node2Host, pgCluster.Node2Port, pgEdgeUser, pgEdgePassword, dbName) + pool, err := connectToNode(env.N2Host, env.N2Port, env.DBUser, env.DBPassword, env.DBName) require.NoError(t, err) err = pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) @@ -943,11 +910,11 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { // require.Empty(t, mtreeTask.DiffResult.NodeDiffs, "Merkle trees should be in sync after identical splits") } -func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { +func testMerkleTreeDiffModifiedRows(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -958,7 +925,7 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -984,44 +951,40 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - for _, mod := range modifications { - updateSQL := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE index = $2", qualifiedTableName, mod.field) - _, err = tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range modifications { + updateSQL := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE index = $2", qualifiedTableName, mod.field) + _, err = tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) expectedDiffCount := len(modifications) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN1]), "Expected %d modified rows on %s", expectedDiffCount, serviceN1) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN2]), "Expected %d modified rows on %s", expectedDiffCount, serviceN2) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN1]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN1) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN2) require.Equal(t, expectedDiffCount, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", expectedDiffCount) for _, mod := range modifications { foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if rIndex, ok := row.Get("index"); ok && rIndex.(int32) == int32(mod.indexVal) { foundN1 = true fieldVal, _ := row.Get(mod.field) @@ -1029,7 +992,7 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if rIndex, ok := row.Get("index"); ok && rIndex.(int32) == int32(mod.indexVal) { foundN2 = true fieldVal, _ := row.Get(mod.field) @@ -1042,11 +1005,11 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { } } -func testMerkleTreeTeardown(t *testing.T, tableName string) { +func testMerkleTreeTeardown(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") @@ -1058,7 +1021,7 @@ func testMerkleTreeTeardown(t *testing.T, tableName string) { cdcPubName := config.Cfg.MTree.CDC.PublicationName cdcSlotName := config.Cfg.MTree.CDC.SlotName - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { require.False(t, tableExists(t, ctx, pool, "ace_cdc_metadata", aceSchema), "Table 'ace_cdc_metadata' should NOT exist after teardown") require.False(t, publicationExists(t, ctx, pool, cdcPubName), "Publication '%s' should NOT exist after teardown", cdcPubName) require.False(t, replicationSlotExists(t, ctx, pool, cdcSlotName), "Replication slot '%s' should NOT exist after teardown", cdcSlotName) @@ -1072,6 +1035,11 @@ func testMerkleTreeTeardown(t *testing.T, tableName string) { // numeric values with different PostgreSQL scales (e.g., 3000.00 vs 3000.0) // produce identical merkle tree hashes and are NOT reported as differences. func TestMerkleTreeNumericScaleInvariance(t *testing.T) { + env := newSpockEnv() + testMerkleTreeNumericScaleInvariance(t, env) +} + +func testMerkleTreeNumericScaleInvariance(t *testing.T, env *testEnv) { ctx := context.Background() numericTable := "ace_numeric_scale_test" qualifiedTable := fmt.Sprintf("%s.%s", testSchema, numericTable) @@ -1083,12 +1051,12 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { amount NUMERIC, label TEXT )`, qualifiedTable) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createSQL) require.NoError(t, err, "failed to create numeric test table") } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", qualifiedTable)) } }) @@ -1112,39 +1080,37 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { {5, "0.00", "0.0", "scale-diff-zero"}, } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { tx, err := pool.Begin(ctx) require.NoError(t, err) defer func(tx pgx.Tx) { _ = tx.Rollback(ctx) }(tx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for _, r := range rows { - amt := r.n1Amt - if pool == pgCluster.Node2Pool { - amt = r.n2Amt + env.withRepairModeTx(t, ctx, tx, func() { + for _, r := range rows { + amt := r.n1Amt + if pool == env.N2Pool { + amt = r.n2Amt + } + insertSQL := fmt.Sprintf("INSERT INTO %s (id, amount, label) VALUES ($1, %s, $2) ON CONFLICT (id) DO UPDATE SET amount = %s, label = $2", qualifiedTable, amt, amt) + _, err := tx.Exec(ctx, insertSQL, r.id, r.label) + require.NoError(t, err) } - insertSQL := fmt.Sprintf("INSERT INTO %s (id, amount, label) VALUES ($1, %s, $2) ON CONFLICT (id) DO UPDATE SET amount = %s, label = $2", qualifiedTable, amt, amt) - _, err := tx.Exec(ctx, insertSQL, r.id, r.label) - require.NoError(t, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) } // Verify the scales really are different in storage by checking ::text var n1Text, n2Text string - err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n1Text) + err := env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n1Text) require.NoError(t, err) - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n2Text) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n2Text) require.NoError(t, err) t.Logf("Node1 amount::text = %q, Node2 amount::text = %q", n1Text, n2Text) // They should differ in scale representation require.NotEqual(t, n1Text, n2Text, "precondition: numeric ::text should differ between nodes (different scales)") // Now run merkle tree init + build + diff - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTable, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTable, nodes) mtreeTask.BlockSize = 1000 // must be >= 1000 for RunChecks; table has 5 rows so 1 block err = mtreeTask.RunChecks(false) @@ -1165,16 +1131,16 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { // The key assertion: there should be ZERO differences because trim_scale // normalizes numeric values before hashing. - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } if nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey]; ok { - n1Rows := len(nodeDiffs.Rows[serviceN1]) - n2Rows := len(nodeDiffs.Rows[serviceN2]) - require.Equal(t, 0, n1Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", serviceN1, n1Rows) - require.Equal(t, 0, n2Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", serviceN2, n2Rows) + n1Rows := len(nodeDiffs.Rows[env.ServiceN1]) + n2Rows := len(nodeDiffs.Rows[env.ServiceN2]) + require.Equal(t, 0, n1Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", env.ServiceN1, n1Rows) + require.Equal(t, 0, n2Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", env.ServiceN2, n2Rows) } // If pairKey is not in NodeDiffs at all, that means zero diffs — which is correct. @@ -1195,9 +1161,9 @@ func functionExists(t *testing.T, ctx context.Context, pool *pgxpool.Pool, funct var exists bool err := pool.QueryRow(ctx, ` SELECT EXISTS ( - SELECT 1 - FROM pg_proc p - JOIN pg_namespace n ON p.pronamespace = n.oid + SELECT 1 + FROM pg_proc p + JOIN pg_namespace n ON p.pronamespace = n.oid WHERE p.proname = $1 AND n.nspname = $2 )`, functionName, schemaName).Scan(&exists) require.NoError(t, err) @@ -1209,8 +1175,8 @@ func tableExists(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableNam var exists bool err := pool.QueryRow(ctx, ` SELECT EXISTS ( - SELECT 1 - FROM information_schema.tables + SELECT 1 + FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 )`, schemaName, tableName).Scan(&exists) require.NoError(t, err) diff --git a/tests/integration/native_pg_test.go b/tests/integration/native_pg_test.go new file mode 100644 index 0000000..abf93ed --- /dev/null +++ b/tests/integration/native_pg_test.go @@ -0,0 +1,629 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package integration + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/db/queries" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go/modules/compose" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + nativeUser = "postgres" + nativePassword = "password" + nativeDBName = "testdb" + nativeServiceN1 = "native-n1" + nativeServiceN2 = "native-n2" + nativeContainerPort = "5432/tcp" + nativeComposeFile = "docker-compose-native.yaml" + nativeClusterName = "native_test_cluster" +) + +// nativeClusterState holds connections and config for vanilla PG test containers. +type nativeClusterState struct { + stack compose.ComposeStack + n1Host string + n1Port string + n1Pool *pgxpool.Pool + n2Host string + n2Port string + n2Pool *pgxpool.Pool +} + +func setupNativeCluster(t *testing.T) *nativeClusterState { + t.Helper() + ctx := context.Background() + + absCompose, err := filepath.Abs(nativeComposeFile) + require.NoError(t, err, "resolve native compose file path") + + identifier := strings.ToLower(fmt.Sprintf("ace_native_test_%d", time.Now().UnixNano())) + + waitN1 := wait.ForListeningPort(nat.Port(nativeContainerPort)). + WithStartupTimeout(startupTimeout). + WithPollInterval(5 * time.Second) + waitN2 := wait.ForListeningPort(nat.Port(nativeContainerPort)). + WithStartupTimeout(startupTimeout). + WithPollInterval(5 * time.Second) + + stack, err := compose.NewDockerComposeWith( + compose.StackIdentifier(identifier), + compose.WithStackFiles(absCompose), + ) + require.NoError(t, err, "create native compose stack") + + execErr := stack. + WaitForService(nativeServiceN1, waitN1). + WaitForService(nativeServiceN2, waitN2). + Up(ctx, compose.Wait(true)) + require.NoError(t, execErr, "start native compose stack") + + state := &nativeClusterState{stack: stack} + + // Get mapped host/port for n1 + n1Container, err := stack.ServiceContainer(ctx, nativeServiceN1) + require.NoError(t, err, "get native-n1 container") + n1Host, err := n1Container.Host(ctx) + require.NoError(t, err, "get native-n1 host") + cPort, err := nat.NewPort("tcp", "5432") + require.NoError(t, err) + n1MappedPort, err := n1Container.MappedPort(ctx, cPort) + require.NoError(t, err, "get native-n1 mapped port") + state.n1Host = n1Host + state.n1Port = n1MappedPort.Port() + + // Get mapped host/port for n2 + n2Container, err := stack.ServiceContainer(ctx, nativeServiceN2) + require.NoError(t, err, "get native-n2 container") + n2Host, err := n2Container.Host(ctx) + require.NoError(t, err, "get native-n2 host") + n2MappedPort, err := n2Container.MappedPort(ctx, cPort) + require.NoError(t, err, "get native-n2 mapped port") + state.n2Host = n2Host + state.n2Port = n2MappedPort.Port() + + // Connect + state.n1Pool, err = connectToNode(state.n1Host, state.n1Port, nativeUser, nativePassword, nativeDBName) + require.NoError(t, err, "connect to native-n1") + state.n2Pool, err = connectToNode(state.n2Host, state.n2Port, nativeUser, nativePassword, nativeDBName) + require.NoError(t, err, "connect to native-n2") + + // Create pgcrypto extension on both nodes + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + _, err = pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS pgcrypto") + require.NoError(t, err, "create pgcrypto extension") + } + + log.Printf("Native PG cluster ready: n1=%s:%s, n2=%s:%s", state.n1Host, state.n1Port, state.n2Host, state.n2Port) + return state +} + +func (s *nativeClusterState) teardown(t *testing.T) { + t.Helper() + if s.n1Pool != nil { + s.n1Pool.Close() + } + if s.n2Pool != nil { + s.n2Pool.Close() + } + if s.stack != nil { + execErr := s.stack.Down( + context.Background(), + compose.RemoveOrphans(true), + compose.RemoveVolumes(true), + ) + if execErr != nil { + t.Logf("Failed to tear down native compose stack: %v", execErr) + } + } + // Clean up diff files and cluster config + files, _ := filepath.Glob("*_diffs-*.json") + for _, f := range files { + os.Remove(f) + } + os.Remove(nativeClusterName + ".json") +} + +// writeClusterConfig writes a cluster config JSON that the diff/repair +// tasks use to discover nodes and credentials. +func (s *nativeClusterState) writeClusterConfig(t *testing.T) { + t.Helper() + cfg := types.ClusterConfig{ + JSONVersion: "1.0", + ClusterName: nativeClusterName, + LogLevel: "info", + UpdateDate: time.Now().Format(time.RFC3339), + PGEdge: struct { + PGVersion int `json:"pg_version"` + AutoStart string `json:"auto_start"` + Spock types.SpockConfig `json:"spock"` + Databases []types.Database `json:"databases"` + }{ + PGVersion: 17, + AutoStart: "yes", + Databases: []types.Database{ + { + DBName: nativeDBName, + DBUser: nativeUser, + DBPassword: nativePassword, + }, + }, + }, + NodeGroups: []types.NodeGroup{ + { + Name: nativeServiceN1, + IsActive: "yes", + PublicIP: s.n1Host, + Port: s.n1Port, + }, + { + Name: nativeServiceN2, + IsActive: "yes", + PublicIP: s.n2Host, + Port: s.n2Port, + }, + }, + } + + data, err := json.MarshalIndent(cfg, "", " ") + require.NoError(t, err, "marshal native cluster config") + require.NoError(t, os.WriteFile(nativeClusterName+".json", data, 0644), "write native cluster config") +} + +// TestNativePG runs the full suite of native PostgreSQL (no spock) tests. +// It starts its own vanilla postgres:17 Docker Compose stack and runs the +// same shared test logic used by the spock tests, via testEnv abstraction. +func TestNativePG(t *testing.T) { + state := setupNativeCluster(t) + t.Cleanup(func() { state.teardown(t) }) + state.writeClusterConfig(t) + + ctx := context.Background() + + // Create the customers table on both nodes and load initial data from CSV. + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + require.NoError(t, createTestTable(ctx, pool, testSchema, "customers")) + } + + env := newNativeEnv(state) + + // Load CSV data so that shared tests (e.g. NoDifferences, DataOnlyOnNode1) + // have a known baseline. + csvPath, err := filepath.Abs(defaultCsvFilePath + "customers.csv") + require.NoError(t, err) + for _, pool := range env.pools() { + require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, "customers", csvPath), + "load CSV into customers") + } + + // Create and load customers_1M table (needed by merkle tree CDC and split tests). + for _, pool := range env.pools() { + require.NoError(t, createTestTable(ctx, pool, testSchema, "customers_1M")) + } + csv1MPath, err := filepath.Abs(defaultCsvFilePath + "customers_1M.csv") + require.NoError(t, err) + for _, pool := range env.pools() { + require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, "customers_1M", csv1MPath), + "load CSV into customers_1M") + } + + // ── Native-specific tests ───────────────────────────────────────────── + + t.Run("CheckSpockInstalled_ReturnsFalse", func(t *testing.T) { + installed, err := queries.CheckSpockInstalled(ctx, state.n1Pool) + require.NoError(t, err) + assert.False(t, installed, "spock should not be installed on vanilla PG") + }) + + t.Run("GetNodeOriginNames_NativeSubscription", func(t *testing.T) { + // Set up a real publication on n1 and subscription on n2 so that + // pg_replication_origin gets populated with a subscription-linked entry. + subName := "test_origin_sub" + pubName := "test_origin_pub" + + // Create a test table and publication on n1. + _, err := state.n1Pool.Exec(ctx, + "CREATE TABLE IF NOT EXISTS public.origin_test (id int PRIMARY KEY, val text)") + require.NoError(t, err) + _, err = state.n1Pool.Exec(ctx, + fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE public.origin_test", pubName)) + require.NoError(t, err) + + // Create the same table on n2 (subscription target). + _, err = state.n2Pool.Exec(ctx, + "CREATE TABLE IF NOT EXISTS public.origin_test (id int PRIMARY KEY, val text)") + require.NoError(t, err) + + // Create a subscription on n2 pointing at n1 via Docker-internal hostname. + connStr := fmt.Sprintf( + "host=%s port=5432 dbname=%s user=%s password=%s", + nativeServiceN1, nativeDBName, nativeUser, nativePassword) + _, err = state.n2Pool.Exec(ctx, + fmt.Sprintf("CREATE SUBSCRIPTION %s CONNECTION '%s' PUBLICATION %s", + subName, connStr, pubName)) + require.NoError(t, err) + + t.Cleanup(func() { + state.n2Pool.Exec(ctx, fmt.Sprintf("DROP SUBSCRIPTION IF EXISTS %s", subName)) + state.n1Pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", pubName)) + state.n1Pool.Exec(ctx, "DROP TABLE IF EXISTS public.origin_test") + state.n2Pool.Exec(ctx, "DROP TABLE IF EXISTS public.origin_test") + }) + + // GetNodeOriginNames on n2 should route to GetNativeNodeOriginNames + // (spock not installed) and return the subscription name as the value. + names, err := queries.GetNodeOriginNames(ctx, state.n2Pool) + require.NoError(t, err, "GetNodeOriginNames should succeed on native PG with subscriptions") + require.NotEmpty(t, names, "should have at least one origin mapping") + + // Verify that the subscription name appears as a value in the map. + found := false + for _, name := range names { + if name == subName { + found = true + break + } + } + assert.True(t, found, + "expected subscription name %q in origin names map, got: %v", subName, names) + + // The key should be a numeric roident (parseable as int). + for id := range names { + _, parseErr := fmt.Sscanf(id, "%d", new(int)) + assert.NoError(t, parseErr, "origin ID %q should be numeric", id) + } + }) + + t.Run("SpockDiff_GracefulError", func(t *testing.T) { + task := diff.NewSpockDiffTask() + task.ClusterName = nativeClusterName + task.DBName = nativeDBName + task.Nodes = nativeServiceN1 + "," + nativeServiceN2 + task.Ctx = context.Background() + task.SkipDBUpdate = true + + err := task.RunChecks(false) + require.NoError(t, err, "spock-diff RunChecks should succeed (it just validates and connects)") + + err = task.ExecuteTask() + require.Error(t, err, "spock-diff should fail on vanilla PG") + assert.Contains(t, err.Error(), "spock extension", "error should mention spock extension") + }) + + t.Run("RepsetDiff_GracefulError", func(t *testing.T) { + task := diff.NewRepsetDiffTask() + task.ClusterName = nativeClusterName + task.DBName = nativeDBName + task.RepsetName = "default" + task.Nodes = "all" + task.SkipDBUpdate = true + + err := task.RunChecks(false) + require.Error(t, err, "repset-diff should fail on vanilla PG") + assert.Contains(t, err.Error(), "spock extension", "error should mention spock extension") + }) + + // ── Shared table-diff tests (simple PK) ─────────────────────────────── + + t.Run("TableDiffSimplePK", func(t *testing.T) { + t.Run("Customers", func(t *testing.T) { + runCustomerTableDiffTests(t, env) + }) + t.Run("MixedCaseIdentifiers", func(t *testing.T) { + testTableDiff_MixedCaseIdentifiers(t, env, false) + }) + t.Run("VariousDataTypes", func(t *testing.T) { + testTableDiff_VariousDataTypes(t, env, false) + }) + t.Run("UUIDColumn", func(t *testing.T) { + testTableDiff_UUIDColumn(t, env, false) + }) + t.Run("ByteaColumnSizeCheck", func(t *testing.T) { + testTableDiff_ByteaColumnSizeCheck(t, env, false) + }) + }) + + // ── Shared table-diff tests (composite PK) ─────────────────────────── + + t.Run("TableDiffCompositePK", func(t *testing.T) { + t.Run("Customers", func(t *testing.T) { + for _, pool := range env.pools() { + err := alterTableToCompositeKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + t.Cleanup(func() { + for _, pool := range env.pools() { + err := revertTableToSimpleKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + }) + runCustomerTableDiffTests(t, env) + }) + t.Run("MixedCaseIdentifiers", func(t *testing.T) { + testTableDiff_MixedCaseIdentifiers(t, env, true) + }) + t.Run("VariousDataTypes", func(t *testing.T) { + testTableDiff_VariousDataTypes(t, env, true) + }) + t.Run("UUIDColumn", func(t *testing.T) { + testTableDiff_UUIDColumn(t, env, true) + }) + t.Run("ByteaColumnSizeCheck", func(t *testing.T) { + testTableDiff_ByteaColumnSizeCheck(t, env, true) + }) + }) + + // ── Shared table-repair tests ───────────────────────────────────────── + + t.Run("TableRepair_UnidirectionalDefault", func(t *testing.T) { + testTableRepair_UnidirectionalDefault(t, env) + }) + t.Run("TableRepair_InsertOnly", func(t *testing.T) { + testTableRepair_InsertOnly(t, env) + }) + t.Run("TableRepair_UpsertOnly", func(t *testing.T) { + testTableRepair_UpsertOnly(t, env) + }) + t.Run("TableRepair_Bidirectional", func(t *testing.T) { + testTableRepair_Bidirectional(t, env) + }) + t.Run("TableRepair_DryRun", func(t *testing.T) { + testTableRepair_DryRun(t, env) + }) + t.Run("TableRepair_GenerateReport", func(t *testing.T) { + testTableRepair_GenerateReport(t, env) + }) + t.Run("TableRepair_FixNulls", func(t *testing.T) { + testTableRepair_FixNulls(t, env) + }) + t.Run("TableRepair_FixNulls_DryRun", func(t *testing.T) { + testTableRepair_FixNulls_DryRun(t, env) + }) + t.Run("TableRepair_FixNulls_BidirectionalUpdate", func(t *testing.T) { + testTableRepair_FixNulls_BidirectionalUpdate(t, env) + }) + t.Run("TableRepair_VariousDataTypes", func(t *testing.T) { + testTableRepair_VariousDataTypes(t, env) + }) + t.Run("TableRepair_TimestampAndTimeTypes", func(t *testing.T) { + testTableRepair_TimestampAndTimeTypes(t, env) + }) + t.Run("TableRepair_LargeBigintPK", func(t *testing.T) { + testTableRepair_LargeBigintPK(t, env) + }) + + // ── Merkle tree tests ──────────────────────────────────────────────── + + t.Run("MerkleTreeSimplePK", func(t *testing.T) { + runMerkleTreeTests(t, env, "customers") + }) + + t.Run("MerkleTreeCompositePK", func(t *testing.T) { + for _, pool := range env.pools() { + err := alterTableToCompositeKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + t.Cleanup(func() { + for _, pool := range env.pools() { + err := revertTableToSimpleKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + }) + runMerkleTreeTests(t, env, "customers") + }) + + t.Run("MerkleTreeNumericScaleInvariance", func(t *testing.T) { + testMerkleTreeNumericScaleInvariance(t, env) + }) + + // ── Native PG preserve-origin test ─────────────────────────────────── + // This test verifies the full diff → preserve-origin repair → verify + // cycle on native PG with real logical replication, including that + // GetNodeOriginNames returns subscription names and that repaired rows + // retain their original replication origin. + + t.Run("TableRepair_PreserveOrigin_NativePG", func(t *testing.T) { + testNativePreserveOrigin(t, state, env) + }) +} + +// getNativeReplicationOrigin retrieves the replication origin for a row on +// native PG (no spock). Uses pg_xact_commit_timestamp_origin to get the +// roident, then resolves it via pg_replication_origin.roname. +func getNativeReplicationOrigin(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualifiedTableName string, id int) string { + t.Helper() + + var roidentStr *string + query := fmt.Sprintf( + `SELECT (pg_xact_commit_timestamp_origin(xmin)).roident::text FROM %s WHERE id = $1`, + qualifiedTableName) + err := pool.QueryRow(ctx, query, id).Scan(&roidentStr) + if err != nil || roidentStr == nil || *roidentStr == "" || *roidentStr == "0" { + return "" + } + + var originName string + err = pool.QueryRow(ctx, + "SELECT roname FROM pg_replication_origin WHERE roident::text = $1", *roidentStr).Scan(&originName) + if err == nil && originName != "" { + return originName + } + + return *roidentStr +} + +// testNativePreserveOrigin verifies origin tracking on native PG with real +// logical replication: +// 1. Set up logical replication (publication on n1, subscription on n2) +// 2. Insert data on n1, wait for streaming replication to n2 +// 3. Verify GetNodeOriginNames maps roident → subscription name +// 4. Verify replicated rows on n2 have origin tracked via pg_xact_commit_timestamp_origin +// 5. Verify the origin resolves to the subscription's pg_replication_origin entry +// 6. Delete rows on n2, run diff + repair, verify rows restored +// +// Note: preserve-origin cannot fully restore origins in a 2-node setup because +// the source-of-truth node (n1) has origin="local" for rows it wrote. A 3-node +// setup (like the spock PreserveOrigin test) is needed for full origin preservation. +func testNativePreserveOrigin(t *testing.T, state *nativeClusterState, env *testEnv) { + ctx := context.Background() + tableName := "native_preserve_origin_test" + qualifiedTableName := fmt.Sprintf("public.%s", tableName) + subName := "preserve_origin_sub" + pubName := "preserve_origin_pub" + + // Create table on both nodes. + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + _, err := pool.Exec(ctx, fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s (id INT PRIMARY KEY, data TEXT)", qualifiedTableName)) + require.NoError(t, err) + } + + // Create publication on n1 and subscription on n2. + // Use copy_data=false so the initial table sync doesn't use a transient + // replication origin (which PG deletes after sync, leaving rows with a + // defunct roident). With copy_data=false, data inserted after the + // subscription starts streaming uses the subscription's main origin. + _, err := state.n1Pool.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s", pubName, qualifiedTableName)) + require.NoError(t, err) + + connStr := fmt.Sprintf("host=%s port=5432 dbname=%s user=%s password=%s", + nativeServiceN1, nativeDBName, nativeUser, nativePassword) + _, err = state.n2Pool.Exec(ctx, fmt.Sprintf( + "CREATE SUBSCRIPTION %s CONNECTION '%s' PUBLICATION %s WITH (copy_data = false)", + subName, connStr, pubName)) + require.NoError(t, err) + + t.Cleanup(func() { + state.n2Pool.Exec(ctx, fmt.Sprintf("DROP SUBSCRIPTION IF EXISTS %s", subName)) + state.n1Pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", pubName)) + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) + } + }) + + // Brief pause for the subscription to connect and start streaming. + time.Sleep(2 * time.Second) + + // Insert data on n1 AFTER subscription is streaming. + sampleIDs := []int{1, 2, 3, 4, 5} + for _, id := range sampleIDs { + _, err := state.n1Pool.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (id, data) VALUES ($1, $2)", qualifiedTableName), + id, fmt.Sprintf("row_%d", id)) + require.NoError(t, err) + } + + assertEventually(t, 30*time.Second, func() error { + var count int + if err := state.n2Pool.QueryRow(ctx, + fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&count); err != nil { + return err + } + if count < len(sampleIDs) { + return fmt.Errorf("expected %d rows on n2, got %d", len(sampleIDs), count) + } + return nil + }) + log.Println("Replication complete: all rows present on n2 (via streaming)") + + // --- Verify GetNodeOriginNames maps roident → subscription name --- + names, err := queries.GetNodeOriginNames(ctx, state.n2Pool) + require.NoError(t, err) + found := false + var subRoident string + for id, name := range names { + if name == subName { + found = true + subRoident = id + break + } + } + require.True(t, found, "GetNodeOriginNames should contain subscription %q, got: %v", subName, names) + log.Printf("GetNodeOriginNames on n2: %v (subscription roident=%s)", names, subRoident) + + // --- Verify replicated rows on n2 have non-local origin --- + for _, id := range sampleIDs { + origin := getNativeReplicationOrigin(t, ctx, state.n2Pool, qualifiedTableName, id) + require.NotEmpty(t, origin, + "Row %d on n2 should have a replication origin (was replicated from n1)", id) + log.Printf("Row %d on n2: origin=%s", id, origin) + } + + // --- Verify rows on n1 are "local" origin --- + for _, id := range sampleIDs { + origin := getNativeReplicationOrigin(t, ctx, state.n1Pool, qualifiedTableName, id) + assert.Empty(t, origin, + "Row %d on n1 should have local origin (roident=0), got %q", id, origin) + } + log.Println("Origin tracking verified: n2 rows have subscription origin, n1 rows are local") + + // --- Simulate data loss on n2 and verify basic repair works --- + log.Println("Simulating data loss on n2...") + tx, err := state.n2Pool.Begin(ctx) + require.NoError(t, err) + _, err = tx.Exec(ctx, "SET session_replication_role = 'replica'") + require.NoError(t, err) + for _, id := range sampleIDs { + _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE id = $1", qualifiedTableName), id) + require.NoError(t, err) + } + require.NoError(t, tx.Commit(ctx)) + + // Run table-diff. + diffTask := env.newTableDiffTask(t, qualifiedTableName, []string{nativeServiceN1, nativeServiceN2}) + require.NoError(t, diffTask.RunChecks(false)) + require.NoError(t, diffTask.ExecuteTask()) + diffFile := getLatestDiffFile(t) + require.NotEmpty(t, diffFile) + + // Run repair (recovery mode). + repairTask := env.newTableRepairTask(nativeServiceN1, qualifiedTableName, diffFile) + repairTask.RecoveryMode = true + + err = repairTask.Run(false) + require.NoError(t, err) + if repairTask.TaskStatus == "FAILED" { + t.Fatalf("Repair failed: %s", repairTask.TaskContext) + } + + // Verify all rows are restored. + var count int + err = state.n2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&count) + require.NoError(t, err) + require.Equal(t, len(sampleIDs), count, "All rows should be restored after repair") + + // Verify row content matches. + for _, id := range sampleIDs { + var data string + err := state.n2Pool.QueryRow(ctx, + fmt.Sprintf("SELECT data FROM %s WHERE id = $1", qualifiedTableName), id).Scan(&data) + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("row_%d", id), data, "Row %d data mismatch", id) + } + log.Println("Native PG origin tracking and repair verified successfully") +} diff --git a/tests/integration/table_diff_test.go b/tests/integration/table_diff_test.go index 500f47b..a39fd2a 100644 --- a/tests/integration/table_diff_test.go +++ b/tests/integration/table_diff_test.go @@ -15,7 +15,6 @@ import ( "context" "fmt" "log" - "math" "os" "path/filepath" "strings" @@ -24,7 +23,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/internal/consistency/diff" - "github.com/pgedge/ace/pkg/types" "github.com/stretchr/testify/require" ) @@ -33,46 +31,25 @@ func newTestTableDiffTask( qualifiedTableName string, nodes []string, ) *diff.TableDiffTask { - task := diff.NewTableDiffTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.QualifiedTableName = qualifiedTableName - task.Nodes = strings.Join(nodes, ",") - task.Output = "json" - task.BlockSize = 1000 - task.CompareUnitSize = 100 - task.ConcurrencyFactor = 1 - task.MaxDiffRows = math.MaxInt64 - - task.DiffResult = types.DiffOutput{ - NodeDiffs: make(map[string]types.DiffByNodePair), - Summary: types.DiffSummary{ - Nodes: nodes, - BlockSize: task.BlockSize, - CompareUnitSize: task.CompareUnitSize, - ConcurrencyFactor: task.ConcurrencyFactor, - DiffRowsCount: make(map[string]int), - }, - } - - return task + return newSpockEnv().newTableDiffTask(t, qualifiedTableName, nodes) } func TestTableDiffSimplePK(t *testing.T) { + env := newSpockEnv() t.Run("Customers", func(t *testing.T) { - runCustomerTableDiffTests(t) + runCustomerTableDiffTests(t, env) }) t.Run("MixedCaseIdentifiers", func(t *testing.T) { - testTableDiff_MixedCaseIdentifiers(t, false) + testTableDiff_MixedCaseIdentifiers(t, env, false) }) t.Run("VariousDataTypes", func(t *testing.T) { - testTableDiff_VariousDataTypes(t, false) + testTableDiff_VariousDataTypes(t, newSpockEnv(), false) }) t.Run("UUIDColumn", func(t *testing.T) { - testTableDiff_UUIDColumn(t, false) + testTableDiff_UUIDColumn(t, newSpockEnv(), false) }) t.Run("ByteaColumnSizeCheck", func(t *testing.T) { - testTableDiff_ByteaColumnSizeCheck(t, false) + testTableDiff_ByteaColumnSizeCheck(t, newSpockEnv(), false) }) t.Run("WithSpockMetadata", func(t *testing.T) { testTableDiff_WithSpockMetadata(t, false) @@ -80,54 +57,55 @@ func TestTableDiffSimplePK(t *testing.T) { } func TestTableDiffCompositePK(t *testing.T) { + env := newSpockEnv() t.Run("Customers", func(t *testing.T) { ctx := context.Background() tableName := "customers" - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, err) } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, err) } }) - runCustomerTableDiffTests(t) + runCustomerTableDiffTests(t, env) }) t.Run("MixedCaseIdentifiers", func(t *testing.T) { - testTableDiff_MixedCaseIdentifiers(t, true) + testTableDiff_MixedCaseIdentifiers(t, env, true) }) t.Run("VariousDataTypes", func(t *testing.T) { - testTableDiff_VariousDataTypes(t, true) + testTableDiff_VariousDataTypes(t, newSpockEnv(), true) }) t.Run("UUIDColumn", func(t *testing.T) { - testTableDiff_UUIDColumn(t, true) + testTableDiff_UUIDColumn(t, newSpockEnv(), true) }) t.Run("ByteaColumnSizeCheck", func(t *testing.T) { - testTableDiff_ByteaColumnSizeCheck(t, true) + testTableDiff_ByteaColumnSizeCheck(t, newSpockEnv(), true) }) t.Run("WithSpockMetadata", func(t *testing.T) { testTableDiff_WithSpockMetadata(t, true) }) } -func runCustomerTableDiffTests(t *testing.T) { - t.Run("NoDifferences", testTableDiff_NoDifferences) - t.Run("DataOnlyOnNode1", testTableDiff_DataOnlyOnNode1) - t.Run("DataOnlyOnNode2", testTableDiff_DataOnlyOnNode2) - t.Run("ModifiedRows", testTableDiff_ModifiedRows) - t.Run("TableFiltering", testTableDiff_TableFiltering) - t.Run("TableFilterNoRows", testTableDiff_TableFilterNoRows) - t.Run("MaxDiffRowsLimit", testTableDiff_MaxDiffRowsLimit) +func runCustomerTableDiffTests(t *testing.T, env *testEnv) { + t.Run("NoDifferences", func(t *testing.T) { testTableDiff_NoDifferences(t, env) }) + t.Run("DataOnlyOnNode1", func(t *testing.T) { testTableDiff_DataOnlyOnNode1(t, env) }) + t.Run("DataOnlyOnNode2", func(t *testing.T) { testTableDiff_DataOnlyOnNode2(t, env) }) + t.Run("ModifiedRows", func(t *testing.T) { testTableDiff_ModifiedRows(t, env) }) + t.Run("TableFiltering", func(t *testing.T) { testTableDiff_TableFiltering(t, env) }) + t.Run("TableFilterNoRows", func(t *testing.T) { testTableDiff_TableFilterNoRows(t, env) }) + t.Run("MaxDiffRowsLimit", func(t *testing.T) { testTableDiff_MaxDiffRowsLimit(t, env) }) } -func testTableDiff_NoDifferences(t *testing.T) { - resetSharedTable(t, "customers") +func testTableDiff_NoDifferences(t *testing.T, env *testEnv) { + env.resetSharedTable(t, "customers") tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -160,13 +138,13 @@ func testTableDiff_NoDifferences(t *testing.T) { log.Println("TestTableDiff_NoDifferences completed.") } -func testTableDiff_DataOnlyOnNode1(t *testing.T) { +func testTableDiff_DataOnlyOnNode1(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -174,32 +152,26 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { }) // Truncate the table on the second node to create the diff - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) - if err != nil { - t.Fatalf("Failed to truncate table %s on node2: %v", qualifiedTableName, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s on node2: %v", qualifiedTableName, err) + } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } - log.Printf("Data loaded only into %s for table %s", serviceN1, qualifiedTableName) + log.Printf("Data loaded only into %s for table %s", env.ServiceN1, qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -210,10 +182,7 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -225,26 +194,26 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { } var expectedDiffCount int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) if err != nil { t.Fatalf("Failed to count rows in %s on node1: %v", qualifiedTableName, err) } - node1OnlyRows := nodeDiffs.Rows[serviceN1] + node1OnlyRows := nodeDiffs.Rows[env.ServiceN1] if len(node1OnlyRows) != expectedDiffCount { t.Errorf( "Expected %d rows only on %s, but got %d", expectedDiffCount, - serviceN1, + env.ServiceN1, len(node1OnlyRows), ) } - node2OnlyRows := nodeDiffs.Rows[serviceN2] + node2OnlyRows := nodeDiffs.Rows[env.ServiceN2] if len(node2OnlyRows) != 0 { t.Errorf( "Expected 0 rows only on %s, but got %d", - serviceN2, + env.ServiceN2, len(node2OnlyRows), ) } @@ -261,45 +230,39 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { log.Println("TestTableDiff_DataOnlyOnNode1 completed.") } -func testTableDiff_DataOnlyOnNode2(t *testing.T) { +func testTableDiff_DataOnlyOnNode2(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) } }) - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN1, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN1, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN1, err) - } - _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) - if err != nil { - t.Fatalf("Failed to truncate table %s on node1: %v", qualifiedTableName, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN1, err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s on node1: %v", qualifiedTableName, err) + } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN1, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN1, err) } - log.Printf("Data loaded only into %s for table %s", serviceN2, qualifiedTableName) + log.Printf("Data loaded only into %s for table %s", env.ServiceN2, qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -310,10 +273,7 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -325,26 +285,26 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { } var expectedDiffCount int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) if err != nil { t.Fatalf("Failed to count rows in %s on node2: %v", qualifiedTableName, err) } - node1OnlyRows := nodeDiffs.Rows[serviceN1] + node1OnlyRows := nodeDiffs.Rows[env.ServiceN1] if len(node1OnlyRows) != 0 { t.Errorf( "Expected 0 rows only on %s, but got %d", - serviceN1, + env.ServiceN1, len(node1OnlyRows), ) } - node2OnlyRows := nodeDiffs.Rows[serviceN2] + node2OnlyRows := nodeDiffs.Rows[env.ServiceN2] if len(node2OnlyRows) != expectedDiffCount { t.Errorf( "Expected %d rows only on %s, but got %d", expectedDiffCount, - serviceN2, + env.ServiceN2, len(node2OnlyRows), ) } @@ -361,13 +321,13 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { log.Println("TestTableDiff_DataOnlyOnNode2 completed.") } -func testTableDiff_ModifiedRows(t *testing.T) { +func testTableDiff_ModifiedRows(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -391,50 +351,44 @@ func testTableDiff_ModifiedRows(t *testing.T) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - for _, mod := range modifications { - updateSQL := fmt.Sprintf( - "UPDATE %s.%s SET %s = $1 WHERE index = $2", - testSchema, - tableName, - mod.field, - ) - _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - if err != nil { - t.Fatalf( - "Failed to update row with index %d on node %s: %v", - mod.indexVal, - serviceN2, - err, + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range modifications { + updateSQL := fmt.Sprintf( + "UPDATE %s.%s SET %s = $1 WHERE index = $2", + env.Schema, + tableName, + mod.field, ) + _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + if err != nil { + t.Fatalf( + "Failed to update row with index %d on node %s: %v", + mod.indexVal, + env.ServiceN2, + err, + ) + } } - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } log.Printf( "%d rows modified on %s for table %s", len(modifications), - serviceN2, + env.ServiceN2, qualifiedTableName, ) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -445,10 +399,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -459,26 +410,26 @@ func testTableDiff_ModifiedRows(t *testing.T) { ) } - if len(nodeDiffs.Rows[serviceN1]) != len(modifications) { + if len(nodeDiffs.Rows[env.ServiceN1]) != len(modifications) { t.Errorf( "Expected %d modified rows to be reported for %s (original values), but got %d. Rows: %+v", len( modifications, ), - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != len(modifications) { + if len(nodeDiffs.Rows[env.ServiceN2]) != len(modifications) { t.Errorf( "Expected %d modified rows to be reported for %s (modified values), but got %d. Rows: %+v", len( modifications, ), - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } @@ -495,7 +446,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { for _, mod := range modifications { found := false - for _, rowN2 := range nodeDiffs.Rows[serviceN2] { + for _, rowN2 := range nodeDiffs.Rows[env.ServiceN2] { if indexVal, ok := rowN2.Get("index"); ok && indexVal == int32(mod.indexVal) { actualModifiedValue, _ := rowN2.Get(mod.field) @@ -516,7 +467,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { t.Errorf( "Modified row with Index %d not found in diff results for node %s", mod.indexVal, - serviceN2, + env.ServiceN2, ) } } @@ -524,10 +475,10 @@ func testTableDiff_ModifiedRows(t *testing.T) { log.Println("TestTableDiff_ModifiedRows completed.") } -func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { +func testTableDiff_MixedCaseIdentifiers(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "CustomersMixedCase" - qualifiedTableName := fmt.Sprintf("%s.\"%s\"", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.\"%s\"", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -542,25 +493,22 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { "LastName" VARCHAR(100), "EmailAddress" VARCHAR(100), PRIMARY KEY("ID"%s) -);`, testSchema, qualifiedTableName, compositeKeyPart) +);`, env.Schema, qualifiedTableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createMixedCaseTableSQL) if err != nil { t.Fatalf( - "Failed to create mixed-case table %s on node %s: %v", + "Failed to create mixed-case table %s: %v", qualifiedTableName, - nodeName, err, ) } _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) if err != nil { t.Fatalf( - "Failed to truncate mixed-case table %s on node %s: %v", + "Failed to truncate mixed-case table %s: %v", qualifiedTableName, - nodeName, err, ) } @@ -568,7 +516,7 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { log.Printf("Mixed-case table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -600,30 +548,30 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { qualifiedTableName, ) for _, row := range commonRows { - _, err := pgCluster.Node1Pool.Exec(ctx, insertSQL, + _, err := env.N1Pool.Exec(ctx, insertSQL, row["ID"], row["FirstName"], row["LastName"], row["EmailAddress"]) if err != nil { t.Fatalf( "Failed to insert data into mixed-case table %s on node %s: %v", qualifiedTableName, - serviceN1, err) + env.ServiceN1, err) } - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, + _, err = env.N2Pool.Exec(ctx, insertSQL, row["ID"], row["FirstName"], row["LastName"], row["EmailAddress"]) if err != nil { t.Fatalf( "Failed to insert data into mixed-case table %s on node %s: %v", qualifiedTableName, - serviceN2, err) + env.ServiceN2, err) } } log.Printf("Data loaded into mixed-case table %s on both nodes", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask( + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask( t, - fmt.Sprintf("%s.%s", testSchema, tableName), + fmt.Sprintf("%s.%s", env.Schema, tableName), nodesToCompare, ) @@ -658,10 +606,10 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { log.Println("TestTableDiff_MixedCaseIdentifiers completed.") } -func testTableDiff_VariousDataTypes(t *testing.T, compositeKey bool) { +func testTableDiff_VariousDataTypes(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "data_type_test_table" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -690,33 +638,38 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_bytea BYTEA, col_int_array INT[], PRIMARY KEY(id%s) -);`, testSchema, testSchema, tableName, compositeKeyPart) +);`, env.Schema, env.Schema, tableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + nodeNames := []string{env.ServiceN1, env.ServiceN2} + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := nodeNames[i] _, err := pool.Exec(ctx, createDataTypeTableSQL) if err != nil { t.Fatalf("Failed to create data_type_test_table on node %s: %v", nodeName, err) } - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) + _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) if err != nil { t.Fatalf("Failed to truncate data_type_test_table on node %s: %v", nodeName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on n1: %v", err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + } } } log.Printf("Table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -771,55 +724,49 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode: %v", err) - } - _, err = tx.Exec( - ctx, - fmt.Sprintf(insertSQLTemplate, testSchema, tableName), - data["id"], - data["col_smallint"], - data["col_integer"], - data["col_bigint"], - data["col_numeric"], - data["col_real"], - data["col_double"], - data["col_varchar"], - data["col_text"], - data["col_char"], - data["col_boolean"], - data["col_date"], - data["col_timestamp"], - data["col_timestamptz"], - data["col_jsonb"], - data["col_json"], - data["col_bytea"], - data["col_int_array"], - ) - if err != nil { - t.Fatalf("Failed to insert row id %v: %v", data["id"], err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode: %v", err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec( + ctx, + fmt.Sprintf(insertSQLTemplate, env.Schema, tableName), + data["id"], + data["col_smallint"], + data["col_integer"], + data["col_bigint"], + data["col_numeric"], + data["col_real"], + data["col_double"], + data["col_varchar"], + data["col_text"], + data["col_char"], + data["col_boolean"], + data["col_date"], + data["col_timestamp"], + data["col_timestamptz"], + data["col_jsonb"], + data["col_json"], + data["col_bytea"], + data["col_int_array"], + ) + if err != nil { + t.Fatalf("Failed to insert row id %v: %v", data["id"], err) + } + }) if err = tx.Commit(ctx); err != nil { t.Fatalf("Failed to commit transaction: %v", err) } } - insertRow(pgCluster.Node1Pool, row1) - insertRow(pgCluster.Node2Pool, row1) - insertRow(pgCluster.Node1Pool, row2Node1Only) - insertRow(pgCluster.Node2Pool, row3Node2Only) - insertRow(pgCluster.Node1Pool, row4Base) - insertRow(pgCluster.Node2Pool, row4Node2Modified) + insertRow(env.N1Pool, row1) + insertRow(env.N2Pool, row1) + insertRow(env.N1Pool, row2Node1Only) + insertRow(env.N2Pool, row3Node2Only) + insertRow(env.N1Pool, row4Base) + insertRow(env.N2Pool, row4Node2Modified) log.Printf("Data loaded into %s with variations", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -830,10 +777,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( t.Fatalf("ExecuteTask failed for data type table: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -844,20 +788,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) } - if len(nodeDiffs.Rows[serviceN1]) != 2 { + if len(nodeDiffs.Rows[env.ServiceN1]) != 2 { t.Errorf( "Expected 2 rows in diffs for %s, got %d. Rows: %+v", - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != 2 { + if len(nodeDiffs.Rows[env.ServiceN2]) != 2 { t.Errorf( "Expected 2 rows in diffs for %s, got %d. Rows: %+v", - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } @@ -873,37 +817,37 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } foundRow2N1Only := false - for _, r := range nodeDiffs.Rows[serviceN1] { + for _, r := range nodeDiffs.Rows[env.ServiceN1] { if id, ok := r.Get("id"); ok && id == int32(2) { foundRow2N1Only = true break } } if !foundRow2N1Only { - t.Errorf("Row with id=2 (N1 only) not found in %s diffs", serviceN1) + t.Errorf("Row with id=2 (N1 only) not found in %s diffs", env.ServiceN1) } foundRow3N2Only := false - for _, r := range nodeDiffs.Rows[serviceN2] { + for _, r := range nodeDiffs.Rows[env.ServiceN2] { if id, ok := r.Get("id"); ok && id == int32(3) { foundRow3N2Only = true break } } if !foundRow3N2Only { - t.Errorf("Row with id=3 (N2 only) not found in %s diffs", serviceN2) + t.Errorf("Row with id=3 (N2 only) not found in %s diffs", env.ServiceN2) } foundRow4OriginalN1 := false foundRow4ModifiedN2 := false - for _, r := range nodeDiffs.Rows[serviceN1] { + for _, r := range nodeDiffs.Rows[env.ServiceN1] { id, _ := r.Get("id") varchar, _ := r.Get("col_varchar") if id == int32(4) && varchar == "original_varchar_row4" { foundRow4OriginalN1 = true } } - for _, r := range nodeDiffs.Rows[serviceN2] { + for _, r := range nodeDiffs.Rows[env.ServiceN2] { id, _ := r.Get("id") varchar, _ := r.Get("col_varchar") if id == int32(4) && varchar == "MODIFIED_varchar_row4" { @@ -911,19 +855,19 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } if !foundRow4OriginalN1 { - t.Errorf("Original version of row id=4 not found in %s diffs", serviceN1) + t.Errorf("Original version of row id=4 not found in %s diffs", env.ServiceN1) } if !foundRow4ModifiedN2 { - t.Errorf("Modified version of row id=4 not found in %s diffs", serviceN2) + t.Errorf("Modified version of row id=4 not found in %s diffs", env.ServiceN2) } log.Println("TestTableDiff_VariousDataTypes completed.") } -func testTableDiff_UUIDColumn(t *testing.T, compositeKey bool) { +func testTableDiff_UUIDColumn(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "uuid_test_table" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -937,33 +881,38 @@ CREATE TABLE IF NOT EXISTS %s.%s ( name TEXT, col_uuid UUID, PRIMARY KEY(id%s) -);`, testSchema, testSchema, tableName, compositeKeyPart) +);`, env.Schema, env.Schema, tableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + nodeNames := []string{env.ServiceN1, env.ServiceN2} + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := nodeNames[i] _, err := pool.Exec(ctx, createTableSQL) if err != nil { t.Fatalf("Failed to create uuid_test_table on node %s: %v", nodeName, err) } - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) + _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) if err != nil { t.Fatalf("Failed to truncate uuid_test_table on node %s: %v", nodeName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + } } } log.Printf("Table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -981,7 +930,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( row2UUIDNode1 := "550e8400-e29b-41d4-a716-446655440002" row2UUIDNode2 := "660e8400-e29b-41d4-a716-446655440002" - insertSQL := fmt.Sprintf("INSERT INTO %s.%s (id, name, col_uuid) VALUES ($1, $2, $3)", testSchema, tableName) + insertSQL := fmt.Sprintf("INSERT INTO %s.%s (id, name, col_uuid) VALUES ($1, $2, $3)", env.Schema, tableName) insertRow := func(pool *pgxpool.Pool, id int, name string, uuid string) { tx, err := pool.Begin(ctx) @@ -990,35 +939,29 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode: %v", err) - } - _, err = tx.Exec(ctx, insertSQL, id, name, uuid) - if err != nil { - t.Fatalf("Failed to insert row id %d: %v", id, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode: %v", err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, insertSQL, id, name, uuid) + if err != nil { + t.Fatalf("Failed to insert row id %d: %v", id, err) + } + }) if err = tx.Commit(ctx); err != nil { t.Fatalf("Failed to commit transaction: %v", err) } } // Insert same row on both nodes - insertRow(pgCluster.Node1Pool, 1, "same", row1UUID) - insertRow(pgCluster.Node2Pool, 1, "same", row1UUID) + insertRow(env.N1Pool, 1, "same", row1UUID) + insertRow(env.N2Pool, 1, "same", row1UUID) // Insert row with different UUID on each node - insertRow(pgCluster.Node1Pool, 2, "different", row2UUIDNode1) - insertRow(pgCluster.Node2Pool, 2, "different", row2UUIDNode2) + insertRow(env.N1Pool, 2, "different", row2UUIDNode1) + insertRow(env.N2Pool, 2, "different", row2UUIDNode2) log.Printf("Data loaded into %s with UUID variations", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -1029,10 +972,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( t.Fatalf("ExecuteTask failed for UUID table: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -1040,17 +980,17 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } // Should have 1 diff (row 2 with different UUIDs) - if len(nodeDiffs.Rows[serviceN1]) != 1 { + if len(nodeDiffs.Rows[env.ServiceN1]) != 1 { t.Errorf("Expected 1 row in diffs for %s, got %d. Rows: %+v", - serviceN1, len(nodeDiffs.Rows[serviceN1]), nodeDiffs.Rows[serviceN1]) + env.ServiceN1, len(nodeDiffs.Rows[env.ServiceN1]), nodeDiffs.Rows[env.ServiceN1]) } - if len(nodeDiffs.Rows[serviceN2]) != 1 { + if len(nodeDiffs.Rows[env.ServiceN2]) != 1 { t.Errorf("Expected 1 row in diffs for %s, got %d. Rows: %+v", - serviceN2, len(nodeDiffs.Rows[serviceN2]), nodeDiffs.Rows[serviceN2]) + env.ServiceN2, len(nodeDiffs.Rows[env.ServiceN2]), nodeDiffs.Rows[env.ServiceN2]) } // Verify the UUID is formatted correctly (as a string in standard UUID format) - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { uuid, ok := row.Get("col_uuid") if !ok { t.Errorf("col_uuid not found in diff row") @@ -1066,7 +1006,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { uuid, ok := row.Get("col_uuid") if !ok { t.Errorf("col_uuid not found in diff row") @@ -1085,13 +1025,13 @@ CREATE TABLE IF NOT EXISTS %s.%s ( log.Println("TestTableDiff_UUIDColumn completed.") } -func testTableDiff_TableFiltering(t *testing.T) { +func testTableDiff_TableFiltering(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -1120,45 +1060,39 @@ func testTableDiff_TableFiltering(t *testing.T) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - for _, mod := range updatesNode2 { - updateSQL := fmt.Sprintf( - "UPDATE %s.%s SET %s = $1 WHERE index = $2", - testSchema, - tableName, - mod.field, - ) - _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - if err != nil { - t.Fatalf( - "Failed to update row with index %d on node %s for filter test: %v", - mod.indexVal, - serviceN2, - err, + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range updatesNode2 { + updateSQL := fmt.Sprintf( + "UPDATE %s.%s SET %s = $1 WHERE index = $2", + env.Schema, + tableName, + mod.field, ) + _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + if err != nil { + t.Fatalf( + "Failed to update row with index %d on node %s for filter test: %v", + mod.indexVal, + env.ServiceN2, + err, + ) + } } - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } - log.Printf("Data modified on %s for filter test", serviceN2) + log.Printf("Data modified on %s for filter test", env.ServiceN2) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.TableFilter = "index <= 100" err = tdTask.RunChecks(false) @@ -1170,10 +1104,7 @@ func testTableDiff_TableFiltering(t *testing.T) { t.Fatalf("ExecuteTask failed for table filtering test: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -1185,22 +1116,22 @@ func testTableDiff_TableFiltering(t *testing.T) { } expectedFilteredModifications := 3 - if len(nodeDiffs.Rows[serviceN1]) != expectedFilteredModifications { + if len(nodeDiffs.Rows[env.ServiceN1]) != expectedFilteredModifications { t.Errorf( "Expected %d modified rows (original) for %s due to filter, but got %d. Rows: %+v", expectedFilteredModifications, - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != expectedFilteredModifications { + if len(nodeDiffs.Rows[env.ServiceN2]) != expectedFilteredModifications { t.Errorf( "Expected %d modified rows (updated) for %s due to filter, but got %d. Rows: %+v", expectedFilteredModifications, - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } if tdTask.DiffResult.Summary.DiffRowsCount[pairKey] != expectedFilteredModifications { @@ -1216,7 +1147,7 @@ func testTableDiff_TableFiltering(t *testing.T) { foundIndex1 := false foundIndex2 := false foundIndex3 := false - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { indexVal, _ := row.Get("index") email, _ := row.Get("email") firstName, _ := row.Get("first_name") @@ -1247,12 +1178,12 @@ func testTableDiff_TableFiltering(t *testing.T) { log.Println("TestTableDiff_TableFiltering completed.") } -func testTableDiff_TableFilterNoRows(t *testing.T) { +func testTableDiff_TableFilterNoRows(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodesToCompare := []string{serviceN1, serviceN2} + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.TableFilter = "index < 0" // no rows satisfy this require.NoError(t, tdTask.RunChecks(false)) @@ -1261,13 +1192,13 @@ func testTableDiff_TableFilterNoRows(t *testing.T) { require.Contains(t, err.Error(), "table filter produced no rows") } -func testTableDiff_MaxDiffRowsLimit(t *testing.T) { +func testTableDiff_MaxDiffRowsLimit(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -1284,25 +1215,21 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T) { {indexVal: 4, email: "limit-test-4@example.com"}, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - updateSQL := fmt.Sprintf("UPDATE %s.%s SET email = $1 WHERE index = $2", testSchema, tableName) - for _, mod := range modifications { - _, err := tx.Exec(ctx, updateSQL, mod.email, mod.indexVal) - require.NoErrorf(t, err, "failed to update row with index %d on node %s", mod.indexVal, serviceN2) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + updateSQL := fmt.Sprintf("UPDATE %s.%s SET email = $1 WHERE index = $2", env.Schema, tableName) + for _, mod := range modifications { + _, err := tx.Exec(ctx, updateSQL, mod.email, mod.indexVal) + require.NoErrorf(t, err, "failed to update row with index %d on node %s", mod.indexVal, env.ServiceN2) + } + }) require.NoError(t, tx.Commit(ctx)) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.BlockSize = 10 tdTask.CompareUnitSize = 1 tdTask.MaxDiffRows = 2 @@ -1313,10 +1240,7 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T) { err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() if _, ok := tdTask.DiffResult.NodeDiffs[pairKey]; !ok { t.Fatalf("expected diffs for pair %s when enforcing max diff rows", pairKey) @@ -1343,10 +1267,10 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T) { log.Println("TestTableDiff_MaxDiffRowsLimit completed.") } -func testTableDiff_ByteaColumnSizeCheck(t *testing.T, compositeKey bool) { +func testTableDiff_ByteaColumnSizeCheck(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "bytea_size_test" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -1362,24 +1286,28 @@ CREATE TABLE IF NOT EXISTS %s ( );`, qualifiedTableName, compositeKeyPart) // Create table on both nodes and add cleanup to drop it - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createTableSQL) if err != nil { t.Fatalf("Failed to create test table %s: %v", qualifiedTableName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on n1: %v", err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set: %v", err) + } } } t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -1394,7 +1322,7 @@ CREATE TABLE IF NOT EXISTS %s ( // --- Test Case 1: Data < 1MB (should pass) --- t.Run("DataUnder1MB", func(t *testing.T) { // Truncate before run - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) if err != nil { t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) @@ -1402,15 +1330,24 @@ CREATE TABLE IF NOT EXISTS %s ( } smallData := make([]byte, 500*1024) // 500 KB - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) + _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) if err != nil { t.Fatalf("Failed to insert small data: %v", err) } - time.Sleep(5 * time.Second) + if env.HasSpock { + // Wait for replication to propagate + time.Sleep(5 * time.Second) + } else { + // On native PG there is no replication; insert on n2 directly + _, err = env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) + if err != nil { + t.Fatalf("Failed to insert small data on n2: %v", err) + } + } - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -1421,22 +1358,31 @@ CREATE TABLE IF NOT EXISTS %s ( // --- Test Case 2: Data > 1MB (should fail) --- t.Run("DataOver1MB", func(t *testing.T) { // Truncate before run - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) if err != nil { t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) } } largeData := make([]byte, 1024*1024+1) // > 1 MB - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) + _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) if err != nil { t.Fatalf("Failed to insert large data: %v", err) } - time.Sleep(5 * time.Second) + if env.HasSpock { + // Wait for replication to propagate + time.Sleep(5 * time.Second) + } else { + // On native PG there is no replication; insert on n2 directly + _, err = env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) + if err != nil { + t.Fatalf("Failed to insert large data on n2: %v", err) + } + } - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err == nil { diff --git a/tests/integration/table_repair_test.go b/tests/integration/table_repair_test.go index 96a0c62..11c7954 100644 --- a/tests/integration/table_repair_test.go +++ b/tests/integration/table_repair_test.go @@ -37,89 +37,15 @@ import ( // - 2 rows only on node1 (IDs 1001, 1002) // - 2 rows only on node2 (IDs 2001, 2002) // - 2 common rows modified on node2 (IDs 1, 2) -func setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string, composite bool) { +func setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { t.Helper() - log.Println("Setting up data divergence for", qualifiedTableName) - - // Truncate on both nodes - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) - } - - // Insert common rows (always populate customer_id to support composite PK later) - for i := 1; i <= 5; i++ { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("FirstName%d", i), fmt.Sprintf("LastName%d", i), fmt.Sprintf("email%d@example.com", i)) - require.NoError(t, err) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - } - } - - // Insert rows only on node1 (always include customer_id) - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 1001; i <= 1002; i++ { - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1OnlyFirst%d", i), fmt.Sprintf("N1OnlyLast%d", i), fmt.Sprintf("n1.only%d@example.com", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Insert rows only on node2 (always include customer_id) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 2001; i <= 2002; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2OnlyFirst%d", i), fmt.Sprintf("N2OnlyLast%d", i), fmt.Sprintf("n2.only%d@example.com", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Modify rows on node2 - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 1; i <= 2; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = $1 WHERE index = $2", qualifiedTableName), - fmt.Sprintf("modified.email%d@example.com", i), i) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - log.Println("Data divergence setup complete.") + newSpockEnv().setupDivergence(t, ctx, qualifiedTableName) } // runTableDiff executes a table-diff task and returns the path to the latest diff file. func runTableDiff(t *testing.T, qualifiedTableName string, nodesToCompare []string) string { t.Helper() - // Clean up any old diff files to ensure we get the correct one - files, _ := filepath.Glob("*_diffs-*.json") - for _, f := range files { - os.Remove(f) - } - - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) - err := tdTask.RunChecks(false) - require.NoError(t, err, "table-diff validation failed") - err = tdTask.ExecuteTask() - require.NoError(t, err, "table-diff execution failed") - - latestDiffFile := getLatestDiffFile(t) - require.NotEmpty(t, latestDiffFile, "No diff file was generated") - - return latestDiffFile + return newSpockEnv().runTableDiff(t, qualifiedTableName, nodesToCompare) } // getLatestDiffFile finds the most recently modified diff file. @@ -145,16 +71,7 @@ func getLatestDiffFile(t *testing.T) string { // assertNoTableDiff runs a diff and asserts that there are no differences. func assertNoTableDiff(t *testing.T, qualifiedTableName string) { t.Helper() - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) - - err := tdTask.RunChecks(false) - require.NoError(t, err, "assertNoTableDiff: validation failed") - - err = tdTask.ExecuteTask() - require.NoError(t, err, "assertNoTableDiff: execution failed") - - assert.Empty(t, tdTask.DiffResult.NodeDiffs, "Expected no differences after repair, but diffs were found") + newSpockEnv().assertNoTableDiff(t, qualifiedTableName) } // captureOutput executes a function while capturing its stdout and stderr. @@ -196,42 +113,7 @@ func getTableCount(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualif func setupNullDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { t.Helper() - log.Println("Setting up null divergence for", qualifiedTableName) - - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) - } - - insertSQL := fmt.Sprintf( - "INSERT INTO %s (index, customer_id, first_name, last_name, city) VALUES ($1, $2, $3, $4, $5)", - qualifiedTableName, - ) - - // Node1 rows: missing city for id 1, missing first_name for id 2 - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Node2 rows: missing last_name for id 1, missing city for id 2 - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + newSpockEnv().setupNullDivergence(t, ctx, qualifiedTableName) } type nameCity struct { @@ -252,9 +134,9 @@ func getNameCity(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualifie return nameCity{first: first, last: last, city: city} } -func TestTableRepair_UnidirectionalDefault(t *testing.T) { +func testTableRepair_UnidirectionalDefault(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -268,14 +150,14 @@ func TestTableRepair_UnidirectionalDefault(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -287,34 +169,40 @@ func TestTableRepair_UnidirectionalDefault(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) diffData, err := os.ReadFile(diffFile) require.NoError(t, err) - assert.Contains(t, string(diffData), "_spock_metadata_", "Diff file should contain spock metadata before repair") + if env.HasSpock { + assert.Contains(t, string(diffData), "_spock_metadata_", "Diff file should contain spock metadata before repair") + } - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err = repairTask.Run(false) require.NoError(t, err, "Table repair failed") log.Println("Verifying repair for TestTableRepair_UnidirectionalDefault") - assertNoTableDiff(t, qualifiedTableName) - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1, count2, "Row counts should be equal after default repair") assert.Equal(t, 7, count1, "Expected 7 rows on node1") }) } } -func TestTableRepair_InsertOnly(t *testing.T) { +func TestTableRepair_UnidirectionalDefault(t *testing.T) { + testTableRepair_UnidirectionalDefault(t, newSpockEnv()) +} + +func testTableRepair_InsertOnly(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -328,14 +216,14 @@ func TestTableRepair_InsertOnly(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -347,49 +235,50 @@ func TestTableRepair_InsertOnly(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.InsertOnly = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (insert-only) failed") - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 7, count1) assert.Equal(t, 9, count2) - tdTask := newTestTableDiffTask(t, qualifiedTableName, []string{serviceN1, serviceN2}) + tdTask := env.newTableDiffTask(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) err = tdTask.RunChecks(false) require.NoError(t, err) err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() // After insert-only repair with n1 as source (DO NOTHING on conflict): // - Rows 1001, 1002 (n1-only) are inserted into n2 // - Rows 1, 2 (conflicting) are NOT overwritten on n2 (DO NOTHING), so they still differ // - n1 has 2 differing rows (its versions of rows 1, 2) // - n2 has 4 differing rows (its versions of rows 1, 2 + unique rows 2001, 2002) // - Total diff count is 4 - assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN1])) - assert.Equal(t, 4, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN2])) + assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN1])) + assert.Equal(t, 4, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN2])) assert.Equal(t, 4, tdTask.DiffResult.Summary.DiffRowsCount[pairKey]) }) } } -func TestTableRepair_UpsertOnly(t *testing.T) { +func TestTableRepair_InsertOnly(t *testing.T) { + testTableRepair_InsertOnly(t, newSpockEnv()) +} + +func testTableRepair_UpsertOnly(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -403,14 +292,14 @@ func TestTableRepair_UpsertOnly(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -422,43 +311,44 @@ func TestTableRepair_UpsertOnly(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.UpsertOnly = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (upsert-only) failed") - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 7, count1) assert.Equal(t, 9, count2) - tdTask := newTestTableDiffTask(t, qualifiedTableName, []string{serviceN1, serviceN2}) + tdTask := env.newTableDiffTask(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) err = tdTask.RunChecks(false) require.NoError(t, err) err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } - assert.Equal(t, 0, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN1])) - assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN2])) + pairKey := env.pairKey() + assert.Equal(t, 0, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN1])) + assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN2])) assert.Equal(t, 2, tdTask.DiffResult.Summary.DiffRowsCount[pairKey]) }) } } -func TestTableRepair_Bidirectional(t *testing.T) { +func TestTableRepair_UpsertOnly(t *testing.T) { + testTableRepair_UpsertOnly(t, newSpockEnv()) +} + +func testTableRepair_Bidirectional(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -472,14 +362,14 @@ func TestTableRepair_Bidirectional(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -492,55 +382,53 @@ func TestTableRepair_Bidirectional(t *testing.T) { t.Cleanup(tc.teardown) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) log.Println("Setting up data for bidirectional test") - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + env.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err) + }) } - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 3001; i <= 3003; i++ { - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N1Pool, func(conn *pgxpool.Conn) { + for i := 3001; i <= 3003; i++ { + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) + require.NoError(t, err) + } + }) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 4001; i <= 4002; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N2Pool, func(conn *pgxpool.Conn) { + for i := 4001; i <= 4002; i++ { + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) + require.NoError(t, err) + } + }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.Bidirectional = true - err = repairTask.Run(false) + err := repairTask.Run(false) require.NoError(t, err, "Table repair (bidirectional) failed") log.Println("Verifying repair for TestTableRepair_Bidirectional") - assertNoTableDiff(t, qualifiedTableName) - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 5, count1, "Expected 5 rows on node1 after bidirectional repair") assert.Equal(t, 5, count2, "Expected 5 rows on node2 after bidirectional repair") }) } } -func TestTableRepair_DryRun(t *testing.T) { +func TestTableRepair_Bidirectional(t *testing.T) { + testTableRepair_Bidirectional(t, newSpockEnv()) +} + +func testTableRepair_DryRun(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -554,14 +442,14 @@ func TestTableRepair_DryRun(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -573,14 +461,14 @@ func TestTableRepair_DryRun(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - count1Before, count2Before := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName), getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1Before, count2Before := getTableCount(t, ctx, env.N1Pool, qualifiedTableName), getTableCount(t, ctx, env.N2Pool, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.DryRun = true output := captureOutput(t, func() { @@ -589,18 +477,22 @@ func TestTableRepair_DryRun(t *testing.T) { }) assert.Contains(t, output, "DRY RUN") - assert.Contains(t, output, fmt.Sprintf("Node %s: Would attempt to UPSERT 4 rows and DELETE 2 rows.", serviceN2)) + assert.Contains(t, output, fmt.Sprintf("Node %s: Would attempt to UPSERT 4 rows and DELETE 2 rows.", env.ServiceN2)) - count1After, count2After := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName), getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1After, count2After := getTableCount(t, ctx, env.N1Pool, qualifiedTableName), getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1Before, count1After, "Node1 count should not change after dry run") assert.Equal(t, count2Before, count2After, "Node2 count should not change after dry run") }) } } -func TestTableRepair_GenerateReport(t *testing.T) { +func TestTableRepair_DryRun(t *testing.T) { + testTableRepair_DryRun(t, newSpockEnv()) +} + +func testTableRepair_GenerateReport(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() reportDir := "reports" @@ -615,14 +507,14 @@ func TestTableRepair_GenerateReport(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -636,14 +528,14 @@ func TestTableRepair_GenerateReport(t *testing.T) { t.Cleanup(func() { os.RemoveAll(reportDir) - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) os.RemoveAll(reportDir) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.GenerateReport = true err := repairTask.Run(false) @@ -669,12 +561,16 @@ func TestTableRepair_GenerateReport(t *testing.T) { require.NoError(t, err) assert.True(t, len(data) > 0, "Report file is empty") assert.Contains(t, string(data), "\"operation_type\": \"table-repair\"") - assert.Contains(t, string(data), fmt.Sprintf("\"source_of_truth\": \"%s\"", serviceN1)) + assert.Contains(t, string(data), fmt.Sprintf("\"source_of_truth\": \"%s\"", env.ServiceN1)) }) } } -func TestTableRepair_VariousDataTypes(t *testing.T) { +func TestTableRepair_GenerateReport(t *testing.T) { + testTableRepair_GenerateReport(t, newSpockEnv()) +} + +func testTableRepair_VariousDataTypes(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "data_type_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -705,19 +601,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_text_array TEXT[] );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createDataTypeTableSQL) require.NoErrorf(t, err, "Failed to create data type table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate data type table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -731,23 +631,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec( - ctx, - fmt.Sprintf(`INSERT INTO %s (id, col_smallint, col_integer, col_bigint, col_numeric, col_real, col_double, col_varchar, col_text, col_char, col_boolean, col_date, col_timestamp, col_timestamptz, col_interval, col_jsonb, col_json, col_bytea, col_int_array, col_text_array) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec( + ctx, + fmt.Sprintf(`INSERT INTO %s (id, col_smallint, col_integer, col_bigint, col_numeric, col_real, col_double, col_varchar, col_text, col_char, col_boolean, col_date, col_timestamp, col_timestamptz, col_interval, col_jsonb, col_json, col_bytea, col_int_array, col_text_array) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)`, qualifiedTableName), - data["id"], data["col_smallint"], data["col_integer"], data["col_bigint"], - data["col_numeric"], data["col_real"], data["col_double"], data["col_varchar"], - data["col_text"], data["col_char"], data["col_boolean"], data["col_date"], - data["col_timestamp"], data["col_timestamptz"], data["col_interval"], data["col_jsonb"], data["col_json"], - data["col_bytea"], data["col_int_array"], data["col_text_array"], - ) - require.NoError(t, err) + data["id"], data["col_smallint"], data["col_integer"], data["col_bigint"], + data["col_numeric"], data["col_real"], data["col_double"], data["col_varchar"], + data["col_text"], data["col_char"], data["col_boolean"], data["col_date"], + data["col_timestamp"], data["col_timestamptz"], data["col_interval"], data["col_jsonb"], data["col_json"], + data["col_bytea"], data["col_int_array"], data["col_text_array"], + ) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } @@ -796,20 +693,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( "col_numeric": "444.44", "col_varchar": "only_on_n2", } - insertRow(pgCluster.Node1Pool, row1) - insertRow(pgCluster.Node2Pool, row1) - insertRow(pgCluster.Node1Pool, row2OnlyN1) - insertRow(pgCluster.Node1Pool, row3Base) - insertRow(pgCluster.Node2Pool, row3ModifiedN2) - insertRow(pgCluster.Node2Pool, row4OnlyN2) + insertRow(env.N1Pool, row1) + insertRow(env.N2Pool, row1) + insertRow(env.N1Pool, row2OnlyN1) + insertRow(env.N1Pool, row3Base) + insertRow(env.N2Pool, row3ModifiedN2) + insertRow(env.N2Pool, row4OnlyN2) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err := repairTask.Run(false) require.NoError(t, err, "Table repair for various data types failed") - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) checkRow3 := func(pool *pgxpool.Pool) map[string]any { var ( @@ -837,8 +734,8 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } - row3N1 := checkRow3(pgCluster.Node1Pool) - row3N2 := checkRow3(pgCluster.Node2Pool) + row3N1 := checkRow3(env.N1Pool) + row3N2 := checkRow3(env.N2Pool) assert.Equal(t, row3N1["col_numeric"], row3N2["col_numeric"]) assert.Equal(t, row3N1["col_jsonb"], row3N2["col_jsonb"]) @@ -848,14 +745,18 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.True(t, bytes.Equal(row3N1["col_bytea"].([]byte), row3N2["col_bytea"].([]byte))) var row4Count int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 4", qualifiedTableName)).Scan(&row4Count) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 4", qualifiedTableName)).Scan(&row4Count) require.NoError(t, err) assert.Equal(t, 0, row4Count, "Row present only on node2 should be deleted") } -func TestTableRepair_FixNulls(t *testing.T) { +func TestTableRepair_VariousDataTypes(t *testing.T) { + testTableRepair_VariousDataTypes(t, newSpockEnv()) +} + +func testTableRepair_FixNulls(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -869,14 +770,14 @@ func TestTableRepair_FixNulls(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -888,23 +789,23 @@ func TestTableRepair_FixNulls(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupNullDivergence(t, ctx, qualifiedTableName) + env.setupNullDivergence(t, ctx, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (fix-nulls) failed") - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) - row1N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 1, "CUST-1") - row1N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 1, "CUST-1") - row2N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 2, "CUST-2") - row2N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 2, "CUST-2") + row1N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 1, "CUST-1") + row1N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 1, "CUST-1") + row2N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 2, "CUST-2") + row2N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 2, "CUST-2") require.NotNil(t, row1N1.city) require.NotNil(t, row1N2.last) @@ -919,9 +820,13 @@ func TestTableRepair_FixNulls(t *testing.T) { } } -func TestTableRepair_FixNulls_DryRun(t *testing.T) { +func TestTableRepair_FixNulls(t *testing.T) { + testTableRepair_FixNulls(t, newSpockEnv()) +} + +func testTableRepair_FixNulls_DryRun(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -935,14 +840,14 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -954,10 +859,10 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupNullDivergence(t, ctx, qualifiedTableName) + env.setupNullDivergence(t, ctx, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true repairTask.DryRun = true @@ -971,10 +876,10 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { assert.Contains(t, output, "Would update") // Ensure data unchanged - row1N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 1, "CUST-1") - row1N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 1, "CUST-1") - row2N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 2, "CUST-2") - row2N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 2, "CUST-2") + row1N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 1, "CUST-1") + row1N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 1, "CUST-1") + row2N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 2, "CUST-2") + row2N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 2, "CUST-2") assert.Nil(t, row1N1.city) assert.Nil(t, row1N2.last) @@ -982,7 +887,7 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { assert.Nil(t, row2N2.city) // Cleanup: actually repair to leave table consistent for subsequent tests - fixTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + fixTask := env.newTableRepairTask("", qualifiedTableName, diffFile) fixTask.SourceOfTruth = "" fixTask.FixNulls = true err := fixTask.Run(false) @@ -991,12 +896,16 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { } } -// TestTableRepair_TimestampAndTimeTypes verifies that the full diff→repair +func TestTableRepair_FixNulls_DryRun(t *testing.T) { + testTableRepair_FixNulls_DryRun(t, newSpockEnv()) +} + +// testTableRepair_TimestampAndTimeTypes verifies that the full diff→repair // pipeline works correctly for timestamp, timestamptz, time, and timetz columns. // This guards against pgx sending the wrong OID for timestamp-without-tz // (which would cause PostgreSQL to apply a session-timezone shift) and ensures // that pgtype.Time values round-trip through diff JSON and back correctly. -func TestTableRepair_TimestampAndTimeTypes(t *testing.T) { +func testTableRepair_TimestampAndTimeTypes(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "temporal_type_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -1011,19 +920,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_timetz TIME WITH TIME ZONE );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createSQL) require.NoErrorf(t, err, "Failed to create temporal table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate temporal table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add temporal table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add temporal table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -1037,18 +950,15 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, col_ts, col_tstz, col_time, col_timetz) VALUES ($1, $2, $3, $4::time, $5::timetz)`, - qualifiedTableName), - id, ts, tstz, timeStr, timetzStr, - ) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf( + `INSERT INTO %s (id, col_ts, col_tstz, col_time, col_timetz) VALUES ($1, $2, $3, $4::time, $5::timetz)`, + qualifiedTableName), + id, ts, tstz, timeStr, timetzStr, + ) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } @@ -1056,11 +966,11 @@ CREATE TABLE IF NOT EXISTS %s.%s ( refTSTZ := time.Date(2024, 6, 15, 18, 45, 30, 0, time.UTC) // Row 1: identical on both nodes (baseline) - insertRow(pgCluster.Node1Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") - insertRow(pgCluster.Node2Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") + insertRow(env.N1Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") + insertRow(env.N2Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") // Row 2: only on node1 (missing on node2) - insertRow(pgCluster.Node1Pool, 2, + insertRow(env.N1Pool, 2, refTS.Add(1*time.Hour), refTSTZ.Add(1*time.Hour), "12:00:00", @@ -1068,7 +978,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // Row 3: only on node2 (should be deleted when n1 is source of truth) - insertRow(pgCluster.Node2Pool, 3, + insertRow(env.N2Pool, 3, refTS.Add(2*time.Hour), refTSTZ.Add(2*time.Hour), "23:59:59.999999", @@ -1076,13 +986,13 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // Row 4: different values on each node (row mismatch) - insertRow(pgCluster.Node1Pool, 4, + insertRow(env.N1Pool, 4, refTS.Add(3*time.Hour), refTSTZ.Add(3*time.Hour), "06:30:00", "16:45:00-07", ) - insertRow(pgCluster.Node2Pool, 4, + insertRow(env.N2Pool, 4, refTS.Add(99*time.Hour), // deliberately different refTSTZ.Add(99*time.Hour), "22:00:00.500000", @@ -1090,24 +1000,24 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // ----- Diff ----- - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // ----- Repair (node1 = source of truth) ----- - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err := repairTask.Run(false) require.NoError(t, err, "Repair for temporal types failed") // ----- Verify: no diffs remain ----- - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) // ----- Verify: row counts match ----- - countN1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - countN2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + countN1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + countN2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, countN1, countN2, "Row counts should match after repair") // ----- Verify: row 3 was deleted from node2 ----- var row3Count int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 3", qualifiedTableName)).Scan(&row3Count) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 3", qualifiedTableName)).Scan(&row3Count) require.NoError(t, err) assert.Equal(t, 0, row3Count, "Row 3 (only on node2) should be deleted") @@ -1127,31 +1037,35 @@ CREATE TABLE IF NOT EXISTS %s.%s ( return r } - r2n1 := readRow(pgCluster.Node1Pool, 2) - r2n2 := readRow(pgCluster.Node2Pool, 2) + r2n1 := readRow(env.N1Pool, 2) + r2n2 := readRow(env.N2Pool, 2) assert.True(t, r2n1.ts.Equal(r2n2.ts), "row 2 col_ts: n1=%v n2=%v", r2n1.ts, r2n2.ts) assert.True(t, r2n1.tstz.Equal(r2n2.tstz), "row 2 col_tstz: n1=%v n2=%v", r2n1.tstz, r2n2.tstz) assert.Equal(t, r2n1.timeV, r2n2.timeV, "row 2 col_time mismatch") assert.Equal(t, r2n1.timetz, r2n2.timetz, "row 2 col_timetz mismatch") // ----- Verify: row 4 was repaired to match node1 ----- - r4n1 := readRow(pgCluster.Node1Pool, 4) - r4n2 := readRow(pgCluster.Node2Pool, 4) + r4n1 := readRow(env.N1Pool, 4) + r4n2 := readRow(env.N2Pool, 4) assert.True(t, r4n1.ts.Equal(r4n2.ts), "row 4 col_ts: n1=%v n2=%v", r4n1.ts, r4n2.ts) assert.True(t, r4n1.tstz.Equal(r4n2.tstz), "row 4 col_tstz: n1=%v n2=%v", r4n1.tstz, r4n2.tstz) assert.Equal(t, r4n1.timeV, r4n2.timeV, "row 4 col_time mismatch") assert.Equal(t, r4n1.timetz, r4n2.timetz, "row 4 col_timetz mismatch") // ----- Verify: row 1 (baseline) is still identical ----- - r1n1 := readRow(pgCluster.Node1Pool, 1) - r1n2 := readRow(pgCluster.Node2Pool, 1) + r1n1 := readRow(env.N1Pool, 1) + r1n2 := readRow(env.N2Pool, 1) assert.True(t, r1n1.ts.Equal(r1n2.ts), "row 1 col_ts: n1=%v n2=%v", r1n1.ts, r1n2.ts) assert.True(t, r1n1.tstz.Equal(r1n2.tstz), "row 1 col_tstz: n1=%v n2=%v", r1n1.tstz, r1n2.tstz) assert.Equal(t, r1n1.timeV, r1n2.timeV, "row 1 col_time mismatch") assert.Equal(t, r1n1.timetz, r1n2.timetz, "row 1 col_timetz mismatch") } -// TestTableRepair_FixNulls_BidirectionalUpdate tests that when both nodes have NULLs +func TestTableRepair_TimestampAndTimeTypes(t *testing.T) { + testTableRepair_TimestampAndTimeTypes(t, newSpockEnv()) +} + +// testTableRepair_FixNulls_BidirectionalUpdate tests that when both nodes have NULLs // in different columns for the same row, fix-nulls performs bidirectional updates. // This verifies the behavior discussed in code review: each node updates the other // with its non-NULL values. @@ -1165,9 +1079,9 @@ CREATE TABLE IF NOT EXISTS %s.%s ( // // Node1: {id: 1, col_a: "value_a", col_b: "value_b", col_c: "value_c"} // Node2: {id: 1, col_a: "value_a", col_b: "value_b", col_c: "value_c"} -func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { +func testTableRepair_FixNulls_BidirectionalUpdate(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -1181,14 +1095,14 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -1203,14 +1117,11 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { log.Println("Setting up bidirectional NULL divergence for", qualifiedTableName) // Clean table on both nodes - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + env.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "Failed to truncate table") + }) } // Insert row with complementary NULLs on each node @@ -1222,38 +1133,34 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { ) // Node1 data - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - // Row 1 on Node1: NULL first_name and city - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") - require.NoError(t, err) - // Row 2 on Node1: NULL last_name and email - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N1Pool, func(conn *pgxpool.Conn) { + // Row 1 on Node1: NULL first_name and city + _, err := conn.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") + require.NoError(t, err) + // Row 2 on Node1: NULL last_name and email + _, err = conn.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) + require.NoError(t, err) + }) // Node2 data - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - // Row 1 on Node2: NULL last_name and email - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) - require.NoError(t, err) - // Row 2 on Node2: NULL first_name and city - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N2Pool, func(conn *pgxpool.Conn) { + // Row 1 on Node2: NULL last_name and email + _, err := conn.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) + require.NoError(t, err) + // Row 2 on Node2: NULL first_name and city + _, err = conn.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") + require.NoError(t, err) + }) // Run table-diff to detect the NULL differences - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // Run fix-nulls repair - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true - err = repairTask.Run(false) + err := repairTask.Run(false) require.NoError(t, err, "Table repair (fix-nulls bidirectional) failed") // Verify bidirectional updates happened @@ -1276,8 +1183,8 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { } // Check Row 1 (id=100) on both nodes - row1N1 := getFullRow(pgCluster.Node1Pool, 100, "CUST-100") - row1N2 := getFullRow(pgCluster.Node2Pool, 100, "CUST-100") + row1N1 := getFullRow(env.N1Pool, 100, "CUST-100") + row1N2 := getFullRow(env.N2Pool, 100, "CUST-100") // Node1's NULLs (first_name, city) should be filled from Node2 require.NotNil(t, row1N1.firstName, "Node1 row 100 first_name should be filled from Node2") @@ -1302,8 +1209,8 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { assert.Equal(t, "email100@example.com", *row1N2.email) // Check Row 2 (id=200) on both nodes - row2N1 := getFullRow(pgCluster.Node1Pool, 200, "CUST-200") - row2N2 := getFullRow(pgCluster.Node2Pool, 200, "CUST-200") + row2N1 := getFullRow(env.N1Pool, 200, "CUST-200") + row2N2 := getFullRow(env.N2Pool, 200, "CUST-200") // Node1's NULLs (last_name, email) should be filled from Node2 require.NotNil(t, row2N1.lastName, "Node1 row 200 last_name should be filled from Node2") @@ -1328,13 +1235,17 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { assert.Equal(t, "email200@example.com", *row2N2.email) // Verify no diffs remain - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) log.Println("Bidirectional fix-nulls test completed successfully") }) } } +func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { + testTableRepair_FixNulls_BidirectionalUpdate(t, newSpockEnv()) +} + // TestTableRepair_PreserveOrigin tests that the preserve-origin flag correctly preserves // both replication origin metadata and commit timestamps during table repair recovery operations. // This test verifies the fix for maintaining original transaction metadata to prevent @@ -2169,7 +2080,7 @@ func TestTableRepair_MixedOps_PreserveOrigin(t *testing.T) { log.Println(" - UPDATE: 2 modified rows corrected with preserved origin/timestamp") } -// TestTableRepair_LargeBigintPK verifies that table repair correctly handles +// testTableRepair_LargeBigintPK verifies that table repair correctly handles // bigint primary keys whose values exceed float64's exact integer range (2^53). // Before the fix, JSON deserialization converted these to float64, silently // truncating the PK values and causing: @@ -2179,7 +2090,7 @@ func TestTableRepair_MixedOps_PreserveOrigin(t *testing.T) { // // The test uses PKs that are adjacent integers above 2^53, which all collapse // to the same float64 value without the json.Number fix. -func TestTableRepair_LargeBigintPK(t *testing.T) { +func testTableRepair_LargeBigintPK(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "bigint_pk_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -2202,19 +2113,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( amount NUMERIC(20, 4) );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createTableSQL) require.NoErrorf(t, err, "Failed to create table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -2228,36 +2143,33 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, - fmt.Sprintf("INSERT INTO %s (id, data, amount) VALUES ($1, $2, $3::numeric)", qualifiedTableName), - id, data, amount) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (id, data, amount) VALUES ($1, $2, $3::numeric)", qualifiedTableName), + id, data, amount) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } // ---- Set up divergence ---- // Common row on both nodes (same data) - insertRow(pgCluster.Node1Pool, collisionPKs[0], "common_row", "1000000.1234") - insertRow(pgCluster.Node2Pool, collisionPKs[0], "common_row", "1000000.1234") + insertRow(env.N1Pool, collisionPKs[0], "common_row", "1000000.1234") + insertRow(env.N2Pool, collisionPKs[0], "common_row", "1000000.1234") // Rows only on node1 (should be deleted when node2 is source of truth) - insertRow(pgCluster.Node1Pool, collisionPKs[1], "n1_only_289", "2000000.5678") + insertRow(env.N1Pool, collisionPKs[1], "n1_only_289", "2000000.5678") // Rows only on node2 (should be inserted into node1) - insertRow(pgCluster.Node2Pool, collisionPKs[2], "n2_only_290", "3000000.9012") + insertRow(env.N2Pool, collisionPKs[2], "n2_only_290", "3000000.9012") // Modified row: same PK on both nodes, different data - insertRow(pgCluster.Node1Pool, collisionPKs[3], "old_data_291", "4000000.0001") - insertRow(pgCluster.Node2Pool, collisionPKs[3], "new_data_291", "4000000.9999") + insertRow(env.N1Pool, collisionPKs[3], "old_data_291", "4000000.0001") + insertRow(env.N2Pool, collisionPKs[3], "new_data_291", "4000000.9999") // ---- Run diff ---- - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // Verify diff found the expected number of differences: // - collisionPKs[1]: node1-only → 1 diff @@ -2279,19 +2191,19 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.Equal(t, 3, totalDiffs, "Expected exactly 3 differences before repair") // ---- Run repair (node2 is source of truth) ---- - repairTask := newTestTableRepairTask(serviceN2, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN2, qualifiedTableName, diffFile) err = repairTask.Run(false) require.NoError(t, err, "Table repair failed") // ---- Verify repair ---- // 1. Tables should now match (zero diffs) - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) // 2. Row counts should match: 3 rows (common + n2_only_290 + modified_291) // collisionPKs[1] (n1_only_289) should have been deleted from node1 - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1, count2, "Row counts should match after repair") assert.Equal(t, 3, count1, "Expected 3 rows after repair") @@ -2308,26 +2220,30 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } // Common row should be unchanged - verifyRow(pgCluster.Node1Pool, collisionPKs[0], "common_row", "1000000.1234") + verifyRow(env.N1Pool, collisionPKs[0], "common_row", "1000000.1234") // Node1-only row should have been deleted var deletedCount int - err = pgCluster.Node1Pool.QueryRow(ctx, + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = $1", qualifiedTableName), collisionPKs[1]). Scan(&deletedCount) require.NoError(t, err) assert.Equal(t, 0, deletedCount, "Node1-only row (PK %d) should have been deleted", collisionPKs[1]) // Node2-only row should have been inserted into node1 - verifyRow(pgCluster.Node1Pool, collisionPKs[2], "n2_only_290", "3000000.9012") + verifyRow(env.N1Pool, collisionPKs[2], "n2_only_290", "3000000.9012") // Modified row should have node2's version on node1 - verifyRow(pgCluster.Node1Pool, collisionPKs[3], "new_data_291", "4000000.9999") + verifyRow(env.N1Pool, collisionPKs[3], "new_data_291", "4000000.9999") - log.Println("TestTableRepair_LargeBigintPK PASSED") + log.Println("testTableRepair_LargeBigintPK PASSED") log.Println(" - 4 adjacent PKs above 2^53 that collide under float64") log.Println(" - DELETE: node1-only row correctly removed") log.Println(" - INSERT: node2-only row correctly added to node1") log.Println(" - UPDATE: modified row correctly updated on node1") log.Println(" - All PKs preserved with exact precision (no float64 truncation)") } + +func TestTableRepair_LargeBigintPK(t *testing.T) { + testTableRepair_LargeBigintPK(t, newSpockEnv()) +} diff --git a/tests/integration/test_env_test.go b/tests/integration/test_env_test.go new file mode 100644 index 0000000..3e81fe9 --- /dev/null +++ b/tests/integration/test_env_test.go @@ -0,0 +1,405 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package integration + +import ( + "context" + "fmt" + "log" + "math" + "os" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/mtree" + "github.com/pgedge/ace/internal/consistency/repair" + "github.com/pgedge/ace/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testEnv abstracts the differences between spock and native PG test environments, +// allowing the same test logic to run against both. +type testEnv struct { + N1Pool *pgxpool.Pool + N2Pool *pgxpool.Pool + N3Pool *pgxpool.Pool // nil for native PG (only 2 nodes) + + N1Host string + N1Port string + N2Host string + N2Port string + + ServiceN1 string + ServiceN2 string + ServiceN3 string // empty for native PG + + DBName string + DBUser string + DBPassword string + + ClusterName string + ClusterNodes []map[string]any + + HasSpock bool + Schema string +} + +func newSpockEnv() *testEnv { + return &testEnv{ + N1Pool: pgCluster.Node1Pool, + N2Pool: pgCluster.Node2Pool, + N3Pool: pgCluster.Node3Pool, + N1Host: pgCluster.Node1Host, + N1Port: pgCluster.Node1Port, + N2Host: pgCluster.Node2Host, + N2Port: pgCluster.Node2Port, + ServiceN1: serviceN1, + ServiceN2: serviceN2, + ServiceN3: serviceN3, + DBName: dbName, + DBUser: pgEdgeUser, + DBPassword: pgEdgePassword, + ClusterName: pgCluster.ClusterName, + ClusterNodes: pgCluster.ClusterNodes, + HasSpock: true, + Schema: testSchema, + } +} + +func newNativeEnv(state *nativeClusterState) *testEnv { + return &testEnv{ + N1Pool: state.n1Pool, + N2Pool: state.n2Pool, + N3Pool: nil, + N1Host: state.n1Host, + N1Port: state.n1Port, + N2Host: state.n2Host, + N2Port: state.n2Port, + ServiceN1: nativeServiceN1, + ServiceN2: nativeServiceN2, + ServiceN3: "", + DBName: nativeDBName, + DBUser: nativeUser, + DBPassword: nativePassword, + ClusterName: nativeClusterName, + ClusterNodes: []map[string]any{ + { + "Name": nativeServiceN1, + "PublicIP": state.n1Host, + "Port": state.n1Port, + "DBUser": nativeUser, + "DBPassword": nativePassword, + "DBName": nativeDBName, + }, + { + "Name": nativeServiceN2, + "PublicIP": state.n2Host, + "Port": state.n2Port, + "DBUser": nativeUser, + "DBPassword": nativePassword, + "DBName": nativeDBName, + }, + }, + HasSpock: false, + Schema: testSchema, + } +} + +// withRepairMode acquires a single connection from the pool, enables +// spock.repair_mode on it (when available), runs fn, then disables repair mode +// and releases the connection. On native PG this simply pins a connection for +// the duration of fn. All work inside fn must use the provided conn so that +// repair_mode is in effect. +func (e *testEnv) withRepairMode(t *testing.T, ctx context.Context, pool *pgxpool.Pool, fn func(conn *pgxpool.Conn)) { + t.Helper() + conn, err := pool.Acquire(ctx) + require.NoError(t, err, "acquire connection for repair mode") + defer conn.Release() + + if e.HasSpock { + _, err := conn.Exec(ctx, "SELECT spock.repair_mode(true)") + require.NoError(t, err, "enable spock.repair_mode") + defer func() { + _, err := conn.Exec(ctx, "SELECT spock.repair_mode(false)") + require.NoError(t, err, "disable spock.repair_mode") + }() + } + fn(conn) +} + +// withRepairModeTx is like withRepairMode but operates on a transaction. +func (e *testEnv) withRepairModeTx(t *testing.T, ctx context.Context, tx pgx.Tx, fn func()) { + t.Helper() + if e.HasSpock { + _, err := tx.Exec(ctx, "SELECT spock.repair_mode(true)") + require.NoError(t, err, "enable spock.repair_mode in tx") + defer func() { + _, err := tx.Exec(ctx, "SELECT spock.repair_mode(false)") + require.NoError(t, err, "disable spock.repair_mode in tx") + }() + } + fn() +} + +// pools returns the two primary pools (n1, n2) used by most tests. +func (e *testEnv) pools() []*pgxpool.Pool { + return []*pgxpool.Pool{e.N1Pool, e.N2Pool} +} + +// newTableDiffTask creates a TableDiffTask configured for this environment. +func (e *testEnv) newTableDiffTask(t *testing.T, qualifiedTableName string, nodes []string) *diff.TableDiffTask { + t.Helper() + task := diff.NewTableDiffTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.QualifiedTableName = qualifiedTableName + task.Nodes = strings.Join(nodes, ",") + task.Output = "json" + task.BlockSize = 1000 + task.CompareUnitSize = 100 + task.ConcurrencyFactor = 1 + task.MaxDiffRows = math.MaxInt64 + + task.DiffResult = types.DiffOutput{ + NodeDiffs: make(map[string]types.DiffByNodePair), + Summary: types.DiffSummary{ + Nodes: nodes, + BlockSize: task.BlockSize, + CompareUnitSize: task.CompareUnitSize, + ConcurrencyFactor: task.ConcurrencyFactor, + DiffRowsCount: make(map[string]int), + }, + } + return task +} + +// newTableRepairTask creates a TableRepairTask configured for this environment. +func (e *testEnv) newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *repair.TableRepairTask { + task := repair.NewTableRepairTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.SourceOfTruth = sourceOfTruthNode + task.QualifiedTableName = qualifiedTableName + task.DiffFilePath = diffFilePath + task.Nodes = "all" + return task +} + +// setupDivergence prepares a table with a known set of differences between n1 and n2. +// - 5 common rows (index 1-5) +// - 2 rows only on n1 (index 1001, 1002) +// - 2 rows only on n2 (index 2001, 2002) +// - 2 common rows modified on n2 (index 1, 2) +func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { + t.Helper() + log.Println("Setting up data divergence for", qualifiedTableName) + + // Truncate on both nodes + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate table") + }) + } + + // Insert common rows + for i := 1; i <= 5; i++ { + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("FirstName%d", i), fmt.Sprintf("LastName%d", i), fmt.Sprintf("email%d@example.com", i)) + require.NoError(t, err) + }) + } + } + + // Rows only on n1 + e.withRepairMode(t, ctx, e.N1Pool, func(conn *pgxpool.Conn) { + for i := 1001; i <= 1002; i++ { + _, err := conn.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1OnlyFirst%d", i), fmt.Sprintf("N1OnlyLast%d", i), fmt.Sprintf("n1.only%d@example.com", i)) + require.NoError(t, err) + } + }) + + // Rows only on n2 + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { + for i := 2001; i <= 2002; i++ { + _, err := conn.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2OnlyFirst%d", i), fmt.Sprintf("N2OnlyLast%d", i), fmt.Sprintf("n2.only%d@example.com", i)) + require.NoError(t, err) + } + }) + + // Modify rows on n2 + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { + for i := 1; i <= 2; i++ { + _, err := conn.Exec(ctx, + fmt.Sprintf("UPDATE %s SET email = $1 WHERE index = $2", qualifiedTableName), + fmt.Sprintf("modified.email%d@example.com", i), i) + require.NoError(t, err) + } + }) + + log.Println("Data divergence setup complete.") +} + +// setupNullDivergence creates divergence where the same rows exist on both nodes +// but with different columns set to NULL. +func (e *testEnv) setupNullDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { + t.Helper() + log.Println("Setting up null divergence for", qualifiedTableName) + + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate table") + }) + } + + insertSQL := fmt.Sprintf( + "INSERT INTO %s (index, customer_id, first_name, last_name, city) VALUES ($1, $2, $3, $4, $5)", + qualifiedTableName, + ) + + // Node1: missing city for id 1, missing first_name for id 2 + e.withRepairMode(t, ctx, e.N1Pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) + require.NoError(t, err) + _, err = conn.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") + require.NoError(t, err) + }) + + // Node2: missing last_name for id 1, missing city for id 2 + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") + require.NoError(t, err) + _, err = conn.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) + require.NoError(t, err) + }) +} + +// runTableDiff executes a table-diff and returns the path to the latest diff file. +func (e *testEnv) runTableDiff(t *testing.T, qualifiedTableName string, nodesToCompare []string) string { + t.Helper() + files, _ := filepath.Glob("*_diffs-*.json") + for _, f := range files { + os.Remove(f) + } + + tdTask := e.newTableDiffTask(t, qualifiedTableName, nodesToCompare) + err := tdTask.RunChecks(false) + require.NoError(t, err, "table-diff validation failed") + err = tdTask.ExecuteTask() + require.NoError(t, err, "table-diff execution failed") + + latestDiffFile := getLatestDiffFile(t) + require.NotEmpty(t, latestDiffFile, "No diff file was generated") + return latestDiffFile +} + +// assertNoTableDiff runs a diff and asserts that there are no differences. +func (e *testEnv) assertNoTableDiff(t *testing.T, qualifiedTableName string) { + t.Helper() + nodesToCompare := []string{e.ServiceN1, e.ServiceN2} + tdTask := e.newTableDiffTask(t, qualifiedTableName, nodesToCompare) + + err := tdTask.RunChecks(false) + require.NoError(t, err, "assertNoTableDiff: validation failed") + err = tdTask.ExecuteTask() + require.NoError(t, err, "assertNoTableDiff: execution failed") + + assert.Empty(t, tdTask.DiffResult.NodeDiffs, "Expected no differences after repair, but diffs were found") +} + +// repairTable finds the latest diff file and runs a repair. +func (e *testEnv) repairTable(t *testing.T, qualifiedTableName, sourceOfTruthNode string) { + t.Helper() + files, err := filepath.Glob("*_diffs-*.json") + if err != nil { + t.Fatalf("Failed to find diff files: %v", err) + } + if len(files) == 0 { + log.Println("No diff file found to repair from, skipping repair.") + return + } + + sort.Slice(files, func(i, j int) bool { + fi, errI := os.Stat(files[i]) + if errI != nil { + t.Logf("Warning: could not stat file %s: %v", files[i], errI) + return false + } + fj, errJ := os.Stat(files[j]) + if errJ != nil { + t.Logf("Warning: could not stat file %s: %v", files[j], errJ) + return false + } + return fi.ModTime().After(fj.ModTime()) + }) + + latestDiffFile := files[0] + log.Printf("Using latest diff file for repair: %s", latestDiffFile) + + repairTask := e.newTableRepairTask(sourceOfTruthNode, qualifiedTableName, latestDiffFile) + time.Sleep(2 * time.Second) + if err := repairTask.Run(false); err != nil { + t.Fatalf("Failed to repair table: %v", err) + } + log.Printf("Table '%s' repaired successfully using %s as source of truth.", qualifiedTableName, sourceOfTruthNode) +} + +// resetSharedTable truncates and reloads the given table from CSV on n1 and n2. +func (e *testEnv) resetSharedTable(t *testing.T, tableName string) { + t.Helper() + ctx := context.Background() + qualifiedTableName := fmt.Sprintf("%s.%s", e.Schema, tableName) + csvPath, err := filepath.Abs(defaultCsvFilePath + tableName + ".csv") + require.NoError(t, err) + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate %s", qualifiedTableName) + }) + require.NoError(t, loadDataFromCSV(ctx, pool, e.Schema, tableName, csvPath), "load CSV into %s", qualifiedTableName) + } +} + +// pairKey returns the canonical node pair key for diff results. +func (e *testEnv) pairKey() string { + if strings.Compare(e.ServiceN1, e.ServiceN2) > 0 { + return e.ServiceN2 + "/" + e.ServiceN1 + } + return e.ServiceN1 + "/" + e.ServiceN2 +} + +// newMerkleTreeTask creates a MerkleTreeTask configured for this environment. +func (e *testEnv) newMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { + t.Helper() + task := mtree.NewMerkleTreeTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.QualifiedTableName = qualifiedTableName + task.Nodes = strings.Join(nodes, ",") + task.BlockSize = 1000 + return task +}