From df369205ec823a296c283f7fa0ec615cb4bcfdea Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Thu, 12 Mar 2026 12:35:24 -0700 Subject: [PATCH 01/14] Support table diff/repair without spock dependency (native PG) Auto-detect whether spock is installed and branch accordingly across all repair, diff, rerun, and merkle code paths: - Add detectSpock() to table repair; conditionally call spock.repair_mode() only when spock is present, fall back to session_replication_role alone - Replace spock.xact_commit_timestamp_origin() with native PG14+ pg_xact_commit_timestamp_origin() in diff, rerun, and merkle queries - Add GetNodeOriginNames() wrapper that uses spock.node when available, falls back to pg_replication_origin for node name resolution - Add native PG alternatives for LSN queries (pg_subscription-based) - Add CheckSpockInstalled() utility; spock-diff and repset-diff now return a clear error when spock is not installed Co-Authored-By: Claude Opus 4.6 --- db/queries/queries.go | 79 +++++++++++ db/queries/templates.go | 24 ++++ internal/consistency/diff/repset_diff.go | 13 ++ internal/consistency/diff/spock_diff.go | 12 ++ internal/consistency/diff/table_diff.go | 4 +- internal/consistency/diff/table_rerun.go | 2 +- internal/consistency/mtree/merkle.go | 6 +- internal/consistency/repair/table_repair.go | 140 ++++++++++++++------ 8 files changed, 236 insertions(+), 44 deletions(-) diff --git a/db/queries/queries.go b/db/queries/queries.go index ee837a9..1c7b471 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -939,6 +939,85 @@ func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string return lsn, nil } +func GetNativeOriginLSNForNode(ctx context.Context, db DBQuerier, originNodeName string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeOriginLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, originNodeName).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch native origin lsn: %w", err) + } + return lsn, nil +} + +func GetNativeSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeSlotLSNForNode, nil) + if err != nil { + return nil, err + } + var lsn *string + if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("failed to fetch native slot lsn: %w", err) + } + return lsn, nil +} + +func GetReplicationOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + sql, err := RenderSQL(SQLTemplates.GetReplicationOriginNames, nil) + if err != nil { + return nil, err + } + + rows, err := db.Query(ctx, sql) + if err != nil { + return nil, err + } + defer rows.Close() + + names := make(map[string]string) + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + names[id] = name + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return names, nil +} + +func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + var spockAvailable bool + err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&spockAvailable) + if err != nil { + return nil, fmt.Errorf("detecting spock extension: %w", err) + } + if spockAvailable { + return GetSpockNodeNames(ctx, db) + } + return GetReplicationOriginNames(ctx, db) +} + +func CheckSpockInstalled(ctx context.Context, db DBQuerier) (bool, error) { + var exists bool + err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&exists) + if err != nil { + return false, fmt.Errorf("detecting spock extension: %w", err) + } + return exists, nil +} + func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) { sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil) if err != nil { diff --git a/db/queries/templates.go b/db/queries/templates.go index ccf564d..46b919d 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -119,6 +119,9 @@ type Templates struct { RemoveTableFromCDCMetadata *template.Template GetSpockOriginLSNForNode *template.Template GetSpockSlotLSNForNode *template.Template + GetNativeOriginLSNForNode *template.Template + GetNativeSlotLSNForNode *template.Template + GetReplicationOriginNames *template.Template EnsureHashVersionColumn *template.Template GetHashVersion *template.Template MarkAllLeavesDirty *template.Template @@ -1527,6 +1530,27 @@ var SQLTemplates = Templates{ SET hash_version = $1, last_updated = current_timestamp WHERE schema_name = $2 AND table_name = $3 `)), + GetNativeOriginLSNForNode: template.Must(template.New("getNativeOriginLSNForNode").Parse(` + SELECT ros.remote_lsn::text + FROM pg_catalog.pg_replication_origin_status ros + JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id + JOIN pg_catalog.pg_subscription s ON ro.roname LIKE 'pg_%' || s.oid::text + WHERE s.subname LIKE '%' || $1 || '%' + AND ros.remote_lsn IS NOT NULL + LIMIT 1 + `)), + GetNativeSlotLSNForNode: template.Must(template.New("getNativeSlotLSNForNode").Parse(` + SELECT rs.confirmed_flush_lsn::text + FROM pg_catalog.pg_replication_slots rs + JOIN pg_catalog.pg_subscription s ON rs.slot_name = s.subslotname + WHERE s.subname LIKE '%' || $1 || '%' + AND rs.confirmed_flush_lsn IS NOT NULL + ORDER BY rs.confirmed_flush_lsn DESC + LIMIT 1 + `)), + GetReplicationOriginNames: template.Must(template.New("getReplicationOriginNames").Parse(` + SELECT roident::text, roname FROM pg_replication_origin; + `)), GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(` SELECT roident FROM pg_replication_origin WHERE roname = $1 `)), diff --git a/internal/consistency/diff/repset_diff.go b/internal/consistency/diff/repset_diff.go index 8eeb94c..e0b545f 100644 --- a/internal/consistency/diff/repset_diff.go +++ b/internal/consistency/diff/repset_diff.go @@ -157,6 +157,19 @@ func (c *RepsetDiffCmd) RunChecks(skipValidation bool) error { return fmt.Errorf("could not connect to node %s: %w", nodeName, err) } + // Check if spock extension is installed (only on first node) + if len(repsetNodeNames) == 0 { + spockInstalled, err := queries.CheckSpockInstalled(c.Ctx, pool) + if err != nil { + pool.Close() + return fmt.Errorf("failed to check for spock extension on node %s: %w", nodeName, err) + } + if !spockInstalled { + pool.Close() + return fmt.Errorf("repset-diff requires the spock extension, which is not installed on this cluster") + } + } + repsetExists, err := queries.CheckRepSetExists(c.Ctx, pool, c.RepsetName) if err != nil { pool.Close() diff --git a/internal/consistency/diff/spock_diff.go b/internal/consistency/diff/spock_diff.go index d560a88..d8dcea9 100644 --- a/internal/consistency/diff/spock_diff.go +++ b/internal/consistency/diff/spock_diff.go @@ -307,6 +307,18 @@ func (t *SpockDiffTask) ExecuteTask() (err error) { } t.Pools = pools + // Check if spock extension is installed + for _, pool := range t.Pools { + spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) + if err != nil { + return fmt.Errorf("failed to check for spock extension: %w", err) + } + if !spockInstalled { + return fmt.Errorf("spock-diff requires the spock extension, which is not installed on this cluster") + } + break + } + allNodeConfigs := make(map[string]SpockNodeConfig) var nodeNames []string diff --git a/internal/consistency/diff/table_diff.go b/internal/consistency/diff/table_diff.go index 02ff81f..6326def 100644 --- a/internal/consistency/diff/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -200,7 +200,7 @@ func (t *TableDiffTask) loadSpockNodeNames() error { return fmt.Errorf("no connection pool available to load spock node names") } - names, err := queries.GetSpockNodeNames(t.Ctx, firstPool) + names, err := queries.GetNodeOriginNames(t.Ctx, firstPool) if err != nil { t.SpockNodeNames = make(map[string]string) return err @@ -450,7 +450,7 @@ func (t *TableDiffTask) fetchRows(nodeName string, r Range) ([]types.OrderedMap, } selectCols := make([]string, 0, len(t.Cols)+2) - selectCols = append(selectCols, "pg_xact_commit_timestamp(xmin) as commit_ts", "to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin") + selectCols = append(selectCols, "pg_xact_commit_timestamp(xmin) as commit_ts", "to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin") for _, colName := range t.Cols { colType := colTypes[colName] diff --git a/internal/consistency/diff/table_rerun.go b/internal/consistency/diff/table_rerun.go index f3ca7a0..c79c30b 100644 --- a/internal/consistency/diff/table_rerun.go +++ b/internal/consistency/diff/table_rerun.go @@ -277,7 +277,7 @@ func fetchRowsByPkeys(ctx context.Context, pool *pgxpool.Pool, t *TableDiffTask, } selectCols := make([]string, 0, len(t.Cols)+2) - selectCols = append(selectCols, "pg_xact_commit_timestamp(t.xmin) as commit_ts", "to_json(spock.xact_commit_timestamp_origin(t.xmin))->>'roident' as node_origin") + selectCols = append(selectCols, "pg_xact_commit_timestamp(t.xmin) as commit_ts", "to_json(pg_xact_commit_timestamp_origin(t.xmin))->>'roident' as node_origin") for _, col := range t.Cols { selectCols = append(selectCols, "t."+pgx.Identifier{col}.Sanitize()) } diff --git a/internal/consistency/mtree/merkle.go b/internal/consistency/mtree/merkle.go index 2b53907..1325eed 100644 --- a/internal/consistency/mtree/merkle.go +++ b/internal/consistency/mtree/merkle.go @@ -517,7 +517,7 @@ func (m *MerkleTreeTask) loadSpockNodeNames() error { lastErr = err continue } - names, err := queries.GetSpockNodeNames(m.Ctx, pool) + names, err := queries.GetNodeOriginNames(m.Ctx, pool) pool.Close() if err != nil { lastErr = err @@ -721,7 +721,7 @@ func buildFetchRowsSQLSimple(schema, table, pk string, orderBy string, keys []an } qualifiedTable := fmt.Sprintf("%s.%s", pgx.Identifier{schema}.Sanitize(), pgx.Identifier{table}.Sanitize()) where := fmt.Sprintf("%s IN (%s)", pgx.Identifier{pk}.Sanitize(), strings.Join(placeholders, ",")) - selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, qualifiedTable, where, orderBy) return q, args } @@ -745,7 +745,7 @@ func buildFetchRowsSQLComposite(schema, table string, pk []string, orderBy strin } qualifiedTable := fmt.Sprintf("%s.%s", pgx.Identifier{schema}.Sanitize(), pgx.Identifier{table}.Sanitize()) where := fmt.Sprintf("( %s ) IN ( %s )", strings.Join(tupleCols, ","), strings.Join(tuples, ",")) - selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" + selectCols := "pg_xact_commit_timestamp(xmin) as commit_ts, to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' as node_origin, *" q := fmt.Sprintf("SELECT %s FROM %s WHERE %s ORDER BY %s", selectCols, qualifiedTable, where, orderBy) return q, args } diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index a2c70f9..296b820 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -98,6 +98,8 @@ type TableRepairTask struct { autoSelectionFailedNode string autoSelectionDetails map[string]map[string]string + spockAvailable bool + Ctx context.Context } @@ -160,13 +162,35 @@ func (t *TableRepairTask) setRole(tx pgx.Tx, nodeName string) error { return nil } -// setupTransactionMode enables spock repair mode, sets the session replication +// detectSpock checks whether the spock extension is installed on any available +// node and stores the result in t.spockAvailable. +func (t *TableRepairTask) detectSpock() { + for nodeName, pool := range t.Pools { + spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) + if err != nil { + logger.Warn("failed to detect spock extension on %s: %v", nodeName, err) + continue + } + t.spockAvailable = spockInstalled + logger.Info("spock extension detected: %v (checked on %s)", spockInstalled, nodeName) + return + } + logger.Warn("could not detect spock extension on any node; assuming not available") + t.spockAvailable = false +} + +// setupTransactionMode enables spock repair mode (when available), sets the session replication // role, and applies the client role for a repair transaction. -func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) error { - if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { - return fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) +// Returns true if spock repair mode was activated, false otherwise. +func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool, error) { + spockRepairModeActive := false + if t.spockAvailable { + if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { + return false, fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) + } + logger.Debug("spock.repair_mode(true) set on %s", nodeName) + spockRepairModeActive = true } - logger.Debug("spock.repair_mode(true) set on %s", nodeName) var err error if t.FireTriggers { @@ -175,13 +199,27 @@ func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) error _, err = tx.Exec(t.Ctx, "SET session_replication_role = 'replica'") } if err != nil { - return fmt.Errorf("setting session_replication_role on %s: %w", nodeName, err) + return false, fmt.Errorf("setting session_replication_role on %s: %w", nodeName, err) } logger.Debug("session_replication_role set on %s (fire_triggers: %v)", nodeName, t.FireTriggers) if err := t.setRole(tx, nodeName); err != nil { - return err + return false, err } + return spockRepairModeActive, nil +} + +// disableSpockRepairMode disables spock repair mode on the given transaction. +// When spock is not available, this is a no-op. +func (t *TableRepairTask) disableSpockRepairMode(tx pgx.Tx, nodeName string) error { + if !t.spockAvailable { + return nil + } + _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") + if err != nil { + return fmt.Errorf("disabling spock.repair_mode(false) on %s: %w", nodeName, err) + } + logger.Debug("spock.repair_mode(false) set on %s", nodeName) return nil } @@ -715,6 +753,11 @@ func (t *TableRepairTask) Run(skipValidation bool) (err error) { } }() + // Detect whether spock extension is available before any repair path + if !t.DryRun { + t.detectSpock() + } + if t.FixNulls { return t.runFixNulls(startTime) } @@ -810,13 +853,13 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { continue } - if err := t.setupTransactionMode(tx, nodeName); err != nil { + spockRepairModeActive, err := t.setupTransactionMode(tx, nodeName) + if err != nil { tx.Rollback(t.Ctx) logger.Error("%v", err) repairErrors = append(repairErrors, err.Error()) continue } - spockRepairModeActive := true colTypes, _, err := t.getColTypesForNode(nodeName) if err != nil { @@ -874,11 +917,10 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { // Commit the initial tx (repair_mode setup) first, then use // per-batch-key transactions for origin preservation. if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } } @@ -929,11 +971,10 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { if !nodeFailed { if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) nodeFailed = true } } @@ -1349,7 +1390,7 @@ func (t *TableRepairTask) executePreserveOriginFixNulls(pool *pgxpool.Pool, node return totalUpdated, fmt.Errorf("failed to begin batch transaction for origin %s: %w", batchKey.nodeOrigin, err) } - if err := t.setupTransactionMode(batchTx, nodeName); err != nil { + if _, err := t.setupTransactionMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpdated, fmt.Errorf("failed to setup transaction mode for origin batch: %w", err) } @@ -1393,7 +1434,7 @@ func (t *TableRepairTask) executePreserveOriginFixNulls(pool *pgxpool.Pool, node } } - if _, err := batchTx.Exec(t.Ctx, "SELECT spock.repair_mode(false)"); err != nil { + if err := t.disableSpockRepairMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpdated, fmt.Errorf("failed to disable repair mode for origin batch: %w", err) } @@ -1574,13 +1615,13 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { continue } - if err := t.setupTransactionMode(tx, nodeName); err != nil { + spockRepairModeActive, err := t.setupTransactionMode(tx, nodeName) + if err != nil { tx.Rollback(t.Ctx) logger.Error("%v", err) repairErrors = append(repairErrors, err.Error()) continue } - spockRepairModeActive := true // TODO: DROP PRIVILEGES HERE! @@ -1634,11 +1675,10 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { // separate per-batch-key transactions so each batch gets its own // pg_replication_origin_xact_setup (which is per-transaction). if spockRepairModeActive { - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } } @@ -1691,15 +1731,22 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { if spockRepairModeActive { // TODO: Need to elevate privileges here, but might be difficult // with pgx transactions and connection pooling. - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { tx.Rollback(t.Ctx) - logger.Error("disabling spock.repair_mode(false) on %s: %v", nodeName, err) - repairErrors = append(repairErrors, fmt.Sprintf("spock.repair_mode(false) failed for %s: %v", nodeName, err)) + logger.Error("%v", err) + repairErrors = append(repairErrors, err.Error()) continue } - logger.Debug("spock.repair_mode(false) set on %s", nodeName) + err = tx.Commit(t.Ctx) + if err != nil { + logger.Error("committing transaction on node %s: %v", nodeName, err) + repairErrors = append(repairErrors, fmt.Sprintf("commit failed for %s: %v", nodeName, err)) + continue + } + logger.Debug("Transaction committed successfully on %s", nodeName) + } else if !t.PreserveOrigin || len(t.extractOriginInfoForNode(nodeName, fullUpserts[nodeName])) == 0 { + // Non-spock path: still need to commit the transaction err = tx.Commit(t.Ctx) if err != nil { logger.Error("committing transaction on node %s: %v", nodeName, err) @@ -1994,7 +2041,7 @@ func (t *TableRepairTask) performBirectionalInserts(nodeName string, inserts map } defer tx.Rollback(t.Ctx) - if err := t.setupTransactionMode(tx, nodeName); err != nil { + if _, err := t.setupTransactionMode(tx, nodeName); err != nil { return 0, err } @@ -2004,9 +2051,8 @@ func (t *TableRepairTask) performBirectionalInserts(nodeName string, inserts map } logger.Info("Executed %d insert operations on %s", insertedCount, nodeName) - _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") - if err != nil { - return 0, fmt.Errorf("failed to disable spock.repair_mode(false) on %s: %w", nodeName, err) + if err := t.disableSpockRepairMode(tx, nodeName); err != nil { + return 0, err } if err := tx.Commit(t.Ctx); err != nil { @@ -2570,7 +2616,7 @@ func (t *TableRepairTask) executePreserveOriginUpserts(pool *pgxpool.Pool, nodeN return totalUpsertedCount, fmt.Errorf("failed to begin batch transaction for origin %s: %w", batchKey.nodeOrigin, err) } - if err := t.setupTransactionMode(batchTx, nodeName); err != nil { + if _, err := t.setupTransactionMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpsertedCount, fmt.Errorf("failed to setup transaction mode for origin batch: %w", err) } @@ -2606,7 +2652,7 @@ func (t *TableRepairTask) executePreserveOriginUpserts(pool *pgxpool.Pool, nodeN } // Disable repair mode before commit - if _, err := batchTx.Exec(t.Ctx, "SELECT spock.repair_mode(false)"); err != nil { + if err := t.disableSpockRepairMode(batchTx, nodeName); err != nil { batchTx.Rollback(t.Ctx) return totalUpsertedCount, fmt.Errorf("failed to disable repair mode for origin batch: %w", err) } @@ -2873,8 +2919,17 @@ func calculateRepairSetsWithSourceOfTruth(task *TableRepairTask) (map[string]map func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survivor string) (originLSN *uint64, slotLSN *uint64, err error) { var originStr *string - originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) + if t.spockAvailable { + originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) + } else { + originStr, err = queries.GetNativeOriginLSNForNode(t.Ctx, pool, failedNode) + } if err != nil { + if !t.spockAvailable { + // Native PG queries may fail if subscription naming doesn't match; treat as no data + logger.Warn("failed to fetch native origin lsn on %s: %v", survivor, err) + return nil, nil, nil + } return nil, nil, fmt.Errorf("failed to fetch origin lsn on %s: %w", survivor, err) } if originStr != nil { @@ -2885,8 +2940,17 @@ func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survi } var slotStr *string - slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) + if t.spockAvailable { + slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) + } else { + slotStr, err = queries.GetNativeSlotLSNForNode(t.Ctx, pool, failedNode) + } if err != nil { + if !t.spockAvailable { + // Native PG queries may fail if subscription naming doesn't match; treat as no data + logger.Warn("failed to fetch native slot lsn on %s: %v", survivor, err) + return originLSN, nil, nil + } return originLSN, nil, fmt.Errorf("failed to fetch slot lsn on %s: %w", survivor, err) } if slotStr != nil { From 16a1c5c0fc05700c9059028c0e8b5ea077cfa677 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Thu, 12 Mar 2026 14:54:36 -0700 Subject: [PATCH 02/14] Add testEnv abstraction to run integration tests on both spock and native PG Introduce a testEnv struct that encapsulates environment-specific state (pools, service names, cluster config, HasSpock flag), allowing the same test logic to run against both spock-replicated and vanilla PostgreSQL clusters with zero code duplication. Refactor 17 test functions (9 repair, 8 diff) to accept *testEnv. Add docker-compose-native.yaml with two postgres:17 containers and a TestNativePG suite that exercises all applicable shared tests (40 subtests) against standalone PG nodes. Co-Authored-By: Claude Opus 4.6 --- tests/integration/docker-compose-native.yaml | 45 ++ tests/integration/helpers_test.go | 48 +- tests/integration/main_test.go | 16 +- tests/integration/native_pg_test.go | 322 ++++++++++++ tests/integration/table_diff_test.go | 410 +++++++-------- tests/integration/table_repair_test.go | 511 ++++++++----------- tests/integration/test_env_test.go | 372 ++++++++++++++ 7 files changed, 1119 insertions(+), 605 deletions(-) create mode 100644 tests/integration/docker-compose-native.yaml create mode 100644 tests/integration/native_pg_test.go create mode 100644 tests/integration/test_env_test.go diff --git a/tests/integration/docker-compose-native.yaml b/tests/integration/docker-compose-native.yaml new file mode 100644 index 0000000..030955f --- /dev/null +++ b/tests/integration/docker-compose-native.yaml @@ -0,0 +1,45 @@ +############################################################################# +# +# ACE - Active Consistency Engine +# +# Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +# +# This software is released under the PostgreSQL License: +# https://opensource.org/license/postgresql +# +############################################################################# + +# Minimal docker-compose with vanilla PostgreSQL images (no spock extension). +# Used by TestNativePG_* tests to verify that table-diff and table-repair +# work without the spock extension. +services: + native-n1: + image: postgres:17 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + command: + - "postgres" + - "-c" + - "track_commit_timestamp=on" + - "-c" + - "wal_level=logical" + ports: + - target: 5432 + published: 7432 + native-n2: + image: postgres:17 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + command: + - "postgres" + - "-c" + - "track_commit_timestamp=on" + - "-c" + - "wal_level=logical" + ports: + - target: 5432 + published: 7433 diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index ebef9e6..0d189d3 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -18,8 +18,6 @@ import ( "io" "log" "os" - "path/filepath" - "sort" "strings" "testing" "time" @@ -285,54 +283,12 @@ func loadDataFromCSV( } func newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *repair.TableRepairTask { - task := repair.NewTableRepairTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.SourceOfTruth = sourceOfTruthNode - task.QualifiedTableName = qualifiedTableName - task.DiffFilePath = diffFilePath - task.Nodes = "all" - return task + return newSpockEnv().newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath) } func repairTable(t *testing.T, qualifiedTableName, sourceOfTruthNode string) { t.Helper() - - files, err := filepath.Glob("*_diffs-*.json") - if err != nil { - t.Fatalf("Failed to find diff files: %v", err) - } - if len(files) == 0 { - log.Println("No diff file found to repair from, skipping repair.") - return - } - - sort.Slice(files, func(i, j int) bool { - fi, errI := os.Stat(files[i]) - if errI != nil { - t.Logf("Warning: could not stat file %s: %v", files[i], errI) - return false - } - fj, errJ := os.Stat(files[j]) - if errJ != nil { - t.Logf("Warning: could not stat file %s: %v", files[j], errJ) - return false - } - return fi.ModTime().After(fj.ModTime()) - }) - - latestDiffFile := files[0] - log.Printf("Using latest diff file for repair: %s", latestDiffFile) - - repairTask := newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, latestDiffFile) - - time.Sleep(2 * time.Second) - - if err := repairTask.Run(false); err != nil { - t.Fatalf("Failed to repair table: %v", err) - } - - log.Printf("Table '%s' repaired successfully using %s as source of truth.", qualifiedTableName, sourceOfTruthNode) + newSpockEnv().repairTable(t, qualifiedTableName, sourceOfTruthNode) } // getCommitTimestamp retrieves the commit timestamp for a specific row diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 98f5175..e6374cf 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -27,7 +27,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/types" - "github.com/stretchr/testify/require" tcLog "github.com/testcontainers/testcontainers-go/log" "github.com/testcontainers/testcontainers-go/modules/compose" "github.com/testcontainers/testcontainers-go/wait" @@ -476,18 +475,5 @@ func setupSharedCustomersTable(tableName string) error { // after other tests may have modified the table. func resetSharedTable(t *testing.T, tableName string) { t.Helper() - ctx := context.Background() - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - csvPath, err := filepath.Abs(defaultCsvFilePath + tableName + ".csv") - require.NoError(t, err) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "enable repair_mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "truncate %s on %s", qualifiedTableName, nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "disable repair_mode on %s", nodeName) - require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, tableName, csvPath), "load CSV into %s on %s", qualifiedTableName, nodeName) - } + newSpockEnv().resetSharedTable(t, tableName) } diff --git a/tests/integration/native_pg_test.go b/tests/integration/native_pg_test.go new file mode 100644 index 0000000..9ea66f7 --- /dev/null +++ b/tests/integration/native_pg_test.go @@ -0,0 +1,322 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package integration + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/docker/go-connections/nat" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/db/queries" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go/modules/compose" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + nativeUser = "postgres" + nativePassword = "password" + nativeDBName = "testdb" + nativeServiceN1 = "native-n1" + nativeServiceN2 = "native-n2" + nativeContainerPort = "5432/tcp" + nativeComposeFile = "docker-compose-native.yaml" + nativeClusterName = "native_test_cluster" +) + +// nativeClusterState holds connections and config for vanilla PG test containers. +type nativeClusterState struct { + stack compose.ComposeStack + n1Host string + n1Port string + n1Pool *pgxpool.Pool + n2Host string + n2Port string + n2Pool *pgxpool.Pool +} + +func setupNativeCluster(t *testing.T) *nativeClusterState { + t.Helper() + ctx := context.Background() + + absCompose, err := filepath.Abs(nativeComposeFile) + require.NoError(t, err, "resolve native compose file path") + + identifier := strings.ToLower(fmt.Sprintf("ace_native_test_%d", time.Now().UnixNano())) + + waitN1 := wait.ForListeningPort(nat.Port(nativeContainerPort)). + WithStartupTimeout(startupTimeout). + WithPollInterval(5 * time.Second) + waitN2 := wait.ForListeningPort(nat.Port(nativeContainerPort)). + WithStartupTimeout(startupTimeout). + WithPollInterval(5 * time.Second) + + stack, err := compose.NewDockerComposeWith( + compose.StackIdentifier(identifier), + compose.WithStackFiles(absCompose), + ) + require.NoError(t, err, "create native compose stack") + + execErr := stack. + WaitForService(nativeServiceN1, waitN1). + WaitForService(nativeServiceN2, waitN2). + Up(ctx, compose.Wait(true)) + require.NoError(t, execErr, "start native compose stack") + + state := &nativeClusterState{stack: stack} + + // Get mapped host/port for n1 + n1Container, err := stack.ServiceContainer(ctx, nativeServiceN1) + require.NoError(t, err, "get native-n1 container") + n1Host, err := n1Container.Host(ctx) + require.NoError(t, err, "get native-n1 host") + cPort, err := nat.NewPort("tcp", "5432") + require.NoError(t, err) + n1MappedPort, err := n1Container.MappedPort(ctx, cPort) + require.NoError(t, err, "get native-n1 mapped port") + state.n1Host = n1Host + state.n1Port = n1MappedPort.Port() + + // Get mapped host/port for n2 + n2Container, err := stack.ServiceContainer(ctx, nativeServiceN2) + require.NoError(t, err, "get native-n2 container") + n2Host, err := n2Container.Host(ctx) + require.NoError(t, err, "get native-n2 host") + n2MappedPort, err := n2Container.MappedPort(ctx, cPort) + require.NoError(t, err, "get native-n2 mapped port") + state.n2Host = n2Host + state.n2Port = n2MappedPort.Port() + + // Connect + state.n1Pool, err = connectToNode(state.n1Host, state.n1Port, nativeUser, nativePassword, nativeDBName) + require.NoError(t, err, "connect to native-n1") + state.n2Pool, err = connectToNode(state.n2Host, state.n2Port, nativeUser, nativePassword, nativeDBName) + require.NoError(t, err, "connect to native-n2") + + // Create pgcrypto extension on both nodes + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + _, err = pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS pgcrypto") + require.NoError(t, err, "create pgcrypto extension") + } + + log.Printf("Native PG cluster ready: n1=%s:%s, n2=%s:%s", state.n1Host, state.n1Port, state.n2Host, state.n2Port) + return state +} + +func (s *nativeClusterState) teardown(t *testing.T) { + t.Helper() + if s.n1Pool != nil { + s.n1Pool.Close() + } + if s.n2Pool != nil { + s.n2Pool.Close() + } + if s.stack != nil { + execErr := s.stack.Down( + context.Background(), + compose.RemoveOrphans(true), + compose.RemoveVolumes(true), + ) + if execErr != nil { + t.Logf("Failed to tear down native compose stack: %v", execErr) + } + } + // Clean up diff files and cluster config + files, _ := filepath.Glob("*_diffs-*.json") + for _, f := range files { + os.Remove(f) + } + os.Remove(nativeClusterName + ".json") +} + +// writeClusterConfig writes a cluster config JSON that the diff/repair +// tasks use to discover nodes and credentials. +func (s *nativeClusterState) writeClusterConfig(t *testing.T) { + t.Helper() + cfg := types.ClusterConfig{ + JSONVersion: "1.0", + ClusterName: nativeClusterName, + LogLevel: "info", + UpdateDate: time.Now().Format(time.RFC3339), + PGEdge: struct { + PGVersion int `json:"pg_version"` + AutoStart string `json:"auto_start"` + Spock types.SpockConfig `json:"spock"` + Databases []types.Database `json:"databases"` + }{ + PGVersion: 17, + AutoStart: "yes", + Databases: []types.Database{ + { + DBName: nativeDBName, + DBUser: nativeUser, + DBPassword: nativePassword, + }, + }, + }, + NodeGroups: []types.NodeGroup{ + { + Name: nativeServiceN1, + IsActive: "yes", + PublicIP: s.n1Host, + Port: s.n1Port, + }, + { + Name: nativeServiceN2, + IsActive: "yes", + PublicIP: s.n2Host, + Port: s.n2Port, + }, + }, + } + + data, err := json.MarshalIndent(cfg, "", " ") + require.NoError(t, err, "marshal native cluster config") + require.NoError(t, os.WriteFile(nativeClusterName+".json", data, 0644), "write native cluster config") +} + +// TestNativePG runs the full suite of native PostgreSQL (no spock) tests. +// It starts its own vanilla postgres:17 Docker Compose stack and runs the +// same shared test logic used by the spock tests, via testEnv abstraction. +func TestNativePG(t *testing.T) { + state := setupNativeCluster(t) + t.Cleanup(func() { state.teardown(t) }) + state.writeClusterConfig(t) + + ctx := context.Background() + + // Create the customers table on both nodes and load initial data from CSV. + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + require.NoError(t, createTestTable(ctx, pool, testSchema, "customers")) + } + + env := newNativeEnv(state) + + // Load CSV data so that shared tests (e.g. NoDifferences, DataOnlyOnNode1) + // have a known baseline. + csvPath, err := filepath.Abs(defaultCsvFilePath + "customers.csv") + require.NoError(t, err) + for _, pool := range env.pools() { + require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, "customers", csvPath), + "load CSV into customers") + } + + // ── Native-specific tests ───────────────────────────────────────────── + + t.Run("CheckSpockInstalled_ReturnsFalse", func(t *testing.T) { + installed, err := queries.CheckSpockInstalled(ctx, state.n1Pool) + require.NoError(t, err) + assert.False(t, installed, "spock should not be installed on vanilla PG") + }) + + t.Run("SpockDiff_GracefulError", func(t *testing.T) { + task := diff.NewSpockDiffTask() + task.ClusterName = nativeClusterName + task.DBName = nativeDBName + task.Nodes = nativeServiceN1 + "," + nativeServiceN2 + task.Ctx = context.Background() + task.SkipDBUpdate = true + + err := task.RunChecks(false) + require.NoError(t, err, "spock-diff RunChecks should succeed (it just validates and connects)") + + err = task.ExecuteTask() + require.Error(t, err, "spock-diff should fail on vanilla PG") + assert.Contains(t, err.Error(), "spock extension", "error should mention spock extension") + }) + + t.Run("RepsetDiff_GracefulError", func(t *testing.T) { + task := diff.NewRepsetDiffTask() + task.ClusterName = nativeClusterName + task.DBName = nativeDBName + task.RepsetName = "default" + task.Nodes = "all" + task.SkipDBUpdate = true + + err := task.RunChecks(false) + require.Error(t, err, "repset-diff should fail on vanilla PG") + assert.Contains(t, err.Error(), "spock extension", "error should mention spock extension") + }) + + // ── Shared table-diff tests (simple PK) ─────────────────────────────── + + t.Run("TableDiffSimplePK", func(t *testing.T) { + t.Run("Customers", func(t *testing.T) { + runCustomerTableDiffTests(t, env) + }) + t.Run("MixedCaseIdentifiers", func(t *testing.T) { + testTableDiff_MixedCaseIdentifiers(t, env, false) + }) + }) + + // ── Shared table-diff tests (composite PK) ─────────────────────────── + + t.Run("TableDiffCompositePK", func(t *testing.T) { + t.Run("Customers", func(t *testing.T) { + for _, pool := range env.pools() { + err := alterTableToCompositeKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + t.Cleanup(func() { + for _, pool := range env.pools() { + err := revertTableToSimpleKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + }) + runCustomerTableDiffTests(t, env) + }) + t.Run("MixedCaseIdentifiers", func(t *testing.T) { + testTableDiff_MixedCaseIdentifiers(t, env, true) + }) + }) + + // ── Shared table-repair tests ───────────────────────────────────────── + + t.Run("TableRepair_UnidirectionalDefault", func(t *testing.T) { + testTableRepair_UnidirectionalDefault(t, env) + }) + t.Run("TableRepair_InsertOnly", func(t *testing.T) { + testTableRepair_InsertOnly(t, env) + }) + t.Run("TableRepair_UpsertOnly", func(t *testing.T) { + testTableRepair_UpsertOnly(t, env) + }) + t.Run("TableRepair_Bidirectional", func(t *testing.T) { + testTableRepair_Bidirectional(t, env) + }) + t.Run("TableRepair_DryRun", func(t *testing.T) { + testTableRepair_DryRun(t, env) + }) + t.Run("TableRepair_GenerateReport", func(t *testing.T) { + testTableRepair_GenerateReport(t, env) + }) + t.Run("TableRepair_FixNulls", func(t *testing.T) { + testTableRepair_FixNulls(t, env) + }) + t.Run("TableRepair_FixNulls_DryRun", func(t *testing.T) { + testTableRepair_FixNulls_DryRun(t, env) + }) + t.Run("TableRepair_FixNulls_BidirectionalUpdate", func(t *testing.T) { + testTableRepair_FixNulls_BidirectionalUpdate(t, env) + }) +} diff --git a/tests/integration/table_diff_test.go b/tests/integration/table_diff_test.go index b0b1a82..023b3fc 100644 --- a/tests/integration/table_diff_test.go +++ b/tests/integration/table_diff_test.go @@ -15,7 +15,6 @@ import ( "context" "fmt" "log" - "math" "os" "path/filepath" "strings" @@ -24,7 +23,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/internal/consistency/diff" - "github.com/pgedge/ace/pkg/types" "github.com/stretchr/testify/require" ) @@ -33,29 +31,7 @@ func newTestTableDiffTask( qualifiedTableName string, nodes []string, ) *diff.TableDiffTask { - task := diff.NewTableDiffTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.QualifiedTableName = qualifiedTableName - task.Nodes = strings.Join(nodes, ",") - task.Output = "json" - task.BlockSize = 1000 - task.CompareUnitSize = 100 - task.ConcurrencyFactor = 1 - task.MaxDiffRows = math.MaxInt64 - - task.DiffResult = types.DiffOutput{ - NodeDiffs: make(map[string]types.DiffByNodePair), - Summary: types.DiffSummary{ - Nodes: nodes, - BlockSize: task.BlockSize, - CompareUnitSize: task.CompareUnitSize, - ConcurrencyFactor: task.ConcurrencyFactor, - DiffRowsCount: make(map[string]int), - }, - } - - return task + return newSpockEnv().newTableDiffTask(t, qualifiedTableName, nodes) } func TestTableDiffUntilFilter(t *testing.T) { @@ -127,11 +103,12 @@ func TestTableDiffUntilFilter(t *testing.T) { } func TestTableDiffSimplePK(t *testing.T) { + env := newSpockEnv() t.Run("Customers", func(t *testing.T) { - runCustomerTableDiffTests(t) + runCustomerTableDiffTests(t, env) }) t.Run("MixedCaseIdentifiers", func(t *testing.T) { - testTableDiff_MixedCaseIdentifiers(t, false) + testTableDiff_MixedCaseIdentifiers(t, env, false) }) t.Run("VariousDataTypes", func(t *testing.T) { testTableDiff_VariousDataTypes(t, false) @@ -148,23 +125,24 @@ func TestTableDiffSimplePK(t *testing.T) { } func TestTableDiffCompositePK(t *testing.T) { + env := newSpockEnv() t.Run("Customers", func(t *testing.T) { ctx := context.Background() tableName := "customers" - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, err) } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, err) } }) - runCustomerTableDiffTests(t) + runCustomerTableDiffTests(t, env) }) t.Run("MixedCaseIdentifiers", func(t *testing.T) { - testTableDiff_MixedCaseIdentifiers(t, true) + testTableDiff_MixedCaseIdentifiers(t, env, true) }) t.Run("VariousDataTypes", func(t *testing.T) { testTableDiff_VariousDataTypes(t, true) @@ -180,22 +158,22 @@ func TestTableDiffCompositePK(t *testing.T) { }) } -func runCustomerTableDiffTests(t *testing.T) { - t.Run("NoDifferences", testTableDiff_NoDifferences) - t.Run("DataOnlyOnNode1", testTableDiff_DataOnlyOnNode1) - t.Run("DataOnlyOnNode2", testTableDiff_DataOnlyOnNode2) - t.Run("ModifiedRows", testTableDiff_ModifiedRows) - t.Run("TableFiltering", testTableDiff_TableFiltering) - t.Run("TableFilterNoRows", testTableDiff_TableFilterNoRows) - t.Run("MaxDiffRowsLimit", testTableDiff_MaxDiffRowsLimit) +func runCustomerTableDiffTests(t *testing.T, env *testEnv) { + t.Run("NoDifferences", func(t *testing.T) { testTableDiff_NoDifferences(t, env) }) + t.Run("DataOnlyOnNode1", func(t *testing.T) { testTableDiff_DataOnlyOnNode1(t, env) }) + t.Run("DataOnlyOnNode2", func(t *testing.T) { testTableDiff_DataOnlyOnNode2(t, env) }) + t.Run("ModifiedRows", func(t *testing.T) { testTableDiff_ModifiedRows(t, env) }) + t.Run("TableFiltering", func(t *testing.T) { testTableDiff_TableFiltering(t, env) }) + t.Run("TableFilterNoRows", func(t *testing.T) { testTableDiff_TableFilterNoRows(t, env) }) + t.Run("MaxDiffRowsLimit", func(t *testing.T) { testTableDiff_MaxDiffRowsLimit(t, env) }) } -func testTableDiff_NoDifferences(t *testing.T) { - resetSharedTable(t, "customers") +func testTableDiff_NoDifferences(t *testing.T, env *testEnv) { + env.resetSharedTable(t, "customers") tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -228,13 +206,13 @@ func testTableDiff_NoDifferences(t *testing.T) { log.Println("TestTableDiff_NoDifferences completed.") } -func testTableDiff_DataOnlyOnNode1(t *testing.T) { +func testTableDiff_DataOnlyOnNode1(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -242,32 +220,26 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { }) // Truncate the table on the second node to create the diff - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) - if err != nil { - t.Fatalf("Failed to truncate table %s on node2: %v", qualifiedTableName, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s on node2: %v", qualifiedTableName, err) + } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } - log.Printf("Data loaded only into %s for table %s", serviceN1, qualifiedTableName) + log.Printf("Data loaded only into %s for table %s", env.ServiceN1, qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -278,10 +250,7 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -293,26 +262,26 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { } var expectedDiffCount int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) if err != nil { t.Fatalf("Failed to count rows in %s on node1: %v", qualifiedTableName, err) } - node1OnlyRows := nodeDiffs.Rows[serviceN1] + node1OnlyRows := nodeDiffs.Rows[env.ServiceN1] if len(node1OnlyRows) != expectedDiffCount { t.Errorf( "Expected %d rows only on %s, but got %d", expectedDiffCount, - serviceN1, + env.ServiceN1, len(node1OnlyRows), ) } - node2OnlyRows := nodeDiffs.Rows[serviceN2] + node2OnlyRows := nodeDiffs.Rows[env.ServiceN2] if len(node2OnlyRows) != 0 { t.Errorf( "Expected 0 rows only on %s, but got %d", - serviceN2, + env.ServiceN2, len(node2OnlyRows), ) } @@ -329,45 +298,39 @@ func testTableDiff_DataOnlyOnNode1(t *testing.T) { log.Println("TestTableDiff_DataOnlyOnNode1 completed.") } -func testTableDiff_DataOnlyOnNode2(t *testing.T) { +func testTableDiff_DataOnlyOnNode2(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) } }) - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN1, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN1, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN1, err) - } - _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) - if err != nil { - t.Fatalf("Failed to truncate table %s on node1: %v", qualifiedTableName, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN1, err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) + if err != nil { + t.Fatalf("Failed to truncate table %s on node1: %v", qualifiedTableName, err) + } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN1, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN1, err) } - log.Printf("Data loaded only into %s for table %s", serviceN2, qualifiedTableName) + log.Printf("Data loaded only into %s for table %s", env.ServiceN2, qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -378,10 +341,7 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -393,26 +353,26 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { } var expectedDiffCount int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&expectedDiffCount) if err != nil { t.Fatalf("Failed to count rows in %s on node2: %v", qualifiedTableName, err) } - node1OnlyRows := nodeDiffs.Rows[serviceN1] + node1OnlyRows := nodeDiffs.Rows[env.ServiceN1] if len(node1OnlyRows) != 0 { t.Errorf( "Expected 0 rows only on %s, but got %d", - serviceN1, + env.ServiceN1, len(node1OnlyRows), ) } - node2OnlyRows := nodeDiffs.Rows[serviceN2] + node2OnlyRows := nodeDiffs.Rows[env.ServiceN2] if len(node2OnlyRows) != expectedDiffCount { t.Errorf( "Expected %d rows only on %s, but got %d", expectedDiffCount, - serviceN2, + env.ServiceN2, len(node2OnlyRows), ) } @@ -429,13 +389,13 @@ func testTableDiff_DataOnlyOnNode2(t *testing.T) { log.Println("TestTableDiff_DataOnlyOnNode2 completed.") } -func testTableDiff_ModifiedRows(t *testing.T) { +func testTableDiff_ModifiedRows(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -459,50 +419,44 @@ func testTableDiff_ModifiedRows(t *testing.T) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - for _, mod := range modifications { - updateSQL := fmt.Sprintf( - "UPDATE %s.%s SET %s = $1 WHERE index = $2", - testSchema, - tableName, - mod.field, - ) - _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - if err != nil { - t.Fatalf( - "Failed to update row with index %d on node %s: %v", - mod.indexVal, - serviceN2, - err, + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range modifications { + updateSQL := fmt.Sprintf( + "UPDATE %s.%s SET %s = $1 WHERE index = $2", + env.Schema, + tableName, + mod.field, ) + _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + if err != nil { + t.Fatalf( + "Failed to update row with index %d on node %s: %v", + mod.indexVal, + env.ServiceN2, + err, + ) + } } - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } log.Printf( "%d rows modified on %s for table %s", len(modifications), - serviceN2, + env.ServiceN2, qualifiedTableName, ) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -513,10 +467,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { t.Fatalf("ExecuteTask failed: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -527,26 +478,26 @@ func testTableDiff_ModifiedRows(t *testing.T) { ) } - if len(nodeDiffs.Rows[serviceN1]) != len(modifications) { + if len(nodeDiffs.Rows[env.ServiceN1]) != len(modifications) { t.Errorf( "Expected %d modified rows to be reported for %s (original values), but got %d. Rows: %+v", len( modifications, ), - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != len(modifications) { + if len(nodeDiffs.Rows[env.ServiceN2]) != len(modifications) { t.Errorf( "Expected %d modified rows to be reported for %s (modified values), but got %d. Rows: %+v", len( modifications, ), - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } @@ -563,7 +514,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { for _, mod := range modifications { found := false - for _, rowN2 := range nodeDiffs.Rows[serviceN2] { + for _, rowN2 := range nodeDiffs.Rows[env.ServiceN2] { if indexVal, ok := rowN2.Get("index"); ok && indexVal == int32(mod.indexVal) { actualModifiedValue, _ := rowN2.Get(mod.field) @@ -584,7 +535,7 @@ func testTableDiff_ModifiedRows(t *testing.T) { t.Errorf( "Modified row with Index %d not found in diff results for node %s", mod.indexVal, - serviceN2, + env.ServiceN2, ) } } @@ -592,10 +543,10 @@ func testTableDiff_ModifiedRows(t *testing.T) { log.Println("TestTableDiff_ModifiedRows completed.") } -func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { +func testTableDiff_MixedCaseIdentifiers(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "CustomersMixedCase" - qualifiedTableName := fmt.Sprintf("%s.\"%s\"", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.\"%s\"", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -610,25 +561,22 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { "LastName" VARCHAR(100), "EmailAddress" VARCHAR(100), PRIMARY KEY("ID"%s) -);`, testSchema, qualifiedTableName, compositeKeyPart) +);`, env.Schema, qualifiedTableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createMixedCaseTableSQL) if err != nil { t.Fatalf( - "Failed to create mixed-case table %s on node %s: %v", + "Failed to create mixed-case table %s: %v", qualifiedTableName, - nodeName, err, ) } _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) if err != nil { t.Fatalf( - "Failed to truncate mixed-case table %s on node %s: %v", + "Failed to truncate mixed-case table %s: %v", qualifiedTableName, - nodeName, err, ) } @@ -636,7 +584,7 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { log.Printf("Mixed-case table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -668,30 +616,30 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, compositeKey bool) { qualifiedTableName, ) for _, row := range commonRows { - _, err := pgCluster.Node1Pool.Exec(ctx, insertSQL, + _, err := env.N1Pool.Exec(ctx, insertSQL, row["ID"], row["FirstName"], row["LastName"], row["EmailAddress"]) if err != nil { t.Fatalf( "Failed to insert data into mixed-case table %s on node %s: %v", qualifiedTableName, - serviceN1, err) + env.ServiceN1, err) } - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, + _, err = env.N2Pool.Exec(ctx, insertSQL, row["ID"], row["FirstName"], row["LastName"], row["EmailAddress"]) if err != nil { t.Fatalf( "Failed to insert data into mixed-case table %s on node %s: %v", qualifiedTableName, - serviceN2, err) + env.ServiceN2, err) } } log.Printf("Data loaded into mixed-case table %s on both nodes", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask( + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask( t, - fmt.Sprintf("%s.%s", testSchema, tableName), + fmt.Sprintf("%s.%s", env.Schema, tableName), nodesToCompare, ) @@ -1153,13 +1101,13 @@ CREATE TABLE IF NOT EXISTS %s.%s ( log.Println("TestTableDiff_UUIDColumn completed.") } -func testTableDiff_TableFiltering(t *testing.T) { +func testTableDiff_TableFiltering(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -1188,45 +1136,39 @@ func testTableDiff_TableFiltering(t *testing.T) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) if err != nil { - t.Fatalf("Failed to begin transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to begin transaction on node %s: %v", env.ServiceN2, err) } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode on node %s: %v", serviceN2, err) - } - for _, mod := range updatesNode2 { - updateSQL := fmt.Sprintf( - "UPDATE %s.%s SET %s = $1 WHERE index = $2", - testSchema, - tableName, - mod.field, - ) - _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - if err != nil { - t.Fatalf( - "Failed to update row with index %d on node %s for filter test: %v", - mod.indexVal, - serviceN2, - err, + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range updatesNode2 { + updateSQL := fmt.Sprintf( + "UPDATE %s.%s SET %s = $1 WHERE index = $2", + env.Schema, + tableName, + mod.field, ) + _, err := tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + if err != nil { + t.Fatalf( + "Failed to update row with index %d on node %s for filter test: %v", + mod.indexVal, + env.ServiceN2, + err, + ) + } } - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode on node %s: %v", serviceN2, err) - } + }) if err = tx.Commit(ctx); err != nil { - t.Fatalf("Failed to commit transaction on node %s: %v", serviceN2, err) + t.Fatalf("Failed to commit transaction on node %s: %v", env.ServiceN2, err) } - log.Printf("Data modified on %s for filter test", serviceN2) + log.Printf("Data modified on %s for filter test", env.ServiceN2) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.TableFilter = "index <= 100" err = tdTask.RunChecks(false) @@ -1238,10 +1180,7 @@ func testTableDiff_TableFiltering(t *testing.T) { t.Fatalf("ExecuteTask failed for table filtering test: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -1253,22 +1192,22 @@ func testTableDiff_TableFiltering(t *testing.T) { } expectedFilteredModifications := 3 - if len(nodeDiffs.Rows[serviceN1]) != expectedFilteredModifications { + if len(nodeDiffs.Rows[env.ServiceN1]) != expectedFilteredModifications { t.Errorf( "Expected %d modified rows (original) for %s due to filter, but got %d. Rows: %+v", expectedFilteredModifications, - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != expectedFilteredModifications { + if len(nodeDiffs.Rows[env.ServiceN2]) != expectedFilteredModifications { t.Errorf( "Expected %d modified rows (updated) for %s due to filter, but got %d. Rows: %+v", expectedFilteredModifications, - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } if tdTask.DiffResult.Summary.DiffRowsCount[pairKey] != expectedFilteredModifications { @@ -1284,7 +1223,7 @@ func testTableDiff_TableFiltering(t *testing.T) { foundIndex1 := false foundIndex2 := false foundIndex3 := false - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { indexVal, _ := row.Get("index") email, _ := row.Get("email") firstName, _ := row.Get("first_name") @@ -1315,12 +1254,12 @@ func testTableDiff_TableFiltering(t *testing.T) { log.Println("TestTableDiff_TableFiltering completed.") } -func testTableDiff_TableFilterNoRows(t *testing.T) { +func testTableDiff_TableFilterNoRows(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodesToCompare := []string{serviceN1, serviceN2} + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.TableFilter = "index < 0" // no rows satisfy this require.NoError(t, tdTask.RunChecks(false)) @@ -1329,13 +1268,13 @@ func testTableDiff_TableFilterNoRows(t *testing.T) { require.Contains(t, err.Error(), "table filter produced no rows") } -func testTableDiff_MaxDiffRowsLimit(t *testing.T) { +func testTableDiff_MaxDiffRowsLimit(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -1352,25 +1291,21 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T) { {indexVal: 4, email: "limit-test-4@example.com"}, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - updateSQL := fmt.Sprintf("UPDATE %s.%s SET email = $1 WHERE index = $2", testSchema, tableName) - for _, mod := range modifications { - _, err := tx.Exec(ctx, updateSQL, mod.email, mod.indexVal) - require.NoErrorf(t, err, "failed to update row with index %d on node %s", mod.indexVal, serviceN2) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + updateSQL := fmt.Sprintf("UPDATE %s.%s SET email = $1 WHERE index = $2", env.Schema, tableName) + for _, mod := range modifications { + _, err := tx.Exec(ctx, updateSQL, mod.email, mod.indexVal) + require.NoErrorf(t, err, "failed to update row with index %d on node %s", mod.indexVal, env.ServiceN2) + } + }) require.NoError(t, tx.Commit(ctx)) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) tdTask.BlockSize = 10 tdTask.CompareUnitSize = 1 tdTask.MaxDiffRows = 2 @@ -1381,10 +1316,7 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T) { err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() if _, ok := tdTask.DiffResult.NodeDiffs[pairKey]; !ok { t.Fatalf("expected diffs for pair %s when enforcing max diff rows", pairKey) diff --git a/tests/integration/table_repair_test.go b/tests/integration/table_repair_test.go index 96a0c62..56940f4 100644 --- a/tests/integration/table_repair_test.go +++ b/tests/integration/table_repair_test.go @@ -39,87 +39,13 @@ import ( // - 2 common rows modified on node2 (IDs 1, 2) func setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string, composite bool) { t.Helper() - log.Println("Setting up data divergence for", qualifiedTableName) - - // Truncate on both nodes - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) - } - - // Insert common rows (always populate customer_id to support composite PK later) - for i := 1; i <= 5; i++ { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("FirstName%d", i), fmt.Sprintf("LastName%d", i), fmt.Sprintf("email%d@example.com", i)) - require.NoError(t, err) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - } - } - - // Insert rows only on node1 (always include customer_id) - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 1001; i <= 1002; i++ { - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1OnlyFirst%d", i), fmt.Sprintf("N1OnlyLast%d", i), fmt.Sprintf("n1.only%d@example.com", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Insert rows only on node2 (always include customer_id) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 2001; i <= 2002; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), - i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2OnlyFirst%d", i), fmt.Sprintf("N2OnlyLast%d", i), fmt.Sprintf("n2.only%d@example.com", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Modify rows on node2 - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 1; i <= 2; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = $1 WHERE index = $2", qualifiedTableName), - fmt.Sprintf("modified.email%d@example.com", i), i) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - log.Println("Data divergence setup complete.") + newSpockEnv().setupDivergence(t, ctx, qualifiedTableName, composite) } // runTableDiff executes a table-diff task and returns the path to the latest diff file. func runTableDiff(t *testing.T, qualifiedTableName string, nodesToCompare []string) string { t.Helper() - // Clean up any old diff files to ensure we get the correct one - files, _ := filepath.Glob("*_diffs-*.json") - for _, f := range files { - os.Remove(f) - } - - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) - err := tdTask.RunChecks(false) - require.NoError(t, err, "table-diff validation failed") - err = tdTask.ExecuteTask() - require.NoError(t, err, "table-diff execution failed") - - latestDiffFile := getLatestDiffFile(t) - require.NotEmpty(t, latestDiffFile, "No diff file was generated") - - return latestDiffFile + return newSpockEnv().runTableDiff(t, qualifiedTableName, nodesToCompare) } // getLatestDiffFile finds the most recently modified diff file. @@ -145,16 +71,7 @@ func getLatestDiffFile(t *testing.T) string { // assertNoTableDiff runs a diff and asserts that there are no differences. func assertNoTableDiff(t *testing.T, qualifiedTableName string) { t.Helper() - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) - - err := tdTask.RunChecks(false) - require.NoError(t, err, "assertNoTableDiff: validation failed") - - err = tdTask.ExecuteTask() - require.NoError(t, err, "assertNoTableDiff: execution failed") - - assert.Empty(t, tdTask.DiffResult.NodeDiffs, "Expected no differences after repair, but diffs were found") + newSpockEnv().assertNoTableDiff(t, qualifiedTableName) } // captureOutput executes a function while capturing its stdout and stderr. @@ -196,42 +113,7 @@ func getTableCount(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualif func setupNullDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { t.Helper() - log.Println("Setting up null divergence for", qualifiedTableName) - - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) - } - - insertSQL := fmt.Sprintf( - "INSERT INTO %s (index, customer_id, first_name, last_name, city) VALUES ($1, $2, $3, $4, $5)", - qualifiedTableName, - ) - - // Node1 rows: missing city for id 1, missing first_name for id 2 - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - - // Node2 rows: missing last_name for id 1, missing city for id 2 - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + newSpockEnv().setupNullDivergence(t, ctx, qualifiedTableName) } type nameCity struct { @@ -252,9 +134,9 @@ func getNameCity(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualifie return nameCity{first: first, last: last, city: city} } -func TestTableRepair_UnidirectionalDefault(t *testing.T) { +func testTableRepair_UnidirectionalDefault(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -268,14 +150,14 @@ func TestTableRepair_UnidirectionalDefault(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -287,34 +169,40 @@ func TestTableRepair_UnidirectionalDefault(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) diffData, err := os.ReadFile(diffFile) require.NoError(t, err) - assert.Contains(t, string(diffData), "_spock_metadata_", "Diff file should contain spock metadata before repair") + if env.HasSpock { + assert.Contains(t, string(diffData), "_spock_metadata_", "Diff file should contain spock metadata before repair") + } - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err = repairTask.Run(false) require.NoError(t, err, "Table repair failed") log.Println("Verifying repair for TestTableRepair_UnidirectionalDefault") - assertNoTableDiff(t, qualifiedTableName) - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1, count2, "Row counts should be equal after default repair") assert.Equal(t, 7, count1, "Expected 7 rows on node1") }) } } -func TestTableRepair_InsertOnly(t *testing.T) { +func TestTableRepair_UnidirectionalDefault(t *testing.T) { + testTableRepair_UnidirectionalDefault(t, newSpockEnv()) +} + +func testTableRepair_InsertOnly(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -328,14 +216,14 @@ func TestTableRepair_InsertOnly(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -347,49 +235,50 @@ func TestTableRepair_InsertOnly(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.InsertOnly = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (insert-only) failed") - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 7, count1) assert.Equal(t, 9, count2) - tdTask := newTestTableDiffTask(t, qualifiedTableName, []string{serviceN1, serviceN2}) + tdTask := env.newTableDiffTask(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) err = tdTask.RunChecks(false) require.NoError(t, err) err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() // After insert-only repair with n1 as source (DO NOTHING on conflict): // - Rows 1001, 1002 (n1-only) are inserted into n2 // - Rows 1, 2 (conflicting) are NOT overwritten on n2 (DO NOTHING), so they still differ // - n1 has 2 differing rows (its versions of rows 1, 2) // - n2 has 4 differing rows (its versions of rows 1, 2 + unique rows 2001, 2002) // - Total diff count is 4 - assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN1])) - assert.Equal(t, 4, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN2])) + assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN1])) + assert.Equal(t, 4, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN2])) assert.Equal(t, 4, tdTask.DiffResult.Summary.DiffRowsCount[pairKey]) }) } } -func TestTableRepair_UpsertOnly(t *testing.T) { +func TestTableRepair_InsertOnly(t *testing.T) { + testTableRepair_InsertOnly(t, newSpockEnv()) +} + +func testTableRepair_UpsertOnly(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -403,14 +292,14 @@ func TestTableRepair_UpsertOnly(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -422,43 +311,44 @@ func TestTableRepair_UpsertOnly(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.UpsertOnly = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (upsert-only) failed") - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 7, count1) assert.Equal(t, 9, count2) - tdTask := newTestTableDiffTask(t, qualifiedTableName, []string{serviceN1, serviceN2}) + tdTask := env.newTableDiffTask(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) err = tdTask.RunChecks(false) require.NoError(t, err) err = tdTask.ExecuteTask() require.NoError(t, err) - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } - assert.Equal(t, 0, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN1])) - assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[serviceN2])) + pairKey := env.pairKey() + assert.Equal(t, 0, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN1])) + assert.Equal(t, 2, len(tdTask.DiffResult.NodeDiffs[pairKey].Rows[env.ServiceN2])) assert.Equal(t, 2, tdTask.DiffResult.Summary.DiffRowsCount[pairKey]) }) } } -func TestTableRepair_Bidirectional(t *testing.T) { +func TestTableRepair_UpsertOnly(t *testing.T) { + testTableRepair_UpsertOnly(t, newSpockEnv()) +} + +func testTableRepair_Bidirectional(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -472,14 +362,14 @@ func TestTableRepair_Bidirectional(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -492,55 +382,53 @@ func TestTableRepair_Bidirectional(t *testing.T) { t.Cleanup(tc.teardown) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) log.Println("Setting up data for bidirectional test") - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) - } - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 3001; i <= 3003; i++ { - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) - require.NoError(t, err) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + env.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err) + }) } - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N1Pool, func() { + for i := 3001; i <= 3003; i++ { + _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) + require.NoError(t, err) + } + }) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for i := 4001; i <= 4002; i++ { - _, err := pgCluster.Node2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) - require.NoError(t, err) - } - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N2Pool, func() { + for i := 4001; i <= 4002; i++ { + _, err := env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) + require.NoError(t, err) + } + }) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.Bidirectional = true - err = repairTask.Run(false) + err := repairTask.Run(false) require.NoError(t, err, "Table repair (bidirectional) failed") log.Println("Verifying repair for TestTableRepair_Bidirectional") - assertNoTableDiff(t, qualifiedTableName) - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, 5, count1, "Expected 5 rows on node1 after bidirectional repair") assert.Equal(t, 5, count2, "Expected 5 rows on node2 after bidirectional repair") }) } } -func TestTableRepair_DryRun(t *testing.T) { +func TestTableRepair_Bidirectional(t *testing.T) { + testTableRepair_Bidirectional(t, newSpockEnv()) +} + +func testTableRepair_DryRun(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -554,14 +442,14 @@ func TestTableRepair_DryRun(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -573,14 +461,14 @@ func TestTableRepair_DryRun(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) - count1Before, count2Before := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName), getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1Before, count2Before := getTableCount(t, ctx, env.N1Pool, qualifiedTableName), getTableCount(t, ctx, env.N2Pool, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.DryRun = true output := captureOutput(t, func() { @@ -589,18 +477,22 @@ func TestTableRepair_DryRun(t *testing.T) { }) assert.Contains(t, output, "DRY RUN") - assert.Contains(t, output, fmt.Sprintf("Node %s: Would attempt to UPSERT 4 rows and DELETE 2 rows.", serviceN2)) + assert.Contains(t, output, fmt.Sprintf("Node %s: Would attempt to UPSERT 4 rows and DELETE 2 rows.", env.ServiceN2)) - count1After, count2After := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName), getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1After, count2After := getTableCount(t, ctx, env.N1Pool, qualifiedTableName), getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1Before, count1After, "Node1 count should not change after dry run") assert.Equal(t, count2Before, count2After, "Node2 count should not change after dry run") }) } } -func TestTableRepair_GenerateReport(t *testing.T) { +func TestTableRepair_DryRun(t *testing.T) { + testTableRepair_DryRun(t, newSpockEnv()) +} + +func testTableRepair_GenerateReport(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() reportDir := "reports" @@ -615,14 +507,14 @@ func TestTableRepair_GenerateReport(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -636,14 +528,14 @@ func TestTableRepair_GenerateReport(t *testing.T) { t.Cleanup(func() { os.RemoveAll(reportDir) - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) }) os.RemoveAll(reportDir) - setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) repairTask.GenerateReport = true err := repairTask.Run(false) @@ -669,11 +561,15 @@ func TestTableRepair_GenerateReport(t *testing.T) { require.NoError(t, err) assert.True(t, len(data) > 0, "Report file is empty") assert.Contains(t, string(data), "\"operation_type\": \"table-repair\"") - assert.Contains(t, string(data), fmt.Sprintf("\"source_of_truth\": \"%s\"", serviceN1)) + assert.Contains(t, string(data), fmt.Sprintf("\"source_of_truth\": \"%s\"", env.ServiceN1)) }) } } +func TestTableRepair_GenerateReport(t *testing.T) { + testTableRepair_GenerateReport(t, newSpockEnv()) +} + func TestTableRepair_VariousDataTypes(t *testing.T) { ctx := context.Background() tableName := "data_type_repair" @@ -853,9 +749,9 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.Equal(t, 0, row4Count, "Row present only on node2 should be deleted") } -func TestTableRepair_FixNulls(t *testing.T) { +func testTableRepair_FixNulls(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -869,14 +765,14 @@ func TestTableRepair_FixNulls(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -888,23 +784,23 @@ func TestTableRepair_FixNulls(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupNullDivergence(t, ctx, qualifiedTableName) + env.setupNullDivergence(t, ctx, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true err := repairTask.Run(false) require.NoError(t, err, "Table repair (fix-nulls) failed") - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) - row1N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 1, "CUST-1") - row1N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 1, "CUST-1") - row2N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 2, "CUST-2") - row2N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 2, "CUST-2") + row1N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 1, "CUST-1") + row1N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 1, "CUST-1") + row2N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 2, "CUST-2") + row2N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 2, "CUST-2") require.NotNil(t, row1N1.city) require.NotNil(t, row1N2.last) @@ -919,9 +815,13 @@ func TestTableRepair_FixNulls(t *testing.T) { } } -func TestTableRepair_FixNulls_DryRun(t *testing.T) { +func TestTableRepair_FixNulls(t *testing.T) { + testTableRepair_FixNulls(t, newSpockEnv()) +} + +func testTableRepair_FixNulls_DryRun(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -935,14 +835,14 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -954,10 +854,10 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { tc.setup() t.Cleanup(tc.teardown) - setupNullDivergence(t, ctx, qualifiedTableName) + env.setupNullDivergence(t, ctx, qualifiedTableName) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true repairTask.DryRun = true @@ -971,10 +871,10 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { assert.Contains(t, output, "Would update") // Ensure data unchanged - row1N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 1, "CUST-1") - row1N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 1, "CUST-1") - row2N1 := getNameCity(t, ctx, pgCluster.Node1Pool, qualifiedTableName, 2, "CUST-2") - row2N2 := getNameCity(t, ctx, pgCluster.Node2Pool, qualifiedTableName, 2, "CUST-2") + row1N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 1, "CUST-1") + row1N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 1, "CUST-1") + row2N1 := getNameCity(t, ctx, env.N1Pool, qualifiedTableName, 2, "CUST-2") + row2N2 := getNameCity(t, ctx, env.N2Pool, qualifiedTableName, 2, "CUST-2") assert.Nil(t, row1N1.city) assert.Nil(t, row1N2.last) @@ -982,7 +882,7 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { assert.Nil(t, row2N2.city) // Cleanup: actually repair to leave table consistent for subsequent tests - fixTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + fixTask := env.newTableRepairTask("", qualifiedTableName, diffFile) fixTask.SourceOfTruth = "" fixTask.FixNulls = true err := fixTask.Run(false) @@ -991,6 +891,10 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { } } +func TestTableRepair_FixNulls_DryRun(t *testing.T) { + testTableRepair_FixNulls_DryRun(t, newSpockEnv()) +} + // TestTableRepair_TimestampAndTimeTypes verifies that the full diff→repair // pipeline works correctly for timestamp, timestamptz, time, and timetz columns. // This guards against pgx sending the wrong OID for timestamp-without-tz @@ -1151,7 +1055,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.Equal(t, r1n1.timetz, r1n2.timetz, "row 1 col_timetz mismatch") } -// TestTableRepair_FixNulls_BidirectionalUpdate tests that when both nodes have NULLs +// testTableRepair_FixNulls_BidirectionalUpdate tests that when both nodes have NULLs // in different columns for the same row, fix-nulls performs bidirectional updates. // This verifies the behavior discussed in code review: each node updates the other // with its non-NULL values. @@ -1165,9 +1069,9 @@ CREATE TABLE IF NOT EXISTS %s.%s ( // // Node1: {id: 1, col_a: "value_a", col_b: "value_b", col_c: "value_c"} // Node2: {id: 1, col_a: "value_a", col_b: "value_b", col_c: "value_c"} -func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { +func testTableRepair_FixNulls_BidirectionalUpdate(t *testing.T, env *testEnv) { tableName := "customers" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) ctx := context.Background() testCases := []struct { @@ -1181,14 +1085,14 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { name: "composite_primary_key", composite: true, setup: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := alterTableToCompositeKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, teardown: func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - _err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + _err := revertTableToSimpleKey(ctx, pool, env.Schema, tableName) require.NoError(t, _err) } }, @@ -1203,14 +1107,11 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { log.Println("Setting up bidirectional NULL divergence for", qualifiedTableName) // Clean table on both nodes - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err, "Failed to enable repair mode on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) - require.NoError(t, err, "Failed to truncate table on node %s", nodeName) - _, err = pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err, "Failed to disable repair mode on %s", nodeName) + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + env.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "Failed to truncate table") + }) } // Insert row with complementary NULLs on each node @@ -1222,38 +1123,34 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { ) // Node1 data - _, err := pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - // Row 1 on Node1: NULL first_name and city - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") - require.NoError(t, err) - // Row 2 on Node1: NULL last_name and email - _, err = pgCluster.Node1Pool.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) - require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N1Pool, func() { + // Row 1 on Node1: NULL first_name and city + _, err := env.N1Pool.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") + require.NoError(t, err) + // Row 2 on Node1: NULL last_name and email + _, err = env.N1Pool.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) + require.NoError(t, err) + }) // Node2 data - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - // Row 1 on Node2: NULL last_name and email - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) - require.NoError(t, err) - // Row 2 on Node2: NULL first_name and city - _, err = pgCluster.Node2Pool.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") - require.NoError(t, err) - _, err = pgCluster.Node2Pool.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairMode(t, ctx, env.N2Pool, func() { + // Row 1 on Node2: NULL last_name and email + _, err := env.N2Pool.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) + require.NoError(t, err) + // Row 2 on Node2: NULL first_name and city + _, err = env.N2Pool.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") + require.NoError(t, err) + }) // Run table-diff to detect the NULL differences - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // Run fix-nulls repair - repairTask := newTestTableRepairTask("", qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask("", qualifiedTableName, diffFile) repairTask.SourceOfTruth = "" repairTask.FixNulls = true - err = repairTask.Run(false) + err := repairTask.Run(false) require.NoError(t, err, "Table repair (fix-nulls bidirectional) failed") // Verify bidirectional updates happened @@ -1276,8 +1173,8 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { } // Check Row 1 (id=100) on both nodes - row1N1 := getFullRow(pgCluster.Node1Pool, 100, "CUST-100") - row1N2 := getFullRow(pgCluster.Node2Pool, 100, "CUST-100") + row1N1 := getFullRow(env.N1Pool, 100, "CUST-100") + row1N2 := getFullRow(env.N2Pool, 100, "CUST-100") // Node1's NULLs (first_name, city) should be filled from Node2 require.NotNil(t, row1N1.firstName, "Node1 row 100 first_name should be filled from Node2") @@ -1302,8 +1199,8 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { assert.Equal(t, "email100@example.com", *row1N2.email) // Check Row 2 (id=200) on both nodes - row2N1 := getFullRow(pgCluster.Node1Pool, 200, "CUST-200") - row2N2 := getFullRow(pgCluster.Node2Pool, 200, "CUST-200") + row2N1 := getFullRow(env.N1Pool, 200, "CUST-200") + row2N2 := getFullRow(env.N2Pool, 200, "CUST-200") // Node1's NULLs (last_name, email) should be filled from Node2 require.NotNil(t, row2N1.lastName, "Node1 row 200 last_name should be filled from Node2") @@ -1328,13 +1225,17 @@ func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { assert.Equal(t, "email200@example.com", *row2N2.email) // Verify no diffs remain - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) log.Println("Bidirectional fix-nulls test completed successfully") }) } } +func TestTableRepair_FixNulls_BidirectionalUpdate(t *testing.T) { + testTableRepair_FixNulls_BidirectionalUpdate(t, newSpockEnv()) +} + // TestTableRepair_PreserveOrigin tests that the preserve-origin flag correctly preserves // both replication origin metadata and commit timestamps during table repair recovery operations. // This test verifies the fix for maintaining original transaction metadata to prevent diff --git a/tests/integration/test_env_test.go b/tests/integration/test_env_test.go new file mode 100644 index 0000000..4d2676b --- /dev/null +++ b/tests/integration/test_env_test.go @@ -0,0 +1,372 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package integration + +import ( + "context" + "fmt" + "log" + "math" + "os" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/repair" + "github.com/pgedge/ace/pkg/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testEnv abstracts the differences between spock and native PG test environments, +// allowing the same test logic to run against both. +type testEnv struct { + N1Pool *pgxpool.Pool + N2Pool *pgxpool.Pool + N3Pool *pgxpool.Pool // nil for native PG (only 2 nodes) + + ServiceN1 string + ServiceN2 string + ServiceN3 string // empty for native PG + + DBName string + DBUser string + DBPassword string + + ClusterName string + ClusterNodes []map[string]any + + HasSpock bool + Schema string +} + +func newSpockEnv() *testEnv { + return &testEnv{ + N1Pool: pgCluster.Node1Pool, + N2Pool: pgCluster.Node2Pool, + N3Pool: pgCluster.Node3Pool, + ServiceN1: serviceN1, + ServiceN2: serviceN2, + ServiceN3: serviceN3, + DBName: dbName, + DBUser: pgEdgeUser, + DBPassword: pgEdgePassword, + ClusterName: pgCluster.ClusterName, + ClusterNodes: pgCluster.ClusterNodes, + HasSpock: true, + Schema: testSchema, + } +} + +func newNativeEnv(state *nativeClusterState) *testEnv { + return &testEnv{ + N1Pool: state.n1Pool, + N2Pool: state.n2Pool, + N3Pool: nil, + ServiceN1: nativeServiceN1, + ServiceN2: nativeServiceN2, + ServiceN3: "", + DBName: nativeDBName, + DBUser: nativeUser, + DBPassword: nativePassword, + ClusterName: nativeClusterName, + ClusterNodes: []map[string]any{ + { + "Name": nativeServiceN1, + "PublicIP": state.n1Host, + "Port": state.n1Port, + "DBUser": nativeUser, + "DBPassword": nativePassword, + "DBName": nativeDBName, + }, + { + "Name": nativeServiceN2, + "PublicIP": state.n2Host, + "Port": state.n2Port, + "DBUser": nativeUser, + "DBPassword": nativePassword, + "DBName": nativeDBName, + }, + }, + HasSpock: false, + Schema: testSchema, + } +} + +// withRepairMode wraps a function with spock.repair_mode(true/false) when spock +// is available. On native PG (no replication), this is a no-op. +func (e *testEnv) withRepairMode(t *testing.T, ctx context.Context, pool *pgxpool.Pool, fn func()) { + t.Helper() + if e.HasSpock { + _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") + require.NoError(t, err, "enable spock.repair_mode") + defer func() { + _, err := pool.Exec(ctx, "SELECT spock.repair_mode(false)") + require.NoError(t, err, "disable spock.repair_mode") + }() + } + fn() +} + +// withRepairModeTx is like withRepairMode but operates on a transaction. +func (e *testEnv) withRepairModeTx(t *testing.T, ctx context.Context, tx pgx.Tx, fn func()) { + t.Helper() + if e.HasSpock { + _, err := tx.Exec(ctx, "SELECT spock.repair_mode(true)") + require.NoError(t, err, "enable spock.repair_mode in tx") + defer func() { + _, err := tx.Exec(ctx, "SELECT spock.repair_mode(false)") + require.NoError(t, err, "disable spock.repair_mode in tx") + }() + } + fn() +} + +// pools returns the two primary pools (n1, n2) used by most tests. +func (e *testEnv) pools() []*pgxpool.Pool { + return []*pgxpool.Pool{e.N1Pool, e.N2Pool} +} + +// newTableDiffTask creates a TableDiffTask configured for this environment. +func (e *testEnv) newTableDiffTask(t *testing.T, qualifiedTableName string, nodes []string) *diff.TableDiffTask { + t.Helper() + task := diff.NewTableDiffTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.QualifiedTableName = qualifiedTableName + task.Nodes = strings.Join(nodes, ",") + task.Output = "json" + task.BlockSize = 1000 + task.CompareUnitSize = 100 + task.ConcurrencyFactor = 1 + task.MaxDiffRows = math.MaxInt64 + + task.DiffResult = types.DiffOutput{ + NodeDiffs: make(map[string]types.DiffByNodePair), + Summary: types.DiffSummary{ + Nodes: nodes, + BlockSize: task.BlockSize, + CompareUnitSize: task.CompareUnitSize, + ConcurrencyFactor: task.ConcurrencyFactor, + DiffRowsCount: make(map[string]int), + }, + } + return task +} + +// newTableRepairTask creates a TableRepairTask configured for this environment. +func (e *testEnv) newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath string) *repair.TableRepairTask { + task := repair.NewTableRepairTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.SourceOfTruth = sourceOfTruthNode + task.QualifiedTableName = qualifiedTableName + task.DiffFilePath = diffFilePath + task.Nodes = "all" + return task +} + +// setupDivergence prepares a table with a known set of differences between n1 and n2. +// - 5 common rows (index 1-5) +// - 2 rows only on n1 (index 1001, 1002) +// - 2 rows only on n2 (index 2001, 2002) +// - 2 common rows modified on n2 (index 1, 2) +func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string, composite bool) { + t.Helper() + log.Println("Setting up data divergence for", qualifiedTableName) + + // Truncate on both nodes + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate table") + }) + } + + // Insert common rows + for i := 1; i <= 5; i++ { + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("FirstName%d", i), fmt.Sprintf("LastName%d", i), fmt.Sprintf("email%d@example.com", i)) + require.NoError(t, err) + }) + } + } + + // Rows only on n1 + e.withRepairMode(t, ctx, e.N1Pool, func() { + for i := 1001; i <= 1002; i++ { + _, err := e.N1Pool.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1OnlyFirst%d", i), fmt.Sprintf("N1OnlyLast%d", i), fmt.Sprintf("n1.only%d@example.com", i)) + require.NoError(t, err) + } + }) + + // Rows only on n2 + e.withRepairMode(t, ctx, e.N2Pool, func() { + for i := 2001; i <= 2002; i++ { + _, err := e.N2Pool.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), + i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2OnlyFirst%d", i), fmt.Sprintf("N2OnlyLast%d", i), fmt.Sprintf("n2.only%d@example.com", i)) + require.NoError(t, err) + } + }) + + // Modify rows on n2 + e.withRepairMode(t, ctx, e.N2Pool, func() { + for i := 1; i <= 2; i++ { + _, err := e.N2Pool.Exec(ctx, + fmt.Sprintf("UPDATE %s SET email = $1 WHERE index = $2", qualifiedTableName), + fmt.Sprintf("modified.email%d@example.com", i), i) + require.NoError(t, err) + } + }) + + log.Println("Data divergence setup complete.") +} + +// setupNullDivergence creates divergence where the same rows exist on both nodes +// but with different columns set to NULL. +func (e *testEnv) setupNullDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { + t.Helper() + log.Println("Setting up null divergence for", qualifiedTableName) + + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate table") + }) + } + + insertSQL := fmt.Sprintf( + "INSERT INTO %s (index, customer_id, first_name, last_name, city) VALUES ($1, $2, $3, $4, $5)", + qualifiedTableName, + ) + + // Node1: missing city for id 1, missing first_name for id 2 + e.withRepairMode(t, ctx, e.N1Pool, func() { + _, err := e.N1Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) + require.NoError(t, err) + _, err = e.N1Pool.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") + require.NoError(t, err) + }) + + // Node2: missing last_name for id 1, missing city for id 2 + e.withRepairMode(t, ctx, e.N2Pool, func() { + _, err := e.N2Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") + require.NoError(t, err) + _, err = e.N2Pool.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) + require.NoError(t, err) + }) +} + +// runTableDiff executes a table-diff and returns the path to the latest diff file. +func (e *testEnv) runTableDiff(t *testing.T, qualifiedTableName string, nodesToCompare []string) string { + t.Helper() + files, _ := filepath.Glob("*_diffs-*.json") + for _, f := range files { + os.Remove(f) + } + + tdTask := e.newTableDiffTask(t, qualifiedTableName, nodesToCompare) + err := tdTask.RunChecks(false) + require.NoError(t, err, "table-diff validation failed") + err = tdTask.ExecuteTask() + require.NoError(t, err, "table-diff execution failed") + + latestDiffFile := getLatestDiffFile(t) + require.NotEmpty(t, latestDiffFile, "No diff file was generated") + return latestDiffFile +} + +// assertNoTableDiff runs a diff and asserts that there are no differences. +func (e *testEnv) assertNoTableDiff(t *testing.T, qualifiedTableName string) { + t.Helper() + nodesToCompare := []string{e.ServiceN1, e.ServiceN2} + tdTask := e.newTableDiffTask(t, qualifiedTableName, nodesToCompare) + + err := tdTask.RunChecks(false) + require.NoError(t, err, "assertNoTableDiff: validation failed") + err = tdTask.ExecuteTask() + require.NoError(t, err, "assertNoTableDiff: execution failed") + + assert.Empty(t, tdTask.DiffResult.NodeDiffs, "Expected no differences after repair, but diffs were found") +} + +// repairTable finds the latest diff file and runs a repair. +func (e *testEnv) repairTable(t *testing.T, qualifiedTableName, sourceOfTruthNode string) { + t.Helper() + files, err := filepath.Glob("*_diffs-*.json") + if err != nil { + t.Fatalf("Failed to find diff files: %v", err) + } + if len(files) == 0 { + log.Println("No diff file found to repair from, skipping repair.") + return + } + + sort.Slice(files, func(i, j int) bool { + fi, errI := os.Stat(files[i]) + if errI != nil { + t.Logf("Warning: could not stat file %s: %v", files[i], errI) + return false + } + fj, errJ := os.Stat(files[j]) + if errJ != nil { + t.Logf("Warning: could not stat file %s: %v", files[j], errJ) + return false + } + return fi.ModTime().After(fj.ModTime()) + }) + + latestDiffFile := files[0] + log.Printf("Using latest diff file for repair: %s", latestDiffFile) + + repairTask := e.newTableRepairTask(sourceOfTruthNode, qualifiedTableName, latestDiffFile) + time.Sleep(2 * time.Second) + if err := repairTask.Run(false); err != nil { + t.Fatalf("Failed to repair table: %v", err) + } + log.Printf("Table '%s' repaired successfully using %s as source of truth.", qualifiedTableName, sourceOfTruthNode) +} + +// resetSharedTable truncates and reloads the given table from CSV on n1 and n2. +func (e *testEnv) resetSharedTable(t *testing.T, tableName string) { + t.Helper() + ctx := context.Background() + qualifiedTableName := fmt.Sprintf("%s.%s", e.Schema, tableName) + csvPath, err := filepath.Abs(defaultCsvFilePath + tableName + ".csv") + require.NoError(t, err) + for _, pool := range e.pools() { + e.withRepairMode(t, ctx, pool, func() { + _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + require.NoError(t, err, "truncate %s", qualifiedTableName) + }) + require.NoError(t, loadDataFromCSV(ctx, pool, e.Schema, tableName, csvPath), "load CSV into %s", qualifiedTableName) + } +} + +// pairKey returns the canonical node pair key for diff results. +func (e *testEnv) pairKey() string { + if strings.Compare(e.ServiceN1, e.ServiceN2) > 0 { + return e.ServiceN2 + "/" + e.ServiceN1 + } + return e.ServiceN1 + "/" + e.ServiceN2 +} From dc694395927926678e52d05084a5d0c07b170573 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Thu, 12 Mar 2026 17:49:50 -0700 Subject: [PATCH 03/14] feat: make ACE schema configurable in SQL templates Replace hardcoded `spock.` schema prefix with a template function `{{aceSchema}}` that reads from config at render time. Defaults to "spock" for backward compatibility but allows alternate schemas (e.g. "ace") via the existing mtree.schema config key. Co-Authored-By: Claude Opus 4.6 --- db/queries/templates.go | 86 +++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/db/queries/templates.go b/db/queries/templates.go index 46b919d..8a93666 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -11,7 +11,17 @@ package queries -import "text/template" +import ( + "text/template" + + "github.com/pgedge/ace/pkg/config" +) + +// aceTemplateFuncs provides the {{aceSchema}} function to SQL templates. +// The function is evaluated at render time (after config is loaded), not at parse time. +var aceTemplateFuncs = template.FuncMap{ + "aceSchema": func() string { return config.Cfg.MTree.Schema }, +} type Templates struct { EstimateRowCount *template.Template @@ -136,8 +146,8 @@ type Templates struct { var SQLTemplates = Templates{ // A template isn't needed for this query; just keeping the struct uniform - CreateMetadataTable: template.Must(template.New("createMetadataTable").Parse(` - CREATE TABLE IF NOT EXISTS spock.ace_mtree_metadata ( + CreateMetadataTable: template.Must(template.New("createMetadataTable").Funcs(aceTemplateFuncs).Parse(` + CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_mtree_metadata ( schema_name text, table_name text, total_rows bigint, @@ -164,8 +174,8 @@ var SQLTemplates = Templates{ ALTER PUBLICATION {{.PublicationName}} DROP TABLE {{.TableName}} `)), - RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Parse(` - UPDATE spock.ace_cdc_metadata + RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Funcs(aceTemplateFuncs).Parse(` + UPDATE {{aceSchema}}.ace_cdc_metadata SET tables = array_remove(tables, $1) WHERE publication_name = $2 `)), @@ -182,9 +192,9 @@ var SQLTemplates = Templates{ ) `)), - UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Parse(` + UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Funcs(aceTemplateFuncs).Parse(` INSERT INTO - spock.ace_cdc_metadata ( + {{aceSchema}}.ace_cdc_metadata ( publication_name, slot_name, start_lsn, @@ -223,17 +233,17 @@ var SQLTemplates = Templates{ CheckPIDExists: template.Must(template.New("checkPIDExists").Parse(` SELECT pid FROM pg_stat_activity WHERE pid = $1 `)), - DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Parse(` - DROP TABLE IF EXISTS spock.ace_cdc_metadata + DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(` + DROP TABLE IF EXISTS {{aceSchema}}.ace_cdc_metadata `)), - GetCDCMetadata: template.Must(template.New("getCDCMetadata").Parse(` + GetCDCMetadata: template.Must(template.New("getCDCMetadata").Funcs(aceTemplateFuncs).Parse(` SELECT slot_name, start_lsn, tables FROM - spock.ace_cdc_metadata + {{aceSchema}}.ace_cdc_metadata WHERE publication_name = $1 `)), @@ -320,8 +330,8 @@ var SQLTemplates = Templates{ AND mt.node_position = b.node_position; `)), - CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Parse(` - CREATE TABLE IF NOT EXISTS spock.ace_cdc_metadata ( + CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(` + CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_cdc_metadata ( publication_name text PRIMARY KEY, slot_name text, start_lsn text, @@ -775,9 +785,9 @@ var SQLTemplates = Templates{ VALUES (0, $1, {{.StartExpr}}, {{.EndExpr}}); `)), - CreateXORFunction: template.Must(template.New("createXORFunction").Parse(` + CreateXORFunction: template.Must(template.New("createXORFunction").Funcs(aceTemplateFuncs).Parse(` CREATE - OR REPLACE FUNCTION spock.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$ + OR REPLACE FUNCTION {{aceSchema}}.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$ DECLARE result bytea; len int; @@ -808,7 +818,7 @@ var SQLTemplates = Templates{ CREATE OPERATOR # ( LEFTARG = bytea, RIGHTARG = bytea, - PROCEDURE = spock.bytea_xor + PROCEDURE = {{aceSchema}}.bytea_xor ); END IF; END $$; @@ -843,9 +853,9 @@ var SQLTemplates = Templates{ AND c.relname = $2 AND a.attname = $3 `)), - UpdateMetadata: template.Must(template.New("updateMetadata").Parse(` + UpdateMetadata: template.Must(template.New("updateMetadata").Funcs(aceTemplateFuncs).Parse(` INSERT INTO - spock.ace_mtree_metadata ( + {{aceSchema}}.ace_mtree_metadata ( schema_name, table_name, total_rows, @@ -876,8 +886,8 @@ var SQLTemplates = Templates{ hash_version = EXCLUDED.hash_version, last_updated = EXCLUDED.last_updated `)), - DeleteMetadata: template.Must(template.New("deleteMetadata").Parse(` - DELETE FROM spock.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 + DeleteMetadata: template.Must(template.New("deleteMetadata").Funcs(aceTemplateFuncs).Parse(` + DELETE FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 `)), InsertBlockRanges: template.Must(template.New("insertBlockRanges").Parse(` INSERT INTO @@ -1058,11 +1068,11 @@ var SQLTemplates = Templates{ ORDER BY node_position `)), - GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Parse(` + GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Funcs(aceTemplateFuncs).Parse(` SELECT total_rows FROM - spock.ace_mtree_metadata + {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 @@ -1297,11 +1307,11 @@ var SQLTemplates = Templates{ mt.range_start, mt.range_end `)), - GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Parse(` + GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Funcs(aceTemplateFuncs).Parse(` SELECT block_size FROM - spock.ace_mtree_metadata + {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2 @@ -1312,11 +1322,19 @@ var SQLTemplates = Templates{ FROM {{.MtreeTable}} `)), - DropXORFunction: template.Must(template.New("dropXORFunction").Parse(` - DROP FUNCTION IF EXISTS spock.bytea_xor(bytea, bytea) CASCADE + CompareBlocksSQL: template.Must(template.New("compareBlocksSQL").Parse(` + SELECT + * + FROM + {{.TableName}} + WHERE + {{.WhereClause}} + `)), + DropXORFunction: template.Must(template.New("dropXORFunction").Funcs(aceTemplateFuncs).Parse(` + DROP FUNCTION IF EXISTS {{aceSchema}}.bytea_xor(bytea, bytea) CASCADE `)), - DropMetadataTable: template.Must(template.New("dropMetadataTable").Parse(` - DROP TABLE IF EXISTS spock.ace_mtree_metadata CASCADE + DropMetadataTable: template.Must(template.New("dropMetadataTable").Funcs(aceTemplateFuncs).Parse(` + DROP TABLE IF EXISTS {{aceSchema}}.ace_mtree_metadata CASCADE `)), DropMtreeTable: template.Must(template.New("dropMtreeTable").Parse(` DROP TABLE IF EXISTS {{.MtreeTable}} CASCADE @@ -1509,13 +1527,13 @@ var SQLTemplates = Templates{ ORDER BY rs.confirmed_flush_lsn DESC LIMIT 1 `)), - EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Parse(` - ALTER TABLE spock.ace_mtree_metadata + EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Funcs(aceTemplateFuncs).Parse(` + ALTER TABLE {{aceSchema}}.ace_mtree_metadata ADD COLUMN IF NOT EXISTS hash_version int NOT NULL DEFAULT 1 `)), - GetHashVersion: template.Must(template.New("getHashVersion").Parse(` + GetHashVersion: template.Must(template.New("getHashVersion").Funcs(aceTemplateFuncs).Parse(` SELECT COALESCE( - (SELECT hash_version FROM spock.ace_mtree_metadata + (SELECT hash_version FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2), 1 ) @@ -1525,8 +1543,8 @@ var SQLTemplates = Templates{ SET dirty = true WHERE node_level = 0 `)), - UpdateHashVersion: template.Must(template.New("updateHashVersion").Parse(` - UPDATE spock.ace_mtree_metadata + UpdateHashVersion: template.Must(template.New("updateHashVersion").Funcs(aceTemplateFuncs).Parse(` + UPDATE {{aceSchema}}.ace_mtree_metadata SET hash_version = $1, last_updated = current_timestamp WHERE schema_name = $2 AND table_name = $3 `)), From b2dfaacc1d8a67181e59bb5453805d21305df947 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Thu, 12 Mar 2026 18:19:23 -0700 Subject: [PATCH 04/14] feat: adapt merkle tree tests for dual-mode (spock + native PG) Refactor all merkle tree test functions to accept *testEnv, replacing pgCluster globals and spock.repair_mode() calls with env-aware abstractions. All 28 merkle tree subtests (Init, Build, Diff, Merge, Split, CDC, Teardown, NumericScaleInvariance) now run on both spock and native PostgreSQL. Co-Authored-By: Claude Opus 4.6 --- db/queries/templates.go | 1 + tests/integration/cdc_busy_table_test.go | 4 +- tests/integration/helpers_test.go | 5 + tests/integration/merkle_tree_test.go | 602 +++++++++++------------ tests/integration/native_pg_test.go | 35 ++ tests/integration/test_env_test.go | 26 + 6 files changed, 353 insertions(+), 320 deletions(-) diff --git a/db/queries/templates.go b/db/queries/templates.go index 8a93666..c0ce0ff 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -88,6 +88,7 @@ type Templates struct { GetBlockCountSimple *template.Template GetBlockSizeFromMetadata *template.Template GetMaxNodeLevel *template.Template + CompareBlocksSQL *template.Template DropXORFunction *template.Template DropMetadataTable *template.Template diff --git a/tests/integration/cdc_busy_table_test.go b/tests/integration/cdc_busy_table_test.go index c557eb3..8a2771d 100644 --- a/tests/integration/cdc_busy_table_test.go +++ b/tests/integration/cdc_busy_table_test.go @@ -192,7 +192,7 @@ func TestCDCFallbackToSlotLSN(t *testing.T) { err := pgCluster.Node1Pool.QueryRow(ctx, "SELECT ($1::pg_lsn + 16)::text", slotConfirmed).Scan(&bumpedStart) require.NoError(t, err) - _, err = pgCluster.Node1Pool.Exec(ctx, "UPDATE spock.ace_cdc_metadata SET start_lsn = $1 WHERE publication_name = $2", bumpedStart, config.Cfg.MTree.CDC.PublicationName) + _, err = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s.ace_cdc_metadata SET start_lsn = $1 WHERE publication_name = $2", config.Cfg.MTree.Schema), bumpedStart, config.Cfg.MTree.CDC.PublicationName) require.NoError(t, err) _, err = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = email || '.cdc_fallback' WHERE index = 2", qualifiedTableName)) @@ -263,7 +263,7 @@ func currentMetadataLSN(t *testing.T, ctx context.Context) pglogrepl.LSN { func metadataStartLSN(t *testing.T, ctx context.Context) pglogrepl.LSN { t.Helper() var lsnStr string - err := pgCluster.Node1Pool.QueryRow(ctx, "SELECT start_lsn FROM spock.ace_cdc_metadata WHERE publication_name = $1", config.Cfg.MTree.CDC.PublicationName).Scan(&lsnStr) + err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT start_lsn FROM %s.ace_cdc_metadata WHERE publication_name = $1", config.Cfg.MTree.Schema), config.Cfg.MTree.CDC.PublicationName).Scan(&lsnStr) require.NoError(t, err) return mustParseLSN(t, lsnStr) } diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index 0d189d3..c31402f 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -24,6 +24,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/pgedge/ace/internal/consistency/mtree" "github.com/pgedge/ace/internal/consistency/repair" "github.com/stretchr/testify/require" ) @@ -286,6 +287,10 @@ func newTestTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath return newSpockEnv().newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diffFilePath) } +func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { + return newSpockEnv().newMerkleTreeTask(t, qualifiedTableName, nodes) +} + func repairTable(t *testing.T, qualifiedTableName, sourceOfTruthNode string) { t.Helper() newSpockEnv().repairTable(t, qualifiedTableName, sourceOfTruthNode) diff --git a/tests/integration/merkle_tree_test.go b/tests/integration/merkle_tree_test.go index dec1c3f..c416477 100644 --- a/tests/integration/merkle_tree_test.go +++ b/tests/integration/merkle_tree_test.go @@ -27,7 +27,6 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgxpool" - "github.com/pgedge/ace/internal/consistency/mtree" "github.com/pgedge/ace/internal/infra/cdc" "github.com/pgedge/ace/pkg/config" "github.com/pgedge/ace/pkg/types" @@ -35,24 +34,26 @@ import ( ) func TestMerkleTreeSimplePK(t *testing.T) { + env := newSpockEnv() tableName := "customers" - runMerkleTreeTests(t, tableName) + runMerkleTreeTests(t, env, tableName) } func TestMerkleTreeCompositePK(t *testing.T) { + env := newSpockEnv() tableName := "customers" ctx := context.Background() - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { err := alterTableToCompositeKey(ctx, pool, testSchema, tableName) require.NoError(t, err) } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { err := revertTableToSimpleKey(ctx, pool, testSchema, tableName) require.NoError(t, err) } }) - runMerkleTreeTests(t, tableName) + runMerkleTreeTests(t, env, tableName) } func TestMerkleTreeUntilFilter(t *testing.T) { @@ -172,69 +173,58 @@ func TestMerkleTreeUntilFilter(t *testing.T) { require.Equal(t, 0, totalWithUntil, "with --until on frozen baseline, expected 0 diff rows") } -func runMerkleTreeTests(t *testing.T, tableName string) { +func runMerkleTreeTests(t *testing.T, env *testEnv, tableName string) { if tableName == "customers" { - resetSharedTable(t, "customers") + env.resetSharedTable(t, "customers") } t.Run("Init", func(t *testing.T) { - testMerkleTreeInit(t, tableName) + testMerkleTreeInit(t, env, tableName) }) t.Run("Build", func(t *testing.T) { - testMerkleTreeBuild(t, tableName) + testMerkleTreeBuild(t, env, tableName) }) if tableName == "customers" { t.Run("Diff_DataOnlyOnNode1", func(t *testing.T) { - testMerkleTreeDiffDataOnlyOnNode1(t, tableName) + testMerkleTreeDiffDataOnlyOnNode1(t, env, tableName) }) t.Run("Diff_ModifiedRows", func(t *testing.T) { - testMerkleTreeDiffModifiedRows(t, tableName) + testMerkleTreeDiffModifiedRows(t, env, tableName) }) t.Run("Diff_BoundaryModifications", func(t *testing.T) { - testMerkleTreeDiffBoundaryModifications(t, tableName) + testMerkleTreeDiffBoundaryModifications(t, env, tableName) }) t.Run("MergeInitialRanges", func(t *testing.T) { - testMerkleTreeMergeInitialRanges(t, tableName) + testMerkleTreeMergeInitialRanges(t, env, tableName) }) t.Run("MergeMiddleRanges", func(t *testing.T) { - testMerkleTreeMergeMiddleRanges(t, tableName) + testMerkleTreeMergeMiddleRanges(t, env, tableName) }) t.Run("MergeLastRanges", func(t *testing.T) { - testMerkleTreeMergeLastRanges(t, tableName) + testMerkleTreeMergeLastRanges(t, env, tableName) }) t.Run("SplitInitialRanges", func(t *testing.T) { - testMerkleTreeSplitInitialRanges(t, tableName) + testMerkleTreeSplitInitialRanges(t, env, tableName) }) t.Run("SplitMiddleRanges", func(t *testing.T) { - testMerkleTreeSplitMiddleRanges(t, tableName) + testMerkleTreeSplitMiddleRanges(t, env, tableName) }) t.Run("SplitLastRanges", func(t *testing.T) { - testMerkleTreeSplitLastRanges(t, tableName) + testMerkleTreeSplitLastRanges(t, env, tableName) }) t.Run("ContinuousCDC", func(t *testing.T) { - testMerkleTreeContinuousCDC(t, tableName) + testMerkleTreeContinuousCDC(t, env, tableName) }) } t.Run("Teardown", func(t *testing.T) { - testMerkleTreeTeardown(t, tableName) + testMerkleTreeTeardown(t, env, tableName) }) } -func newTestMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { - t.Helper() - task := mtree.NewMerkleTreeTask() - task.ClusterName = "test_cluster" - task.DBName = dbName - task.QualifiedTableName = qualifiedTableName - task.Nodes = strings.Join(nodes, ",") - task.BlockSize = 1000 - return task -} - -func testMerkleTreeInit(t *testing.T, tableName string) { +func testMerkleTreeInit(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") @@ -250,7 +240,7 @@ func testMerkleTreeInit(t *testing.T, tableName string) { cdcPubName := config.Cfg.MTree.CDC.PublicationName cdcSlotName := config.Cfg.MTree.CDC.SlotName - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { require.True(t, schemaExists(t, ctx, pool, aceSchema), "Schema '%s' should exist", aceSchema) require.True(t, functionExists(t, ctx, pool, "bytea_xor", aceSchema), "Function 'bytea_xor' should exist in schema '%s'", aceSchema) require.True(t, tableExists(t, ctx, pool, "ace_cdc_metadata", aceSchema), "Table 'ace_cdc_metadata' should exist in schema '%s'", aceSchema) @@ -259,11 +249,11 @@ func testMerkleTreeInit(t *testing.T, tableName string) { } } -func testMerkleTreeBuild(t *testing.T, tableName string) { +func testMerkleTreeBuild(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -282,7 +272,7 @@ func testMerkleTreeBuild(t *testing.T, tableName string) { aceSchema := config.Cfg.MTree.Schema mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) - pool := pgCluster.Node1Pool + pool := env.N1Pool require.True(t, tableExists(t, ctx, pool, mtreeTableName, aceSchema), "Merkle tree table '%s' should exist", mtreeTableName) @@ -303,11 +293,11 @@ func testMerkleTreeBuild(t *testing.T, tableName string) { require.Greater(t, totalNodeCount, leafNodeCount, "Total nodes should be greater than leaf nodes") } -func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { +func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -318,7 +308,7 @@ func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -327,55 +317,51 @@ func testMerkleTreeDiffDataOnlyOnNode1(t *testing.T, tableName string) { err = mtreeTask.BuildMtree() require.NoError(t, err, "BuildMtree should succeed") - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1", qualifiedTableName) - _, err = tx.Exec(ctx, updateSQL) - require.NoError(t, err) - } else { - var customerID string - err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT customer_id FROM %s WHERE index = 1 LIMIT 1", qualifiedTableName)).Scan(&customerID) - require.NoError(t, err, "could not get customer_id for index 1") - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1 AND customer_id = $1", qualifiedTableName) - _, err = tx.Exec(ctx, updateSQL, customerID) - require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + if mtreeTask.SimplePrimaryKey { + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1", qualifiedTableName) + _, err = tx.Exec(ctx, updateSQL) + require.NoError(t, err) + } else { + var customerID string + err := env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT customer_id FROM %s WHERE index = 1 LIMIT 1", qualifiedTableName)).Scan(&customerID) + require.NoError(t, err, "could not get customer_id for index 1") + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'updated.on.n1@example.com' WHERE index = 1 AND customer_id = $1", qualifiedTableName) + _, err = tx.Exec(ctx, updateSQL, customerID) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found. Result: %+v", pairKey, mtreeTask.DiffResult) - require.GreaterOrEqual(t, len(nodeDiffs.Rows[serviceN1]), 1, "Expected at least 1 modified row on %s, got %d", serviceN1, len(nodeDiffs.Rows[serviceN1])) - require.GreaterOrEqual(t, len(nodeDiffs.Rows[serviceN2]), 1, "Expected at least 1 original row on %s, got %d", serviceN2, len(nodeDiffs.Rows[serviceN2])) + require.GreaterOrEqual(t, len(nodeDiffs.Rows[env.ServiceN1]), 1, "Expected at least 1 modified row on %s, got %d", env.ServiceN1, len(nodeDiffs.Rows[env.ServiceN1])) + require.GreaterOrEqual(t, len(nodeDiffs.Rows[env.ServiceN2]), 1, "Expected at least 1 original row on %s, got %d", env.ServiceN2, len(nodeDiffs.Rows[env.ServiceN2])) require.GreaterOrEqual(t, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], 1, "Expected summary diff count to be at least 1") // Find the row with index=1 in the diff (block-based diff may include other rows in the same block) var diffRowN1, diffRowN2 types.OrderedMap - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if idx, ok := row.Get("index"); ok && idx.(int32) == 1 { diffRowN1 = row break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if idx, ok := row.Get("index"); ok && idx.(int32) == 1 { diffRowN2 = row break @@ -397,23 +383,23 @@ type compositeBoundaryKey struct { CustomerID string } -func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { +func testMerkleTreeDiffBoundaryModifications(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") err = mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = regexp_replace(email, '\\\\.(cdc_hw|cdc_fallback)$', '')", qualifiedTableName)) + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = regexp_replace(email, '\\\\.(cdc_hw|cdc_fallback)$', '')", qualifiedTableName)) err := mtreeTask.MtreeTeardown() if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -428,7 +414,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { var boundaryPkeys []any if mtreeTask.SimplePrimaryKey { query := fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND range_start IS NOT NULL UNION SELECT range_end FROM %s.%s WHERE node_level=0 AND range_end IS NOT NULL", aceSchema, mtreeTableName, aceSchema, mtreeTableName) - rows, err := pgCluster.Node1Pool.Query(ctx, query) + rows, err := env.N1Pool.Query(ctx, query) require.NoError(t, err) defer rows.Close() for rows.Next() { @@ -440,7 +426,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } else { // For composite keys, we get the text representation and parse it. query := fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND range_start IS NOT NULL UNION SELECT range_end::text FROM %s.%s WHERE node_level=0 AND range_end IS NOT NULL", aceSchema, mtreeTableName, aceSchema, mtreeTableName) - rows, err := pgCluster.Node1Pool.Query(ctx, query) + rows, err := env.N1Pool.Query(ctx, query) require.NoError(t, err) defer rows.Close() re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) @@ -484,59 +470,55 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } pkeysToModify = uniquePkeysList - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - for _, pkey := range pkeysToModify { - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1", qualifiedTableName, pkey) - _, err := tx.Exec(ctx, updateSQL, pkey) - require.NoError(t, err) - } - } else { - for _, pkey := range pkeysToModify { - ckey := pkey.(compositeBoundaryKey) - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName, ckey.Index) - _, err := tx.Exec(ctx, updateSQL, ckey.Index, ckey.CustomerID) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + if mtreeTask.SimplePrimaryKey { + for _, pkey := range pkeysToModify { + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1", qualifiedTableName, pkey) + _, err := tx.Exec(ctx, updateSQL, pkey) + require.NoError(t, err) + } + } else { + for _, pkey := range pkeysToModify { + ckey := pkey.(compositeBoundaryKey) + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'boundary.update.%v@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName, ckey.Index) + _, err := tx.Exec(ctx, updateSQL, ckey.Index, ckey.CustomerID) + require.NoError(t, err) + } } - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) expectedDiffCount := len(pkeysToModify) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN1]), "Expected %d modified rows on %s", expectedDiffCount, serviceN1) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN2]), "Expected %d modified rows on %s", expectedDiffCount, serviceN2) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN1]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN1) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN2) require.Equal(t, expectedDiffCount, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", expectedDiffCount) if mtreeTask.SimplePrimaryKey { for _, pkey := range pkeysToModify { foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if rIndex, ok := row.Get("index"); ok && rIndex == pkey { foundN1 = true break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if rIndex, ok := row.Get("index"); ok && rIndex == pkey { foundN2 = true emailVal, _ := row.Get("email") @@ -553,7 +535,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { ckey := pkey.(compositeBoundaryKey) foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { rIndex, _ := row.Get("index") rCustomerID, _ := row.Get("customer_id") if rIndex == ckey.Index && rCustomerID == ckey.CustomerID { @@ -561,7 +543,7 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { rIndex, _ := row.Get("index") rCustomerID, _ := row.Get("customer_id") if rIndex == ckey.Index && rCustomerID == ckey.CustomerID { @@ -578,14 +560,14 @@ func testMerkleTreeDiffBoundaryModifications(t *testing.T, tableName string) { } } -func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { +func testMerkleTreeContinuousCDC(t *testing.T, env *testEnv, tableName string) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() largeTableName := "customers_1M" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -596,7 +578,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -610,7 +592,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { wg.Add(1) go func() { defer wg.Done() - nodeInfo := pgCluster.ClusterNodes[0] + nodeInfo := env.ClusterNodes[0] cdc.ListenForChanges(ctx, nodeInfo) }() @@ -618,7 +600,7 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { aceSchema := config.Cfg.MTree.Schema mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) - pool := pgCluster.Node1Pool + pool := env.N1Pool var leafNodeCount int err = pool.QueryRow(context.Background(), fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCount) @@ -633,75 +615,71 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { require.NoError(t, err) defer tx.Rollback(context.Background()) - _, err = tx.Exec(context.Background(), "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - if mtreeTask.SimplePrimaryKey { - var firstBlockStart, middleBlockStart, lastBlockStart int32 - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstBlockStart) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleBlockStart) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastBlockStart) - require.NoError(t, err) - - // DELETE + INSERT in first block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), firstBlockStart) - require.NoError(t, err) - insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) - _, err = tx.Exec(context.Background(), insertSQL, firstBlockStart) - require.NoError(t, err) + env.withRepairModeTx(t, context.Background(), tx, func() { + if mtreeTask.SimplePrimaryKey { + var firstBlockStart, middleBlockStart, lastBlockStart int32 + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstBlockStart) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleBlockStart) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastBlockStart) + require.NoError(t, err) - // UPDATE in middle block - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test@example.com' WHERE index = $1", qualifiedTableName) - _, err = tx.Exec(context.Background(), updateSQL, middleBlockStart) - require.NoError(t, err) + // DELETE + INSERT in first block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), firstBlockStart) + require.NoError(t, err) + insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) + _, err = tx.Exec(context.Background(), insertSQL, firstBlockStart) + require.NoError(t, err) - // DELETE in last block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), lastBlockStart) - require.NoError(t, err) + // UPDATE in middle block + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test@example.com' WHERE index = $1", qualifiedTableName) + _, err = tx.Exec(context.Background(), updateSQL, middleBlockStart) + require.NoError(t, err) - } else { // Composite Primary Key - re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) - parseKey := func(keyStr string) (int32, string) { - matches := re.FindStringSubmatch(keyStr) - require.Len(t, matches, 3, "could not parse composite key: %s", keyStr) - index, err := strconv.Atoi(matches[1]) + // DELETE in last block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1", qualifiedTableName), lastBlockStart) require.NoError(t, err) - return int32(index), matches[2] - } - var firstKeyStr, middleKeyStr, lastKeyStr string - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstKeyStr) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleKeyStr) - require.NoError(t, err) - err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastKeyStr) - require.NoError(t, err) + } else { // Composite Primary Key + re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) + parseKey := func(keyStr string) (int32, string) { + matches := re.FindStringSubmatch(keyStr) + require.Len(t, matches, 3, "could not parse composite key: %s", keyStr) + index, err := strconv.Atoi(matches[1]) + require.NoError(t, err) + return int32(index), matches[2] + } - firstIdx, firstCustId := parseKey(firstKeyStr) - middleIdx, middleCustId := parseKey(middleKeyStr) - lastIdx, lastCustId := parseKey(lastKeyStr) + var firstKeyStr, middleKeyStr, lastKeyStr string + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), firstBlockPos).Scan(&firstKeyStr) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), middleBlockPos).Scan(&middleKeyStr) + require.NoError(t, err) + err = tx.QueryRow(context.Background(), fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), lastBlockPos).Scan(&lastKeyStr) + require.NoError(t, err) - // DELETE + INSERT in first block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), firstIdx, firstCustId) - require.NoError(t, err) - insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) - _, err = tx.Exec(context.Background(), insertSQL, firstIdx, firstCustId) - require.NoError(t, err) + firstIdx, firstCustId := parseKey(firstKeyStr) + middleIdx, middleCustId := parseKey(middleKeyStr) + lastIdx, lastCustId := parseKey(lastKeyStr) - // UPDATE in middle block - updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test.composite@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName) - _, err = tx.Exec(context.Background(), updateSQL, middleIdx, middleCustId) - require.NoError(t, err) + // DELETE + INSERT in first block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), firstIdx, firstCustId) + require.NoError(t, err) + insertSQL := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize()) + _, err = tx.Exec(context.Background(), insertSQL, firstIdx, firstCustId) + require.NoError(t, err) - // DELETE in last block - _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), lastIdx, lastCustId) - require.NoError(t, err) - } + // UPDATE in middle block + updateSQL := fmt.Sprintf("UPDATE %s SET email = 'cdc.update.test.composite@example.com' WHERE index = $1 AND customer_id = $2", qualifiedTableName) + _, err = tx.Exec(context.Background(), updateSQL, middleIdx, middleCustId) + require.NoError(t, err) - _, err = tx.Exec(context.Background(), "SELECT spock.repair_mode(false)") - require.NoError(t, err) + // DELETE in last block + _, err = tx.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE index = $1 AND customer_id = $2", qualifiedTableName), lastIdx, lastCustId) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(context.Background())) time.Sleep(5 * time.Second) @@ -737,8 +715,8 @@ func testMerkleTreeContinuousCDC(t *testing.T, tableName string) { cancel() wg.Wait() - nodes = []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodes) + nodes = []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodes) err = tdTask.RunChecks(false) require.NoError(t, err, "RunChecks after CDC should succeed") @@ -762,35 +740,35 @@ const ( splitLast splitCase = "last" ) -func testMerkleTreeMergeInitialRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, initial) +func testMerkleTreeMergeInitialRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, initial) } -func testMerkleTreeMergeMiddleRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, middle) +func testMerkleTreeMergeMiddleRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, middle) } -func testMerkleTreeMergeLastRanges(t *testing.T, tableName string) { - runMerkleTreeMergeTest(t, tableName, last) +func testMerkleTreeMergeLastRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeMergeTest(t, env, tableName, last) } -func testMerkleTreeSplitInitialRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitInitial) +func testMerkleTreeSplitInitialRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitInitial) } -func testMerkleTreeSplitMiddleRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitMiddle) +func testMerkleTreeSplitMiddleRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitMiddle) } -func testMerkleTreeSplitLastRanges(t *testing.T, tableName string) { - runMerkleTreeSplitTest(t, tableName, splitLast) +func testMerkleTreeSplitLastRanges(t *testing.T, env *testEnv, tableName string) { + runMerkleTreeSplitTest(t, env, tableName, splitLast) } -func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { +func runMerkleTreeMergeTest(t *testing.T, env *testEnv, tableName string, mc mergeCase) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -801,7 +779,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN2) + env.repairTable(t, qualifiedTableName, env.ServiceN2) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -814,7 +792,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) var leafNodeCountBefore int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) require.NoError(t, err) var startPos, endPos int @@ -837,20 +815,20 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if mtreeTask.SimplePrimaryKey { var startKey int32 - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startKey) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startKey) require.NoError(t, err) startRange = []any{startKey} if endPos != -1 { var endKey int32 - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endKey) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endKey) require.NoError(t, err) endRange = []any{endKey} } } else { re := regexp.MustCompile(`^\((\d+),"?([^",]+)"?\)$`) var startStr string - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startStr) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_start::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), startPos).Scan(&startStr) require.NoError(t, err) startMatches := re.FindStringSubmatch(startStr) require.Len(t, startMatches, 3, "should parse composite key from string") @@ -859,7 +837,7 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { if endPos != -1 { var endStr string - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endStr) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT range_end::text FROM %s.%s WHERE node_level = 0 AND node_position = $1", aceSchema, mtreeTableName), endPos).Scan(&endStr) require.NoError(t, err) endMatches := re.FindStringSubmatch(endStr) require.Len(t, endMatches, 3, "should parse composite key from string") @@ -868,47 +846,43 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { } } - tx, err := pgCluster.Node1Pool.Begin(ctx) + tx, err := env.N1Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - var cmdTag pgconn.CommandTag - if endPos == -1 { // Deletion to the end - if mtreeTask.SimplePrimaryKey { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0]) - } else { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2)", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1]) - } - } else { // Deletion within a range - if mtreeTask.SimplePrimaryKey { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1 AND index <= $2", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], endRange[0]) - } else { - deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2) AND (index, customer_id) <= ($3, $4)", qualifiedTableName) - cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1], endRange[0], endRange[1]) + env.withRepairModeTx(t, ctx, tx, func() { + var cmdTag pgconn.CommandTag + if endPos == -1 { // Deletion to the end + if mtreeTask.SimplePrimaryKey { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0]) + } else { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2)", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1]) + } + } else { // Deletion within a range + if mtreeTask.SimplePrimaryKey { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE index >= $1 AND index <= $2", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], endRange[0]) + } else { + deleteSQL := fmt.Sprintf("DELETE FROM %s WHERE (index, customer_id) >= ($1, $2) AND (index, customer_id) <= ($3, $4)", qualifiedTableName) + cmdTag, err = tx.Exec(ctx, deleteSQL, startRange[0], startRange[1], endRange[0], endRange[1]) + } } - } - require.NoError(t, err) - deletedCount = cmdTag.RowsAffected() - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + require.NoError(t, err) + deletedCount = cmdTag.RowsAffected() + }) require.NoError(t, tx.Commit(ctx)) require.Greater(t, deletedCount, int64(0), "should have deleted some rows") - mtreeTask.Nodes = serviceN1 + mtreeTask.Nodes = env.ServiceN1 mtreeTask.Rebalance = true err = mtreeTask.UpdateMtree(true) require.NoError(t, err, "UpdateMtree with rebalance should succeed") var leafNodeCountAfter int - err = pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) require.NoError(t, err) require.Less(t, leafNodeCountAfter, leafNodeCountBefore, "Number of leaf nodes should decrease after merge") @@ -919,25 +893,25 @@ func runMerkleTreeMergeTest(t *testing.T, tableName string, mc mergeCase) { err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) - require.Equal(t, 0, len(nodeDiffs.Rows[serviceN1]), "Expected 0 extra rows on %s", serviceN1) - require.Equal(t, int(deletedCount), len(nodeDiffs.Rows[serviceN2]), "Expected %d missing rows on %s", deletedCount, serviceN1) + require.Equal(t, 0, len(nodeDiffs.Rows[env.ServiceN1]), "Expected 0 extra rows on %s", env.ServiceN1) + require.Equal(t, int(deletedCount), len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d missing rows on %s", deletedCount, env.ServiceN1) require.Equal(t, int(deletedCount), mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", deletedCount) } -func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { +func runMerkleTreeSplitTest(t *testing.T, env *testEnv, tableName string, sc splitCase) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) largeTableName := "customers_1M" - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -956,35 +930,32 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { // No deletes needed for last split } if deleteRange != nil { - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - _, err = tx.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) - require.NoError(t, err) - - createSQL := fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s WHERE index >= %d AND index <= %d", tempTableName, qualifiedTableName, deleteRange[0], deleteRange[1]) - _, err = tx.Exec(ctx, createSQL) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) + require.NoError(t, err) - _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE index >= %d AND index <= %d", qualifiedTableName, deleteRange[0], deleteRange[1])) - require.NoError(t, err) + createSQL := fmt.Sprintf("CREATE TABLE %s AS SELECT * FROM %s WHERE index >= %d AND index <= %d", tempTableName, qualifiedTableName, deleteRange[0], deleteRange[1]) + _, err = tx.Exec(ctx, createSQL) + require.NoError(t, err) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE index >= %d AND index <= %d", qualifiedTableName, deleteRange[0], deleteRange[1])) + require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) } t.Cleanup(func() { - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) } - _, err = pgCluster.Node2Pool.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) + _, err = env.N2Pool.Exec(ctx, "DROP TABLE IF EXISTS "+tempTableName) if err != nil { t.Logf("Warning: failed to drop temp table %s during cleanup: %v", tempTableName, err) } @@ -998,7 +969,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { err = mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") - mtreeTask.Nodes = serviceN2 + mtreeTask.Nodes = env.ServiceN2 mtreeTask.RunChecks(false) err = mtreeTask.BuildMtree() require.NoError(t, err, "BuildMtree should succeed") @@ -1007,7 +978,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { mtreeTableName := fmt.Sprintf("ace_mtree_%s_%s", testSchema, tableName) var leafNodeCountBefore int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountBefore) require.NoError(t, err) var q string @@ -1016,28 +987,24 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { q = fmt.Sprintf("INSERT INTO %s SELECT * FROM %s order by index", qualifiedTableName, tempTableName) case splitLast: var maxIndex int32 - err := pgCluster.Node2Pool.QueryRow(ctx, "SELECT max(index) FROM "+qualifiedTableName).Scan(&maxIndex) + err := env.N2Pool.QueryRow(ctx, "SELECT max(index) FROM "+qualifiedTableName).Scan(&maxIndex) require.NoError(t, err) insertStartIndex := maxIndex + 1 q = fmt.Sprintf("INSERT INTO %s SELECT * FROM %s WHERE index >= %d order by index LIMIT %d", qualifiedTableName, pgx.Identifier{largeTableName}.Sanitize(), insertStartIndex, rowsToInsert) } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, q) - require.NoError(t, err) - if sc != splitLast { - _, err = tx.Exec(ctx, "DROP TABLE "+tempTableName) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, q) require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + if sc != splitLast { + _, err = tx.Exec(ctx, "DROP TABLE "+tempTableName) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) mtreeTask.Rebalance = false @@ -1045,7 +1012,7 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { require.NoError(t, err, "UpdateMtree should succeed") var leafNodeCountAfter int - pool, err := connectToNode(pgCluster.Node2Host, pgCluster.Node2Port, pgEdgeUser, pgEdgePassword, dbName) + pool, err := connectToNode(env.N2Host, env.N2Port, env.DBUser, env.DBPassword, env.DBName) require.NoError(t, err) err = pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s.%s WHERE node_level = 0", aceSchema, mtreeTableName)).Scan(&leafNodeCountAfter) @@ -1060,11 +1027,11 @@ func runMerkleTreeSplitTest(t *testing.T, tableName string, sc splitCase) { // require.Empty(t, mtreeTask.DiffResult.NodeDiffs, "Merkle trees should be in sync after identical splits") } -func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { +func testMerkleTreeDiffModifiedRows(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.RunChecks(false) require.NoError(t, err, "RunChecks should succeed") @@ -1075,7 +1042,7 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { if err != nil { t.Logf("Warning: MtreeTeardown failed during cleanup: %v", err) } - repairTable(t, qualifiedTableName, serviceN1) + env.repairTable(t, qualifiedTableName, env.ServiceN1) files, _ := filepath.Glob("*_diffs-*.json") for _, f := range files { os.Remove(f) @@ -1101,44 +1068,40 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { }, } - tx, err := pgCluster.Node2Pool.Begin(ctx) + tx, err := env.N2Pool.Begin(ctx) require.NoError(t, err) defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - for _, mod := range modifications { - updateSQL := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE index = $2", qualifiedTableName, mod.field) - _, err = tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) - require.NoError(t, err) - } - - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + for _, mod := range modifications { + updateSQL := fmt.Sprintf("UPDATE %s SET %s = $1 WHERE index = $2", qualifiedTableName, mod.field) + _, err = tx.Exec(ctx, updateSQL, mod.value, mod.indexVal) + require.NoError(t, err) + } + }) require.NoError(t, tx.Commit(ctx)) err = mtreeTask.DiffMtree() require.NoError(t, err, "DiffMtree should succeed") - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey] require.True(t, ok, "Expected diffs for pair %s, but none found.", pairKey) expectedDiffCount := len(modifications) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN1]), "Expected %d modified rows on %s", expectedDiffCount, serviceN1) - require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[serviceN2]), "Expected %d modified rows on %s", expectedDiffCount, serviceN2) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN1]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN1) + require.Equal(t, expectedDiffCount, len(nodeDiffs.Rows[env.ServiceN2]), "Expected %d modified rows on %s", expectedDiffCount, env.ServiceN2) require.Equal(t, expectedDiffCount, mtreeTask.DiffResult.Summary.DiffRowsCount[pairKey], "Expected summary diff count to be %d", expectedDiffCount) for _, mod := range modifications { foundN1 := false foundN2 := false - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { if rIndex, ok := row.Get("index"); ok && rIndex.(int32) == int32(mod.indexVal) { foundN1 = true fieldVal, _ := row.Get(mod.field) @@ -1146,7 +1109,7 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { break } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { if rIndex, ok := row.Get("index"); ok && rIndex.(int32) == int32(mod.indexVal) { foundN2 = true fieldVal, _ := row.Get(mod.field) @@ -1159,11 +1122,11 @@ func testMerkleTreeDiffModifiedRows(t *testing.T, tableName string) { } } -func testMerkleTreeTeardown(t *testing.T, tableName string) { +func testMerkleTreeTeardown(t *testing.T, env *testEnv, tableName string) { ctx := context.Background() qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTableName, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTableName, nodes) err := mtreeTask.MtreeInit() require.NoError(t, err, "MtreeInit should succeed") @@ -1175,7 +1138,7 @@ func testMerkleTreeTeardown(t *testing.T, tableName string) { cdcPubName := config.Cfg.MTree.CDC.PublicationName cdcSlotName := config.Cfg.MTree.CDC.SlotName - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { require.False(t, tableExists(t, ctx, pool, "ace_cdc_metadata", aceSchema), "Table 'ace_cdc_metadata' should NOT exist after teardown") require.False(t, publicationExists(t, ctx, pool, cdcPubName), "Publication '%s' should NOT exist after teardown", cdcPubName) require.False(t, replicationSlotExists(t, ctx, pool, cdcSlotName), "Replication slot '%s' should NOT exist after teardown", cdcSlotName) @@ -1189,6 +1152,11 @@ func testMerkleTreeTeardown(t *testing.T, tableName string) { // numeric values with different PostgreSQL scales (e.g., 3000.00 vs 3000.0) // produce identical merkle tree hashes and are NOT reported as differences. func TestMerkleTreeNumericScaleInvariance(t *testing.T) { + env := newSpockEnv() + testMerkleTreeNumericScaleInvariance(t, env) +} + +func testMerkleTreeNumericScaleInvariance(t *testing.T, env *testEnv) { ctx := context.Background() numericTable := "ace_numeric_scale_test" qualifiedTable := fmt.Sprintf("%s.%s", testSchema, numericTable) @@ -1200,12 +1168,12 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { amount NUMERIC, label TEXT )`, qualifiedTable) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createSQL) require.NoError(t, err, "failed to create numeric test table") } t.Cleanup(func() { - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", qualifiedTable)) } }) @@ -1229,39 +1197,37 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { {5, "0.00", "0.0", "scale-diff-zero"}, } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { tx, err := pool.Begin(ctx) require.NoError(t, err) defer func(tx pgx.Tx) { _ = tx.Rollback(ctx) }(tx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - for _, r := range rows { - amt := r.n1Amt - if pool == pgCluster.Node2Pool { - amt = r.n2Amt + env.withRepairModeTx(t, ctx, tx, func() { + for _, r := range rows { + amt := r.n1Amt + if pool == env.N2Pool { + amt = r.n2Amt + } + insertSQL := fmt.Sprintf("INSERT INTO %s (id, amount, label) VALUES ($1, %s, $2) ON CONFLICT (id) DO UPDATE SET amount = %s, label = $2", qualifiedTable, amt, amt) + _, err := tx.Exec(ctx, insertSQL, r.id, r.label) + require.NoError(t, err) } - insertSQL := fmt.Sprintf("INSERT INTO %s (id, amount, label) VALUES ($1, %s, $2) ON CONFLICT (id) DO UPDATE SET amount = %s, label = $2", qualifiedTable, amt, amt) - _, err := tx.Exec(ctx, insertSQL, r.id, r.label) - require.NoError(t, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) + }) require.NoError(t, tx.Commit(ctx)) } // Verify the scales really are different in storage by checking ::text var n1Text, n2Text string - err := pgCluster.Node1Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n1Text) + err := env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n1Text) require.NoError(t, err) - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n2Text) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT amount::text FROM %s WHERE id = 1", qualifiedTable)).Scan(&n2Text) require.NoError(t, err) t.Logf("Node1 amount::text = %q, Node2 amount::text = %q", n1Text, n2Text) // They should differ in scale representation require.NotEqual(t, n1Text, n2Text, "precondition: numeric ::text should differ between nodes (different scales)") // Now run merkle tree init + build + diff - nodes := []string{serviceN1, serviceN2} - mtreeTask := newTestMerkleTreeTask(t, qualifiedTable, nodes) + nodes := []string{env.ServiceN1, env.ServiceN2} + mtreeTask := env.newMerkleTreeTask(t, qualifiedTable, nodes) mtreeTask.BlockSize = 1000 // must be >= 1000 for RunChecks; table has 5 rows so 1 block err = mtreeTask.RunChecks(false) @@ -1282,16 +1248,16 @@ func TestMerkleTreeNumericScaleInvariance(t *testing.T) { // The key assertion: there should be ZERO differences because trim_scale // normalizes numeric values before hashing. - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 + pairKey := env.ServiceN1 + "/" + env.ServiceN2 + if strings.Compare(env.ServiceN1, env.ServiceN2) > 0 { + pairKey = env.ServiceN2 + "/" + env.ServiceN1 } if nodeDiffs, ok := mtreeTask.DiffResult.NodeDiffs[pairKey]; ok { - n1Rows := len(nodeDiffs.Rows[serviceN1]) - n2Rows := len(nodeDiffs.Rows[serviceN2]) - require.Equal(t, 0, n1Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", serviceN1, n1Rows) - require.Equal(t, 0, n2Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", serviceN2, n2Rows) + n1Rows := len(nodeDiffs.Rows[env.ServiceN1]) + n2Rows := len(nodeDiffs.Rows[env.ServiceN2]) + require.Equal(t, 0, n1Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", env.ServiceN1, n1Rows) + require.Equal(t, 0, n2Rows, "Expected 0 diff rows on %s, got %d (numeric scale should not cause diffs)", env.ServiceN2, n2Rows) } // If pairKey is not in NodeDiffs at all, that means zero diffs — which is correct. @@ -1312,9 +1278,9 @@ func functionExists(t *testing.T, ctx context.Context, pool *pgxpool.Pool, funct var exists bool err := pool.QueryRow(ctx, ` SELECT EXISTS ( - SELECT 1 - FROM pg_proc p - JOIN pg_namespace n ON p.pronamespace = n.oid + SELECT 1 + FROM pg_proc p + JOIN pg_namespace n ON p.pronamespace = n.oid WHERE p.proname = $1 AND n.nspname = $2 )`, functionName, schemaName).Scan(&exists) require.NoError(t, err) @@ -1326,8 +1292,8 @@ func tableExists(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableNam var exists bool err := pool.QueryRow(ctx, ` SELECT EXISTS ( - SELECT 1 - FROM information_schema.tables + SELECT 1 + FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 )`, schemaName, tableName).Scan(&exists) require.NoError(t, err) diff --git a/tests/integration/native_pg_test.go b/tests/integration/native_pg_test.go index 9ea66f7..0cf5a74 100644 --- a/tests/integration/native_pg_test.go +++ b/tests/integration/native_pg_test.go @@ -221,6 +221,17 @@ func TestNativePG(t *testing.T) { "load CSV into customers") } + // Create and load customers_1M table (needed by merkle tree CDC and split tests). + for _, pool := range env.pools() { + require.NoError(t, createTestTable(ctx, pool, testSchema, "customers_1M")) + } + csv1MPath, err := filepath.Abs(defaultCsvFilePath + "customers_1M.csv") + require.NoError(t, err) + for _, pool := range env.pools() { + require.NoError(t, loadDataFromCSV(ctx, pool, testSchema, "customers_1M", csv1MPath), + "load CSV into customers_1M") + } + // ── Native-specific tests ───────────────────────────────────────────── t.Run("CheckSpockInstalled_ReturnsFalse", func(t *testing.T) { @@ -319,4 +330,28 @@ func TestNativePG(t *testing.T) { t.Run("TableRepair_FixNulls_BidirectionalUpdate", func(t *testing.T) { testTableRepair_FixNulls_BidirectionalUpdate(t, env) }) + + // ── Merkle tree tests ──────────────────────────────────────────────── + + t.Run("MerkleTreeSimplePK", func(t *testing.T) { + runMerkleTreeTests(t, env, "customers") + }) + + t.Run("MerkleTreeCompositePK", func(t *testing.T) { + for _, pool := range env.pools() { + err := alterTableToCompositeKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + t.Cleanup(func() { + for _, pool := range env.pools() { + err := revertTableToSimpleKey(ctx, pool, env.Schema, "customers") + require.NoError(t, err) + } + }) + runMerkleTreeTests(t, env, "customers") + }) + + t.Run("MerkleTreeNumericScaleInvariance", func(t *testing.T) { + testMerkleTreeNumericScaleInvariance(t, env) + }) } diff --git a/tests/integration/test_env_test.go b/tests/integration/test_env_test.go index 4d2676b..962d730 100644 --- a/tests/integration/test_env_test.go +++ b/tests/integration/test_env_test.go @@ -26,6 +26,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/pgedge/ace/internal/consistency/diff" + "github.com/pgedge/ace/internal/consistency/mtree" "github.com/pgedge/ace/internal/consistency/repair" "github.com/pgedge/ace/pkg/types" "github.com/stretchr/testify/assert" @@ -39,6 +40,11 @@ type testEnv struct { N2Pool *pgxpool.Pool N3Pool *pgxpool.Pool // nil for native PG (only 2 nodes) + N1Host string + N1Port string + N2Host string + N2Port string + ServiceN1 string ServiceN2 string ServiceN3 string // empty for native PG @@ -59,6 +65,10 @@ func newSpockEnv() *testEnv { N1Pool: pgCluster.Node1Pool, N2Pool: pgCluster.Node2Pool, N3Pool: pgCluster.Node3Pool, + N1Host: pgCluster.Node1Host, + N1Port: pgCluster.Node1Port, + N2Host: pgCluster.Node2Host, + N2Port: pgCluster.Node2Port, ServiceN1: serviceN1, ServiceN2: serviceN2, ServiceN3: serviceN3, @@ -77,6 +87,10 @@ func newNativeEnv(state *nativeClusterState) *testEnv { N1Pool: state.n1Pool, N2Pool: state.n2Pool, N3Pool: nil, + N1Host: state.n1Host, + N1Port: state.n1Port, + N2Host: state.n2Host, + N2Port: state.n2Port, ServiceN1: nativeServiceN1, ServiceN2: nativeServiceN2, ServiceN3: "", @@ -370,3 +384,15 @@ func (e *testEnv) pairKey() string { } return e.ServiceN1 + "/" + e.ServiceN2 } + +// newMerkleTreeTask creates a MerkleTreeTask configured for this environment. +func (e *testEnv) newMerkleTreeTask(t *testing.T, qualifiedTableName string, nodes []string) *mtree.MerkleTreeTask { + t.Helper() + task := mtree.NewMerkleTreeTask() + task.ClusterName = e.ClusterName + task.DBName = e.DBName + task.QualifiedTableName = qualifiedTableName + task.Nodes = strings.Join(nodes, ",") + task.BlockSize = 1000 + return task +} From fc2da360cd6251927bc4124630423224c6a24c8a Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 13:43:37 -0700 Subject: [PATCH 05/14] fix: per-node spock detection and word-boundary matching for native PG queries - Use regex word boundaries (\m/\M) instead of LIKE substring matching to prevent node name collisions (e.g., n1 vs n10) - Check spock extension on every node instead of just the first - Make spock detection per-node and lazy, supporting mixed clusters and fixing early-access bug in recovery mode Co-Authored-By: Claude Opus 4.6 (1M context) --- db/queries/templates.go | 4 +- internal/consistency/diff/repset_diff.go | 20 +++++---- internal/consistency/diff/spock_diff.go | 9 ++--- internal/consistency/repair/table_repair.go | 45 +++++++++++---------- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/db/queries/templates.go b/db/queries/templates.go index c0ce0ff..00f4d4e 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -1554,7 +1554,7 @@ var SQLTemplates = Templates{ FROM pg_catalog.pg_replication_origin_status ros JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id JOIN pg_catalog.pg_subscription s ON ro.roname LIKE 'pg_%' || s.oid::text - WHERE s.subname LIKE '%' || $1 || '%' + WHERE s.subname ~ ('\m' || $1 || '\M') AND ros.remote_lsn IS NOT NULL LIMIT 1 `)), @@ -1562,7 +1562,7 @@ var SQLTemplates = Templates{ SELECT rs.confirmed_flush_lsn::text FROM pg_catalog.pg_replication_slots rs JOIN pg_catalog.pg_subscription s ON rs.slot_name = s.subslotname - WHERE s.subname LIKE '%' || $1 || '%' + WHERE s.subname ~ ('\m' || $1 || '\M') AND rs.confirmed_flush_lsn IS NOT NULL ORDER BY rs.confirmed_flush_lsn DESC LIMIT 1 diff --git a/internal/consistency/diff/repset_diff.go b/internal/consistency/diff/repset_diff.go index e0b545f..85e82c7 100644 --- a/internal/consistency/diff/repset_diff.go +++ b/internal/consistency/diff/repset_diff.go @@ -157,17 +157,15 @@ func (c *RepsetDiffCmd) RunChecks(skipValidation bool) error { return fmt.Errorf("could not connect to node %s: %w", nodeName, err) } - // Check if spock extension is installed (only on first node) - if len(repsetNodeNames) == 0 { - spockInstalled, err := queries.CheckSpockInstalled(c.Ctx, pool) - if err != nil { - pool.Close() - return fmt.Errorf("failed to check for spock extension on node %s: %w", nodeName, err) - } - if !spockInstalled { - pool.Close() - return fmt.Errorf("repset-diff requires the spock extension, which is not installed on this cluster") - } + // Check if spock extension is installed + spockInstalled, err := queries.CheckSpockInstalled(c.Ctx, pool) + if err != nil { + pool.Close() + return fmt.Errorf("failed to check for spock extension on node %s: %w", nodeName, err) + } + if !spockInstalled { + pool.Close() + return fmt.Errorf("repset-diff requires the spock extension, which is not installed on node %s", nodeName) } repsetExists, err := queries.CheckRepSetExists(c.Ctx, pool, c.RepsetName) diff --git a/internal/consistency/diff/spock_diff.go b/internal/consistency/diff/spock_diff.go index d8dcea9..7056734 100644 --- a/internal/consistency/diff/spock_diff.go +++ b/internal/consistency/diff/spock_diff.go @@ -307,16 +307,15 @@ func (t *SpockDiffTask) ExecuteTask() (err error) { } t.Pools = pools - // Check if spock extension is installed - for _, pool := range t.Pools { + // Check that spock is installed on every selected node + for name, pool := range t.Pools { spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) if err != nil { - return fmt.Errorf("failed to check for spock extension: %w", err) + return fmt.Errorf("failed to check for spock extension on node %s: %w", name, err) } if !spockInstalled { - return fmt.Errorf("spock-diff requires the spock extension, which is not installed on this cluster") + return fmt.Errorf("spock-diff requires the spock extension, which is not installed on node %s", name) } - break } allNodeConfigs := make(map[string]SpockNodeConfig) diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index 296b820..76ab624 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -98,7 +98,7 @@ type TableRepairTask struct { autoSelectionFailedNode string autoSelectionDetails map[string]map[string]string - spockAvailable bool + spockPerNode map[string]bool Ctx context.Context } @@ -162,21 +162,26 @@ func (t *TableRepairTask) setRole(tx pgx.Tx, nodeName string) error { return nil } -// detectSpock checks whether the spock extension is installed on any available -// node and stores the result in t.spockAvailable. -func (t *TableRepairTask) detectSpock() { - for nodeName, pool := range t.Pools { +// isSpockAvailable returns whether spock is installed on the given node, +// detecting lazily on first check per node. +func (t *TableRepairTask) isSpockAvailable(nodeName string) bool { + if t.spockPerNode == nil { + t.spockPerNode = make(map[string]bool) + } + if _, checked := t.spockPerNode[nodeName]; !checked { + pool := t.Pools[nodeName] + if pool == nil { + return false + } spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) if err != nil { logger.Warn("failed to detect spock extension on %s: %v", nodeName, err) - continue + return false } - t.spockAvailable = spockInstalled - logger.Info("spock extension detected: %v (checked on %s)", spockInstalled, nodeName) - return + t.spockPerNode[nodeName] = spockInstalled + logger.Info("spock extension on %s: %v", nodeName, spockInstalled) } - logger.Warn("could not detect spock extension on any node; assuming not available") - t.spockAvailable = false + return t.spockPerNode[nodeName] } // setupTransactionMode enables spock repair mode (when available), sets the session replication @@ -184,7 +189,7 @@ func (t *TableRepairTask) detectSpock() { // Returns true if spock repair mode was activated, false otherwise. func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool, error) { spockRepairModeActive := false - if t.spockAvailable { + if t.isSpockAvailable(nodeName) { if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { return false, fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) } @@ -212,7 +217,7 @@ func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool // disableSpockRepairMode disables spock repair mode on the given transaction. // When spock is not available, this is a no-op. func (t *TableRepairTask) disableSpockRepairMode(tx pgx.Tx, nodeName string) error { - if !t.spockAvailable { + if !t.isSpockAvailable(nodeName) { return nil } _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") @@ -753,11 +758,6 @@ func (t *TableRepairTask) Run(skipValidation bool) (err error) { } }() - // Detect whether spock extension is available before any repair path - if !t.DryRun { - t.detectSpock() - } - if t.FixNulls { return t.runFixNulls(startTime) } @@ -2918,14 +2918,15 @@ func calculateRepairSetsWithSourceOfTruth(task *TableRepairTask) (map[string]map } func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survivor string) (originLSN *uint64, slotLSN *uint64, err error) { + spock := t.isSpockAvailable(survivor) var originStr *string - if t.spockAvailable { + if spock { originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) } else { originStr, err = queries.GetNativeOriginLSNForNode(t.Ctx, pool, failedNode) } if err != nil { - if !t.spockAvailable { + if !spock { // Native PG queries may fail if subscription naming doesn't match; treat as no data logger.Warn("failed to fetch native origin lsn on %s: %v", survivor, err) return nil, nil, nil @@ -2940,13 +2941,13 @@ func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survi } var slotStr *string - if t.spockAvailable { + if spock { slotStr, err = queries.GetSpockSlotLSNForNode(t.Ctx, pool, failedNode) } else { slotStr, err = queries.GetNativeSlotLSNForNode(t.Ctx, pool, failedNode) } if err != nil { - if !t.spockAvailable { + if !spock { // Native PG queries may fail if subscription naming doesn't match; treat as no data logger.Warn("failed to fetch native slot lsn on %s: %v", survivor, err) return originLSN, nil, nil From 387e7a950cbe8f6ca184fa13317e6ad5d73a3bc8 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 14:04:22 -0700 Subject: [PATCH 06/14] fix: pass TaskStore to child table-diff tasks in repset-diff Without this, each child table-diff opens its own SQLite connection instead of sharing the parent's handle. Matches CloneForSchedule(). Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/consistency/diff/repset_diff.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/consistency/diff/repset_diff.go b/internal/consistency/diff/repset_diff.go index 85e82c7..eb97dc6 100644 --- a/internal/consistency/diff/repset_diff.go +++ b/internal/consistency/diff/repset_diff.go @@ -396,6 +396,7 @@ func RepsetDiff(task *RepsetDiffCmd) (err error) { tdTask.QuietMode = task.Quiet tdTask.Ctx = task.Ctx tdTask.SkipDBUpdate = task.SkipDBUpdate + tdTask.TaskStore = task.TaskStore tdTask.TaskStorePath = task.TaskStorePath if err := tdTask.Validate(); err != nil { From 0fd62a9d1329103e80c142b8b5fd11f0e3e93949 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 14:53:16 -0700 Subject: [PATCH 07/14] refactor: rename SpockNodeNames to NodeOriginNames for dual-mode clarity The field and methods are used on both Spock and native PG paths, so the "Spock" prefix was misleading. Also fixes --against-origin error messages to list available origin IDs/names instead of referencing "spock node id". Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/consistency/diff/table_diff.go | 48 ++++++++++++++---------- internal/consistency/diff/table_rerun.go | 4 +- internal/consistency/mtree/merkle.go | 18 ++++----- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/internal/consistency/diff/table_diff.go b/internal/consistency/diff/table_diff.go index 6326def..c1479f9 100644 --- a/internal/consistency/diff/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -102,7 +102,7 @@ type TableDiffTask struct { blockHashSQLCache map[hashBoundsKey]string blockHashSQLMu sync.Mutex - SpockNodeNames map[string]string + NodeOriginNames map[string]string CompareUnitSize int MaxDiffRows int64 @@ -184,8 +184,8 @@ func (t *TableDiffTask) incrementDiffRowsLocked(delta int) bool { return false } -func (t *TableDiffTask) loadSpockNodeNames() error { - if t.SpockNodeNames != nil { +func (t *TableDiffTask) loadNodeOriginNames() error { + if t.NodeOriginNames != nil { return nil } @@ -196,17 +196,17 @@ func (t *TableDiffTask) loadSpockNodeNames() error { } if firstPool == nil { - t.SpockNodeNames = make(map[string]string) - return fmt.Errorf("no connection pool available to load spock node names") + t.NodeOriginNames = make(map[string]string) + return fmt.Errorf("no connection pool available to load node origin names") } names, err := queries.GetNodeOriginNames(t.Ctx, firstPool) if err != nil { - t.SpockNodeNames = make(map[string]string) + t.NodeOriginNames = make(map[string]string) return err } - t.SpockNodeNames = names + t.NodeOriginNames = names return nil } @@ -214,26 +214,36 @@ func (t *TableDiffTask) resolveAgainstOrigin() error { if strings.TrimSpace(t.AgainstOrigin) == "" { return nil } - if len(t.SpockNodeNames) == 0 { - return fmt.Errorf("unable to resolve --against-origin: spock node names not available") + if len(t.NodeOriginNames) == 0 { + return fmt.Errorf("unable to resolve --against-origin: no node origin names available") } orig := strings.TrimSpace(t.AgainstOrigin) - // direct match on id - if _, ok := t.SpockNodeNames[orig]; ok { + // direct match on origin id + if _, ok := t.NodeOriginNames[orig]; ok { t.resolvedAgainstOrigin = orig return nil } - // match on name - for id, name := range t.SpockNodeNames { + // match on origin name + for id, name := range t.NodeOriginNames { if name == orig { t.resolvedAgainstOrigin = id return nil } } - return fmt.Errorf("unable to resolve against-origin %q to a spock node id", t.AgainstOrigin) + // build a list of available origins for the error message + available := make([]string, 0, len(t.NodeOriginNames)) + for id, name := range t.NodeOriginNames { + if id != name { + available = append(available, fmt.Sprintf("%s (%s)", id, name)) + } else { + available = append(available, id) + } + } + + return fmt.Errorf("unable to resolve --against-origin %q; available origins: %s", t.AgainstOrigin, strings.Join(available, ", ")) } func (t *TableDiffTask) buildEffectiveFilter() (string, error) { @@ -270,7 +280,7 @@ func (t *TableDiffTask) buildEffectiveFilter() (string, error) { } func (t *TableDiffTask) withSpockMetadata(row map[string]any) map[string]any { - row["node_origin"] = utils.TranslateNodeOrigin(row["node_origin"], t.SpockNodeNames) + row["node_origin"] = utils.TranslateNodeOrigin(row["node_origin"], t.NodeOriginNames) return utils.AddSpockMetadata(row) } @@ -1289,8 +1299,8 @@ func (t *TableDiffTask) ExecuteTask() (err error) { } } - if err := t.loadSpockNodeNames(); err != nil { - logger.Warn("table-diff: unable to load spock node names; using raw node_origin values: %v", err) + if err := t.loadNodeOriginNames(); err != nil { + logger.Warn("table-diff: unable to load node origin names; using raw node_origin values: %v", err) } if err := t.resolveAgainstOrigin(); err != nil { @@ -1346,8 +1356,8 @@ func (t *TableDiffTask) ExecuteTask() (err error) { DiffRowsCount: make(map[string]int), AgainstOrigin: t.AgainstOrigin, AgainstOriginResolved: func() string { - if t.resolvedAgainstOrigin != "" && t.SpockNodeNames != nil { - if name, ok := t.SpockNodeNames[t.resolvedAgainstOrigin]; ok { + if t.resolvedAgainstOrigin != "" && t.NodeOriginNames != nil { + if name, ok := t.NodeOriginNames[t.resolvedAgainstOrigin]; ok { return name } } diff --git a/internal/consistency/diff/table_rerun.go b/internal/consistency/diff/table_rerun.go index c79c30b..ca4b1ac 100644 --- a/internal/consistency/diff/table_rerun.go +++ b/internal/consistency/diff/table_rerun.go @@ -79,8 +79,8 @@ func (t *TableDiffTask) ExecuteRerunTask() error { } }() - if err := t.loadSpockNodeNames(); err != nil { - logger.Warn("table-diff rerun: unable to load spock node names; using raw node_origin values: %v", err) + if err := t.loadNodeOriginNames(); err != nil { + logger.Warn("table-diff rerun: unable to load node origin names; using raw node_origin values: %v", err) } // Collect all unique primary keys from the original diff report diff --git a/internal/consistency/mtree/merkle.go b/internal/consistency/mtree/merkle.go index 1325eed..0736ce1 100644 --- a/internal/consistency/mtree/merkle.go +++ b/internal/consistency/mtree/merkle.go @@ -99,7 +99,7 @@ type MerkleTreeTask struct { diffMutex sync.Mutex diffRowKeySets map[string]map[string]map[string]struct{} StartTime time.Time - SpockNodeNames map[string]string + NodeOriginNames map[string]string Ctx context.Context } @@ -505,8 +505,8 @@ func (m *MerkleTreeTask) processWorkItem(work CompareRangesWorkItem, pool1, pool return nil } -func (m *MerkleTreeTask) loadSpockNodeNames() error { - if m.SpockNodeNames != nil { +func (m *MerkleTreeTask) loadNodeOriginNames() error { + if m.NodeOriginNames != nil { return nil } @@ -523,15 +523,15 @@ func (m *MerkleTreeTask) loadSpockNodeNames() error { lastErr = err continue } - m.SpockNodeNames = names + m.NodeOriginNames = names return nil } - m.SpockNodeNames = make(map[string]string) + m.NodeOriginNames = make(map[string]string) if lastErr != nil { return lastErr } - return fmt.Errorf("no nodes available to load spock node names") + return fmt.Errorf("no nodes available to load node origin names") } func (m *MerkleTreeTask) appendDiffs(nodePairKey string, work CompareRangesWorkItem, pr1, pr2 []types.OrderedMap) error { @@ -603,7 +603,7 @@ func (m *MerkleTreeTask) addRowToDiff(nodePairKey, nodeName string, row types.Or } rowMap := utils.OrderedMapToMap(row) - rowMap["node_origin"] = utils.TranslateNodeOrigin(rowMap["node_origin"], m.SpockNodeNames) + rowMap["node_origin"] = utils.TranslateNodeOrigin(rowMap["node_origin"], m.NodeOriginNames) rowWithMeta := utils.AddSpockMetadata(rowMap) orderedRow := utils.MapToOrderedMap(rowWithMeta, m.Cols) @@ -1908,8 +1908,8 @@ func (m *MerkleTreeTask) DiffMtree() (err error) { if err = m.UpdateMtree(true); err != nil { return fmt.Errorf("failed to update merkle tree before diff: %w", err) } - if err := m.loadSpockNodeNames(); err != nil { - logger.Warn("mtree diff: unable to load spock node names; using raw node_origin values: %v", err) + if err := m.loadNodeOriginNames(); err != nil { + logger.Warn("mtree diff: unable to load node origin names; using raw node_origin values: %v", err) } nodePairs := getNodePairs(m.ClusterNodes) mtreeTableIdentifier := pgx.Identifier{m.aceSchema(), fmt.Sprintf("ace_mtree_%s_%s", m.Schema, m.Table)} From 8e3dc78fed393a2dc1205690272a3c101df66d6c Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 14:57:42 -0700 Subject: [PATCH 08/14] fix: store pool before fetchLSNsForNode in autoSelectSourceOfTruth isSpockAvailable reads from t.Pools[nodeName], but the pool was stored after fetchLSNsForNode returned. This caused Spock clusters to silently use native PG LSN queries during recovery-mode auto-selection. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/consistency/repair/table_repair.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index 76ab624..b2730f9 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -2995,17 +2995,19 @@ func (t *TableRepairTask) autoSelectSourceOfTruth(failedNode string, involved ma logger.Warn("recovery-mode: failed to connect to %s for LSN probe: %v", nodeName, err) continue } - originLSN, slotLSN, err := t.fetchLSNsForNode(pool, failedNode, nodeName) - if err != nil { - logger.Warn("recovery-mode: failed to fetch LSNs on %s: %v", nodeName, err) - pool.Close() - continue - } + // Store pool before fetchLSNsForNode so isSpockAvailable can find it if t.Pools[nodeName] == nil { t.Pools[nodeName] = pool } else { pool.Close() + pool = t.Pools[nodeName] + } + + originLSN, slotLSN, err := t.fetchLSNsForNode(pool, failedNode, nodeName) + if err != nil { + logger.Warn("recovery-mode: failed to fetch LSNs on %s: %v", nodeName, err) + continue } lsnDetails[nodeName] = map[string]string{ From e91c3960b4bfce2f2bfa7cfdd43baf45f55a6435 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 15:03:14 -0700 Subject: [PATCH 09/14] fix: use config.Get() in aceTemplateFuncs for thread-safe SIGHUP reload config.Cfg is only safe during single-threaded startup. The aceSchema template function is evaluated at render time from concurrent goroutines, so it must use config.Get() which holds a read lock. Co-Authored-By: Claude Opus 4.6 (1M context) --- db/queries/templates.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db/queries/templates.go b/db/queries/templates.go index 00f4d4e..81d7f37 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -20,7 +20,7 @@ import ( // aceTemplateFuncs provides the {{aceSchema}} function to SQL templates. // The function is evaluated at render time (after config is loaded), not at parse time. var aceTemplateFuncs = template.FuncMap{ - "aceSchema": func() string { return config.Cfg.MTree.Schema }, + "aceSchema": func() string { return config.Get().MTree.Schema }, } type Templates struct { From d04ee5e50a51860e1132efe098973b576cfdb963 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 15:55:50 -0700 Subject: [PATCH 10/14] feat: port 6 data-type tests to run on both spock and native PG Refactored testTableDiff_VariousDataTypes, testTableDiff_UUIDColumn, testTableDiff_ByteaColumnSizeCheck, testTableRepair_VariousDataTypes, testTableRepair_TimestampAndTimeTypes, and testTableRepair_LargeBigintPK to accept *testEnv. Repset calls are conditional on HasSpock, repair_mode uses env.withRepairModeTx. All 6 now run in TestNativePG alongside the existing shared tests. Also adds TestNativePG to CI workflow. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/test.yml | 3 + tests/integration/native_pg_test.go | 27 +++ tests/integration/table_diff_test.go | 310 +++++++++++++------------ tests/integration/table_repair_test.go | 223 +++++++++--------- 4 files changed, 311 insertions(+), 252 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d108d82..dc08b33 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -61,5 +61,8 @@ jobs: - name: Run Merkle Tree --until filter tests run: go test -count=1 -v ./tests/integration -run 'TestMerkleTreeUntilFilter' + - name: Run native PG (no spock) tests + run: go test -count=1 -v ./tests/integration -run 'TestNativePG' + - name: Run timestamp comparison tests run: go test -count=1 -v ./tests/integration -run 'TestCompareTimestampsExact|TestPostgreSQLMicrosecondPrecision|TestOldVsNewComparison' diff --git a/tests/integration/native_pg_test.go b/tests/integration/native_pg_test.go index 0cf5a74..a0aacfd 100644 --- a/tests/integration/native_pg_test.go +++ b/tests/integration/native_pg_test.go @@ -278,6 +278,15 @@ func TestNativePG(t *testing.T) { t.Run("MixedCaseIdentifiers", func(t *testing.T) { testTableDiff_MixedCaseIdentifiers(t, env, false) }) + t.Run("VariousDataTypes", func(t *testing.T) { + testTableDiff_VariousDataTypes(t, env, false) + }) + t.Run("UUIDColumn", func(t *testing.T) { + testTableDiff_UUIDColumn(t, env, false) + }) + t.Run("ByteaColumnSizeCheck", func(t *testing.T) { + testTableDiff_ByteaColumnSizeCheck(t, env, false) + }) }) // ── Shared table-diff tests (composite PK) ─────────────────────────── @@ -299,6 +308,15 @@ func TestNativePG(t *testing.T) { t.Run("MixedCaseIdentifiers", func(t *testing.T) { testTableDiff_MixedCaseIdentifiers(t, env, true) }) + t.Run("VariousDataTypes", func(t *testing.T) { + testTableDiff_VariousDataTypes(t, env, true) + }) + t.Run("UUIDColumn", func(t *testing.T) { + testTableDiff_UUIDColumn(t, env, true) + }) + t.Run("ByteaColumnSizeCheck", func(t *testing.T) { + testTableDiff_ByteaColumnSizeCheck(t, env, true) + }) }) // ── Shared table-repair tests ───────────────────────────────────────── @@ -330,6 +348,15 @@ func TestNativePG(t *testing.T) { t.Run("TableRepair_FixNulls_BidirectionalUpdate", func(t *testing.T) { testTableRepair_FixNulls_BidirectionalUpdate(t, env) }) + t.Run("TableRepair_VariousDataTypes", func(t *testing.T) { + testTableRepair_VariousDataTypes(t, env) + }) + t.Run("TableRepair_TimestampAndTimeTypes", func(t *testing.T) { + testTableRepair_TimestampAndTimeTypes(t, env) + }) + t.Run("TableRepair_LargeBigintPK", func(t *testing.T) { + testTableRepair_LargeBigintPK(t, env) + }) // ── Merkle tree tests ──────────────────────────────────────────────── diff --git a/tests/integration/table_diff_test.go b/tests/integration/table_diff_test.go index 023b3fc..d0bdcb8 100644 --- a/tests/integration/table_diff_test.go +++ b/tests/integration/table_diff_test.go @@ -111,13 +111,13 @@ func TestTableDiffSimplePK(t *testing.T) { testTableDiff_MixedCaseIdentifiers(t, env, false) }) t.Run("VariousDataTypes", func(t *testing.T) { - testTableDiff_VariousDataTypes(t, false) + testTableDiff_VariousDataTypes(t, newSpockEnv(), false) }) t.Run("UUIDColumn", func(t *testing.T) { - testTableDiff_UUIDColumn(t, false) + testTableDiff_UUIDColumn(t, newSpockEnv(), false) }) t.Run("ByteaColumnSizeCheck", func(t *testing.T) { - testTableDiff_ByteaColumnSizeCheck(t, false) + testTableDiff_ByteaColumnSizeCheck(t, newSpockEnv(), false) }) t.Run("WithSpockMetadata", func(t *testing.T) { testTableDiff_WithSpockMetadata(t, false) @@ -145,13 +145,13 @@ func TestTableDiffCompositePK(t *testing.T) { testTableDiff_MixedCaseIdentifiers(t, env, true) }) t.Run("VariousDataTypes", func(t *testing.T) { - testTableDiff_VariousDataTypes(t, true) + testTableDiff_VariousDataTypes(t, newSpockEnv(), true) }) t.Run("UUIDColumn", func(t *testing.T) { - testTableDiff_UUIDColumn(t, true) + testTableDiff_UUIDColumn(t, newSpockEnv(), true) }) t.Run("ByteaColumnSizeCheck", func(t *testing.T) { - testTableDiff_ByteaColumnSizeCheck(t, true) + testTableDiff_ByteaColumnSizeCheck(t, newSpockEnv(), true) }) t.Run("WithSpockMetadata", func(t *testing.T) { testTableDiff_WithSpockMetadata(t, true) @@ -674,10 +674,10 @@ func testTableDiff_MixedCaseIdentifiers(t *testing.T, env *testEnv, compositeKey log.Println("TestTableDiff_MixedCaseIdentifiers completed.") } -func testTableDiff_VariousDataTypes(t *testing.T, compositeKey bool) { +func testTableDiff_VariousDataTypes(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "data_type_test_table" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -706,33 +706,38 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_bytea BYTEA, col_int_array INT[], PRIMARY KEY(id%s) -);`, testSchema, testSchema, tableName, compositeKeyPart) +);`, env.Schema, env.Schema, tableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + nodeNames := []string{env.ServiceN1, env.ServiceN2} + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := nodeNames[i] _, err := pool.Exec(ctx, createDataTypeTableSQL) if err != nil { t.Fatalf("Failed to create data_type_test_table on node %s: %v", nodeName, err) } - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) + _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) if err != nil { t.Fatalf("Failed to truncate data_type_test_table on node %s: %v", nodeName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on n1: %v", err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + } } } log.Printf("Table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -787,55 +792,49 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode: %v", err) - } - _, err = tx.Exec( - ctx, - fmt.Sprintf(insertSQLTemplate, testSchema, tableName), - data["id"], - data["col_smallint"], - data["col_integer"], - data["col_bigint"], - data["col_numeric"], - data["col_real"], - data["col_double"], - data["col_varchar"], - data["col_text"], - data["col_char"], - data["col_boolean"], - data["col_date"], - data["col_timestamp"], - data["col_timestamptz"], - data["col_jsonb"], - data["col_json"], - data["col_bytea"], - data["col_int_array"], - ) - if err != nil { - t.Fatalf("Failed to insert row id %v: %v", data["id"], err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode: %v", err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec( + ctx, + fmt.Sprintf(insertSQLTemplate, env.Schema, tableName), + data["id"], + data["col_smallint"], + data["col_integer"], + data["col_bigint"], + data["col_numeric"], + data["col_real"], + data["col_double"], + data["col_varchar"], + data["col_text"], + data["col_char"], + data["col_boolean"], + data["col_date"], + data["col_timestamp"], + data["col_timestamptz"], + data["col_jsonb"], + data["col_json"], + data["col_bytea"], + data["col_int_array"], + ) + if err != nil { + t.Fatalf("Failed to insert row id %v: %v", data["id"], err) + } + }) if err = tx.Commit(ctx); err != nil { t.Fatalf("Failed to commit transaction: %v", err) } } - insertRow(pgCluster.Node1Pool, row1) - insertRow(pgCluster.Node2Pool, row1) - insertRow(pgCluster.Node1Pool, row2Node1Only) - insertRow(pgCluster.Node2Pool, row3Node2Only) - insertRow(pgCluster.Node1Pool, row4Base) - insertRow(pgCluster.Node2Pool, row4Node2Modified) + insertRow(env.N1Pool, row1) + insertRow(env.N2Pool, row1) + insertRow(env.N1Pool, row2Node1Only) + insertRow(env.N2Pool, row3Node2Only) + insertRow(env.N1Pool, row4Base) + insertRow(env.N2Pool, row4Node2Modified) log.Printf("Data loaded into %s with variations", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -846,10 +845,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( t.Fatalf("ExecuteTask failed for data type table: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -860,20 +856,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) } - if len(nodeDiffs.Rows[serviceN1]) != 2 { + if len(nodeDiffs.Rows[env.ServiceN1]) != 2 { t.Errorf( "Expected 2 rows in diffs for %s, got %d. Rows: %+v", - serviceN1, - len(nodeDiffs.Rows[serviceN1]), - nodeDiffs.Rows[serviceN1], + env.ServiceN1, + len(nodeDiffs.Rows[env.ServiceN1]), + nodeDiffs.Rows[env.ServiceN1], ) } - if len(nodeDiffs.Rows[serviceN2]) != 2 { + if len(nodeDiffs.Rows[env.ServiceN2]) != 2 { t.Errorf( "Expected 2 rows in diffs for %s, got %d. Rows: %+v", - serviceN2, - len(nodeDiffs.Rows[serviceN2]), - nodeDiffs.Rows[serviceN2], + env.ServiceN2, + len(nodeDiffs.Rows[env.ServiceN2]), + nodeDiffs.Rows[env.ServiceN2], ) } @@ -889,37 +885,37 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } foundRow2N1Only := false - for _, r := range nodeDiffs.Rows[serviceN1] { + for _, r := range nodeDiffs.Rows[env.ServiceN1] { if id, ok := r.Get("id"); ok && id == int32(2) { foundRow2N1Only = true break } } if !foundRow2N1Only { - t.Errorf("Row with id=2 (N1 only) not found in %s diffs", serviceN1) + t.Errorf("Row with id=2 (N1 only) not found in %s diffs", env.ServiceN1) } foundRow3N2Only := false - for _, r := range nodeDiffs.Rows[serviceN2] { + for _, r := range nodeDiffs.Rows[env.ServiceN2] { if id, ok := r.Get("id"); ok && id == int32(3) { foundRow3N2Only = true break } } if !foundRow3N2Only { - t.Errorf("Row with id=3 (N2 only) not found in %s diffs", serviceN2) + t.Errorf("Row with id=3 (N2 only) not found in %s diffs", env.ServiceN2) } foundRow4OriginalN1 := false foundRow4ModifiedN2 := false - for _, r := range nodeDiffs.Rows[serviceN1] { + for _, r := range nodeDiffs.Rows[env.ServiceN1] { id, _ := r.Get("id") varchar, _ := r.Get("col_varchar") if id == int32(4) && varchar == "original_varchar_row4" { foundRow4OriginalN1 = true } } - for _, r := range nodeDiffs.Rows[serviceN2] { + for _, r := range nodeDiffs.Rows[env.ServiceN2] { id, _ := r.Get("id") varchar, _ := r.Get("col_varchar") if id == int32(4) && varchar == "MODIFIED_varchar_row4" { @@ -927,19 +923,19 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } if !foundRow4OriginalN1 { - t.Errorf("Original version of row id=4 not found in %s diffs", serviceN1) + t.Errorf("Original version of row id=4 not found in %s diffs", env.ServiceN1) } if !foundRow4ModifiedN2 { - t.Errorf("Modified version of row id=4 not found in %s diffs", serviceN2) + t.Errorf("Modified version of row id=4 not found in %s diffs", env.ServiceN2) } log.Println("TestTableDiff_VariousDataTypes completed.") } -func testTableDiff_UUIDColumn(t *testing.T, compositeKey bool) { +func testTableDiff_UUIDColumn(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "uuid_test_table" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -953,33 +949,38 @@ CREATE TABLE IF NOT EXISTS %s.%s ( name TEXT, col_uuid UUID, PRIMARY KEY(id%s) -);`, testSchema, testSchema, tableName, compositeKeyPart) +);`, env.Schema, env.Schema, tableName, compositeKeyPart) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + nodeNames := []string{env.ServiceN1, env.ServiceN2} + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := nodeNames[i] _, err := pool.Exec(ctx, createTableSQL) if err != nil { t.Fatalf("Failed to create uuid_test_table on node %s: %v", nodeName, err) } - _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", testSchema, tableName)) + _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s.%s CASCADE", env.Schema, tableName)) if err != nil { t.Fatalf("Failed to truncate uuid_test_table on node %s: %v", nodeName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set on %s: %v", nodeName, err) + } } } log.Printf("Table %s created on both nodes", qualifiedTableName) t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -997,7 +998,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( row2UUIDNode1 := "550e8400-e29b-41d4-a716-446655440002" row2UUIDNode2 := "660e8400-e29b-41d4-a716-446655440002" - insertSQL := fmt.Sprintf("INSERT INTO %s.%s (id, name, col_uuid) VALUES ($1, $2, $3)", testSchema, tableName) + insertSQL := fmt.Sprintf("INSERT INTO %s.%s (id, name, col_uuid) VALUES ($1, $2, $3)", env.Schema, tableName) insertRow := func(pool *pgxpool.Pool, id int, name string, uuid string) { tx, err := pool.Begin(ctx) @@ -1006,35 +1007,29 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } defer tx.Rollback(ctx) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - if err != nil { - t.Fatalf("Failed to enable spock repair mode: %v", err) - } - _, err = tx.Exec(ctx, insertSQL, id, name, uuid) - if err != nil { - t.Fatalf("Failed to insert row id %d: %v", id, err) - } - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - if err != nil { - t.Fatalf("Failed to disable spock repair mode: %v", err) - } + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, insertSQL, id, name, uuid) + if err != nil { + t.Fatalf("Failed to insert row id %d: %v", id, err) + } + }) if err = tx.Commit(ctx); err != nil { t.Fatalf("Failed to commit transaction: %v", err) } } // Insert same row on both nodes - insertRow(pgCluster.Node1Pool, 1, "same", row1UUID) - insertRow(pgCluster.Node2Pool, 1, "same", row1UUID) + insertRow(env.N1Pool, 1, "same", row1UUID) + insertRow(env.N2Pool, 1, "same", row1UUID) // Insert row with different UUID on each node - insertRow(pgCluster.Node1Pool, 2, "different", row2UUIDNode1) - insertRow(pgCluster.Node2Pool, 2, "different", row2UUIDNode2) + insertRow(env.N1Pool, 2, "different", row2UUIDNode1) + insertRow(env.N2Pool, 2, "different", row2UUIDNode2) log.Printf("Data loaded into %s with UUID variations", qualifiedTableName) - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err := tdTask.RunChecks(false) if err != nil { @@ -1045,10 +1040,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( t.Fatalf("ExecuteTask failed for UUID table: %v", err) } - pairKey := serviceN1 + "/" + serviceN2 - if strings.Compare(serviceN1, serviceN2) > 0 { - pairKey = serviceN2 + "/" + serviceN1 - } + pairKey := env.pairKey() nodeDiffs, ok := tdTask.DiffResult.NodeDiffs[pairKey] if !ok { @@ -1056,17 +1048,17 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } // Should have 1 diff (row 2 with different UUIDs) - if len(nodeDiffs.Rows[serviceN1]) != 1 { + if len(nodeDiffs.Rows[env.ServiceN1]) != 1 { t.Errorf("Expected 1 row in diffs for %s, got %d. Rows: %+v", - serviceN1, len(nodeDiffs.Rows[serviceN1]), nodeDiffs.Rows[serviceN1]) + env.ServiceN1, len(nodeDiffs.Rows[env.ServiceN1]), nodeDiffs.Rows[env.ServiceN1]) } - if len(nodeDiffs.Rows[serviceN2]) != 1 { + if len(nodeDiffs.Rows[env.ServiceN2]) != 1 { t.Errorf("Expected 1 row in diffs for %s, got %d. Rows: %+v", - serviceN2, len(nodeDiffs.Rows[serviceN2]), nodeDiffs.Rows[serviceN2]) + env.ServiceN2, len(nodeDiffs.Rows[env.ServiceN2]), nodeDiffs.Rows[env.ServiceN2]) } // Verify the UUID is formatted correctly (as a string in standard UUID format) - for _, row := range nodeDiffs.Rows[serviceN1] { + for _, row := range nodeDiffs.Rows[env.ServiceN1] { uuid, ok := row.Get("col_uuid") if !ok { t.Errorf("col_uuid not found in diff row") @@ -1082,7 +1074,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } - for _, row := range nodeDiffs.Rows[serviceN2] { + for _, row := range nodeDiffs.Rows[env.ServiceN2] { uuid, ok := row.Get("col_uuid") if !ok { t.Errorf("col_uuid not found in diff row") @@ -1343,10 +1335,10 @@ func testTableDiff_MaxDiffRowsLimit(t *testing.T, env *testEnv) { log.Println("TestTableDiff_MaxDiffRowsLimit completed.") } -func testTableDiff_ByteaColumnSizeCheck(t *testing.T, compositeKey bool) { +func testTableDiff_ByteaColumnSizeCheck(t *testing.T, env *testEnv, compositeKey bool) { ctx := context.Background() tableName := "bytea_size_test" - qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) + qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) compositeKeyPart := "" if compositeKey { @@ -1362,24 +1354,28 @@ CREATE TABLE IF NOT EXISTS %s ( );`, qualifiedTableName, compositeKeyPart) // Create table on both nodes and add cleanup to drop it - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, createTableSQL) if err != nil { t.Fatalf("Failed to create test table %s: %v", qualifiedTableName, err) } - addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) - _, err = pool.Exec(ctx, addToRepSetSQL) - if err != nil { - t.Fatalf("Failed to add table to replication set on n1: %v", err) + if env.HasSpock { + addToRepSetSQL := fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName) + _, err = pool.Exec(ctx, addToRepSetSQL) + if err != nil { + t.Fatalf("Failed to add table to replication set: %v", err) + } } } t.Cleanup(func() { - removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) - _, err := pgCluster.Node1Pool.Exec(ctx, removeFromRepSetSQL) - if err != nil { - t.Logf("cleanup: failed to remove table from replication set: %v", err) + if env.HasSpock { + removeFromRepSetSQL := fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName) + _, err := env.N1Pool.Exec(ctx, removeFromRepSetSQL) + if err != nil { + t.Logf("cleanup: failed to remove table from replication set: %v", err) + } } - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) if err != nil { t.Logf("Failed to drop test table %s: %v", qualifiedTableName, err) @@ -1394,7 +1390,7 @@ CREATE TABLE IF NOT EXISTS %s ( // --- Test Case 1: Data < 1MB (should pass) --- t.Run("DataUnder1MB", func(t *testing.T) { // Truncate before run - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) if err != nil { t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) @@ -1402,15 +1398,24 @@ CREATE TABLE IF NOT EXISTS %s ( } smallData := make([]byte, 500*1024) // 500 KB - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) + _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) if err != nil { t.Fatalf("Failed to insert small data: %v", err) } - time.Sleep(5 * time.Second) + if env.HasSpock { + // Wait for replication to propagate + time.Sleep(5 * time.Second) + } else { + // On native PG there is no replication; insert on n2 directly + _, err = env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'small', $1)", qualifiedTableName), smallData) + if err != nil { + t.Fatalf("Failed to insert small data on n2: %v", err) + } + } - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err != nil { @@ -1421,22 +1426,31 @@ CREATE TABLE IF NOT EXISTS %s ( // --- Test Case 2: Data > 1MB (should fail) --- t.Run("DataOver1MB", func(t *testing.T) { // Truncate before run - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s", qualifiedTableName)) if err != nil { t.Fatalf("Failed to truncate table %s: %v", qualifiedTableName, err) } } largeData := make([]byte, 1024*1024+1) // > 1 MB - _, err := pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) + _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) if err != nil { t.Fatalf("Failed to insert large data: %v", err) } - time.Sleep(5 * time.Second) + if env.HasSpock { + // Wait for replication to propagate + time.Sleep(5 * time.Second) + } else { + // On native PG there is no replication; insert on n2 directly + _, err = env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (id, name, data) VALUES (1, 'large', $1)", qualifiedTableName), largeData) + if err != nil { + t.Fatalf("Failed to insert large data on n2: %v", err) + } + } - nodesToCompare := []string{serviceN1, serviceN2} - tdTask := newTestTableDiffTask(t, qualifiedTableName, nodesToCompare) + nodesToCompare := []string{env.ServiceN1, env.ServiceN2} + tdTask := env.newTableDiffTask(t, qualifiedTableName, nodesToCompare) err = tdTask.RunChecks(false) if err == nil { diff --git a/tests/integration/table_repair_test.go b/tests/integration/table_repair_test.go index 56940f4..bb59d37 100644 --- a/tests/integration/table_repair_test.go +++ b/tests/integration/table_repair_test.go @@ -570,7 +570,7 @@ func TestTableRepair_GenerateReport(t *testing.T) { testTableRepair_GenerateReport(t, newSpockEnv()) } -func TestTableRepair_VariousDataTypes(t *testing.T) { +func testTableRepair_VariousDataTypes(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "data_type_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -601,19 +601,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_text_array TEXT[] );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createDataTypeTableSQL) require.NoErrorf(t, err, "Failed to create data type table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate data type table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -627,23 +631,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec( - ctx, - fmt.Sprintf(`INSERT INTO %s (id, col_smallint, col_integer, col_bigint, col_numeric, col_real, col_double, col_varchar, col_text, col_char, col_boolean, col_date, col_timestamp, col_timestamptz, col_interval, col_jsonb, col_json, col_bytea, col_int_array, col_text_array) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec( + ctx, + fmt.Sprintf(`INSERT INTO %s (id, col_smallint, col_integer, col_bigint, col_numeric, col_real, col_double, col_varchar, col_text, col_char, col_boolean, col_date, col_timestamp, col_timestamptz, col_interval, col_jsonb, col_json, col_bytea, col_int_array, col_text_array) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)`, qualifiedTableName), - data["id"], data["col_smallint"], data["col_integer"], data["col_bigint"], - data["col_numeric"], data["col_real"], data["col_double"], data["col_varchar"], - data["col_text"], data["col_char"], data["col_boolean"], data["col_date"], - data["col_timestamp"], data["col_timestamptz"], data["col_interval"], data["col_jsonb"], data["col_json"], - data["col_bytea"], data["col_int_array"], data["col_text_array"], - ) - require.NoError(t, err) + data["id"], data["col_smallint"], data["col_integer"], data["col_bigint"], + data["col_numeric"], data["col_real"], data["col_double"], data["col_varchar"], + data["col_text"], data["col_char"], data["col_boolean"], data["col_date"], + data["col_timestamp"], data["col_timestamptz"], data["col_interval"], data["col_jsonb"], data["col_json"], + data["col_bytea"], data["col_int_array"], data["col_text_array"], + ) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } @@ -692,20 +693,20 @@ CREATE TABLE IF NOT EXISTS %s.%s ( "col_numeric": "444.44", "col_varchar": "only_on_n2", } - insertRow(pgCluster.Node1Pool, row1) - insertRow(pgCluster.Node2Pool, row1) - insertRow(pgCluster.Node1Pool, row2OnlyN1) - insertRow(pgCluster.Node1Pool, row3Base) - insertRow(pgCluster.Node2Pool, row3ModifiedN2) - insertRow(pgCluster.Node2Pool, row4OnlyN2) + insertRow(env.N1Pool, row1) + insertRow(env.N2Pool, row1) + insertRow(env.N1Pool, row2OnlyN1) + insertRow(env.N1Pool, row3Base) + insertRow(env.N2Pool, row3ModifiedN2) + insertRow(env.N2Pool, row4OnlyN2) - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err := repairTask.Run(false) require.NoError(t, err, "Table repair for various data types failed") - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) checkRow3 := func(pool *pgxpool.Pool) map[string]any { var ( @@ -733,8 +734,8 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } } - row3N1 := checkRow3(pgCluster.Node1Pool) - row3N2 := checkRow3(pgCluster.Node2Pool) + row3N1 := checkRow3(env.N1Pool) + row3N2 := checkRow3(env.N2Pool) assert.Equal(t, row3N1["col_numeric"], row3N2["col_numeric"]) assert.Equal(t, row3N1["col_jsonb"], row3N2["col_jsonb"]) @@ -744,11 +745,15 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.True(t, bytes.Equal(row3N1["col_bytea"].([]byte), row3N2["col_bytea"].([]byte))) var row4Count int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 4", qualifiedTableName)).Scan(&row4Count) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 4", qualifiedTableName)).Scan(&row4Count) require.NoError(t, err) assert.Equal(t, 0, row4Count, "Row present only on node2 should be deleted") } +func TestTableRepair_VariousDataTypes(t *testing.T) { + testTableRepair_VariousDataTypes(t, newSpockEnv()) +} + func testTableRepair_FixNulls(t *testing.T, env *testEnv) { tableName := "customers" qualifiedTableName := fmt.Sprintf("%s.%s", env.Schema, tableName) @@ -895,12 +900,12 @@ func TestTableRepair_FixNulls_DryRun(t *testing.T) { testTableRepair_FixNulls_DryRun(t, newSpockEnv()) } -// TestTableRepair_TimestampAndTimeTypes verifies that the full diff→repair +// testTableRepair_TimestampAndTimeTypes verifies that the full diff→repair // pipeline works correctly for timestamp, timestamptz, time, and timetz columns. // This guards against pgx sending the wrong OID for timestamp-without-tz // (which would cause PostgreSQL to apply a session-timezone shift) and ensures // that pgtype.Time values round-trip through diff JSON and back correctly. -func TestTableRepair_TimestampAndTimeTypes(t *testing.T) { +func testTableRepair_TimestampAndTimeTypes(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "temporal_type_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -915,19 +920,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( col_timetz TIME WITH TIME ZONE );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createSQL) require.NoErrorf(t, err, "Failed to create temporal table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate temporal table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add temporal table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add temporal table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -941,18 +950,15 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, fmt.Sprintf( - `INSERT INTO %s (id, col_ts, col_tstz, col_time, col_timetz) VALUES ($1, $2, $3, $4::time, $5::timetz)`, - qualifiedTableName), - id, ts, tstz, timeStr, timetzStr, - ) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, fmt.Sprintf( + `INSERT INTO %s (id, col_ts, col_tstz, col_time, col_timetz) VALUES ($1, $2, $3, $4::time, $5::timetz)`, + qualifiedTableName), + id, ts, tstz, timeStr, timetzStr, + ) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } @@ -960,11 +966,11 @@ CREATE TABLE IF NOT EXISTS %s.%s ( refTSTZ := time.Date(2024, 6, 15, 18, 45, 30, 0, time.UTC) // Row 1: identical on both nodes (baseline) - insertRow(pgCluster.Node1Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") - insertRow(pgCluster.Node2Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") + insertRow(env.N1Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") + insertRow(env.N2Pool, 1, refTS, refTSTZ, "08:15:30.123456", "14:30:00-05") // Row 2: only on node1 (missing on node2) - insertRow(pgCluster.Node1Pool, 2, + insertRow(env.N1Pool, 2, refTS.Add(1*time.Hour), refTSTZ.Add(1*time.Hour), "12:00:00", @@ -972,7 +978,7 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // Row 3: only on node2 (should be deleted when n1 is source of truth) - insertRow(pgCluster.Node2Pool, 3, + insertRow(env.N2Pool, 3, refTS.Add(2*time.Hour), refTSTZ.Add(2*time.Hour), "23:59:59.999999", @@ -980,13 +986,13 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // Row 4: different values on each node (row mismatch) - insertRow(pgCluster.Node1Pool, 4, + insertRow(env.N1Pool, 4, refTS.Add(3*time.Hour), refTSTZ.Add(3*time.Hour), "06:30:00", "16:45:00-07", ) - insertRow(pgCluster.Node2Pool, 4, + insertRow(env.N2Pool, 4, refTS.Add(99*time.Hour), // deliberately different refTSTZ.Add(99*time.Hour), "22:00:00.500000", @@ -994,24 +1000,24 @@ CREATE TABLE IF NOT EXISTS %s.%s ( ) // ----- Diff ----- - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // ----- Repair (node1 = source of truth) ----- - repairTask := newTestTableRepairTask(serviceN1, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) err := repairTask.Run(false) require.NoError(t, err, "Repair for temporal types failed") // ----- Verify: no diffs remain ----- - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) // ----- Verify: row counts match ----- - countN1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - countN2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + countN1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + countN2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, countN1, countN2, "Row counts should match after repair") // ----- Verify: row 3 was deleted from node2 ----- var row3Count int - err = pgCluster.Node2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 3", qualifiedTableName)).Scan(&row3Count) + err = env.N2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = 3", qualifiedTableName)).Scan(&row3Count) require.NoError(t, err) assert.Equal(t, 0, row3Count, "Row 3 (only on node2) should be deleted") @@ -1031,30 +1037,34 @@ CREATE TABLE IF NOT EXISTS %s.%s ( return r } - r2n1 := readRow(pgCluster.Node1Pool, 2) - r2n2 := readRow(pgCluster.Node2Pool, 2) + r2n1 := readRow(env.N1Pool, 2) + r2n2 := readRow(env.N2Pool, 2) assert.True(t, r2n1.ts.Equal(r2n2.ts), "row 2 col_ts: n1=%v n2=%v", r2n1.ts, r2n2.ts) assert.True(t, r2n1.tstz.Equal(r2n2.tstz), "row 2 col_tstz: n1=%v n2=%v", r2n1.tstz, r2n2.tstz) assert.Equal(t, r2n1.timeV, r2n2.timeV, "row 2 col_time mismatch") assert.Equal(t, r2n1.timetz, r2n2.timetz, "row 2 col_timetz mismatch") // ----- Verify: row 4 was repaired to match node1 ----- - r4n1 := readRow(pgCluster.Node1Pool, 4) - r4n2 := readRow(pgCluster.Node2Pool, 4) + r4n1 := readRow(env.N1Pool, 4) + r4n2 := readRow(env.N2Pool, 4) assert.True(t, r4n1.ts.Equal(r4n2.ts), "row 4 col_ts: n1=%v n2=%v", r4n1.ts, r4n2.ts) assert.True(t, r4n1.tstz.Equal(r4n2.tstz), "row 4 col_tstz: n1=%v n2=%v", r4n1.tstz, r4n2.tstz) assert.Equal(t, r4n1.timeV, r4n2.timeV, "row 4 col_time mismatch") assert.Equal(t, r4n1.timetz, r4n2.timetz, "row 4 col_timetz mismatch") // ----- Verify: row 1 (baseline) is still identical ----- - r1n1 := readRow(pgCluster.Node1Pool, 1) - r1n2 := readRow(pgCluster.Node2Pool, 1) + r1n1 := readRow(env.N1Pool, 1) + r1n2 := readRow(env.N2Pool, 1) assert.True(t, r1n1.ts.Equal(r1n2.ts), "row 1 col_ts: n1=%v n2=%v", r1n1.ts, r1n2.ts) assert.True(t, r1n1.tstz.Equal(r1n2.tstz), "row 1 col_tstz: n1=%v n2=%v", r1n1.tstz, r1n2.tstz) assert.Equal(t, r1n1.timeV, r1n2.timeV, "row 1 col_time mismatch") assert.Equal(t, r1n1.timetz, r1n2.timetz, "row 1 col_timetz mismatch") } +func TestTableRepair_TimestampAndTimeTypes(t *testing.T) { + testTableRepair_TimestampAndTimeTypes(t, newSpockEnv()) +} + // testTableRepair_FixNulls_BidirectionalUpdate tests that when both nodes have NULLs // in different columns for the same row, fix-nulls performs bidirectional updates. // This verifies the behavior discussed in code review: each node updates the other @@ -2070,7 +2080,7 @@ func TestTableRepair_MixedOps_PreserveOrigin(t *testing.T) { log.Println(" - UPDATE: 2 modified rows corrected with preserved origin/timestamp") } -// TestTableRepair_LargeBigintPK verifies that table repair correctly handles +// testTableRepair_LargeBigintPK verifies that table repair correctly handles // bigint primary keys whose values exceed float64's exact integer range (2^53). // Before the fix, JSON deserialization converted these to float64, silently // truncating the PK values and causing: @@ -2080,7 +2090,7 @@ func TestTableRepair_MixedOps_PreserveOrigin(t *testing.T) { // // The test uses PKs that are adjacent integers above 2^53, which all collapse // to the same float64 value without the json.Number fix. -func TestTableRepair_LargeBigintPK(t *testing.T) { +func testTableRepair_LargeBigintPK(t *testing.T, env *testEnv) { ctx := context.Background() tableName := "bigint_pk_repair" qualifiedTableName := fmt.Sprintf("%s.%s", testSchema, tableName) @@ -2103,19 +2113,23 @@ CREATE TABLE IF NOT EXISTS %s.%s ( amount NUMERIC(20, 4) );`, testSchema, testSchema, tableName) - for i, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { - nodeName := pgCluster.ClusterNodes[i]["Name"].(string) + for i, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { + nodeName := env.ClusterNodes[i]["Name"].(string) _, err := pool.Exec(ctx, createTableSQL) require.NoErrorf(t, err, "Failed to create table on %s", nodeName) _, err = pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoErrorf(t, err, "Failed to truncate table on %s", nodeName) - _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) - require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + if env.HasSpock { + _, err = pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_add_table('default', '%s');`, qualifiedTableName)) + require.NoErrorf(t, err, "Failed to add table to repset on %s", nodeName) + } } t.Cleanup(func() { - _, _ = pgCluster.Node1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) - for _, pool := range []*pgxpool.Pool{pgCluster.Node1Pool, pgCluster.Node2Pool} { + if env.HasSpock { + _, _ = env.N1Pool.Exec(ctx, fmt.Sprintf(`SELECT spock.repset_remove_table('default', '%s');`, qualifiedTableName)) + } + for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) } files, _ := filepath.Glob("*_diffs-*.json") @@ -2129,36 +2143,33 @@ CREATE TABLE IF NOT EXISTS %s.%s ( require.NoError(t, err) defer func() { _ = tx.Rollback(ctx) }() - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(true)") - require.NoError(t, err) - - _, err = tx.Exec(ctx, - fmt.Sprintf("INSERT INTO %s (id, data, amount) VALUES ($1, $2, $3::numeric)", qualifiedTableName), - id, data, amount) - require.NoError(t, err) + env.withRepairModeTx(t, ctx, tx, func() { + _, err = tx.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (id, data, amount) VALUES ($1, $2, $3::numeric)", qualifiedTableName), + id, data, amount) + require.NoError(t, err) + }) - _, err = tx.Exec(ctx, "SELECT spock.repair_mode(false)") - require.NoError(t, err) require.NoError(t, tx.Commit(ctx)) } // ---- Set up divergence ---- // Common row on both nodes (same data) - insertRow(pgCluster.Node1Pool, collisionPKs[0], "common_row", "1000000.1234") - insertRow(pgCluster.Node2Pool, collisionPKs[0], "common_row", "1000000.1234") + insertRow(env.N1Pool, collisionPKs[0], "common_row", "1000000.1234") + insertRow(env.N2Pool, collisionPKs[0], "common_row", "1000000.1234") // Rows only on node1 (should be deleted when node2 is source of truth) - insertRow(pgCluster.Node1Pool, collisionPKs[1], "n1_only_289", "2000000.5678") + insertRow(env.N1Pool, collisionPKs[1], "n1_only_289", "2000000.5678") // Rows only on node2 (should be inserted into node1) - insertRow(pgCluster.Node2Pool, collisionPKs[2], "n2_only_290", "3000000.9012") + insertRow(env.N2Pool, collisionPKs[2], "n2_only_290", "3000000.9012") // Modified row: same PK on both nodes, different data - insertRow(pgCluster.Node1Pool, collisionPKs[3], "old_data_291", "4000000.0001") - insertRow(pgCluster.Node2Pool, collisionPKs[3], "new_data_291", "4000000.9999") + insertRow(env.N1Pool, collisionPKs[3], "old_data_291", "4000000.0001") + insertRow(env.N2Pool, collisionPKs[3], "new_data_291", "4000000.9999") // ---- Run diff ---- - diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) + diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) // Verify diff found the expected number of differences: // - collisionPKs[1]: node1-only → 1 diff @@ -2180,19 +2191,19 @@ CREATE TABLE IF NOT EXISTS %s.%s ( assert.Equal(t, 3, totalDiffs, "Expected exactly 3 differences before repair") // ---- Run repair (node2 is source of truth) ---- - repairTask := newTestTableRepairTask(serviceN2, qualifiedTableName, diffFile) + repairTask := env.newTableRepairTask(env.ServiceN2, qualifiedTableName, diffFile) err = repairTask.Run(false) require.NoError(t, err, "Table repair failed") // ---- Verify repair ---- // 1. Tables should now match (zero diffs) - assertNoTableDiff(t, qualifiedTableName) + env.assertNoTableDiff(t, qualifiedTableName) // 2. Row counts should match: 3 rows (common + n2_only_290 + modified_291) // collisionPKs[1] (n1_only_289) should have been deleted from node1 - count1 := getTableCount(t, ctx, pgCluster.Node1Pool, qualifiedTableName) - count2 := getTableCount(t, ctx, pgCluster.Node2Pool, qualifiedTableName) + count1 := getTableCount(t, ctx, env.N1Pool, qualifiedTableName) + count2 := getTableCount(t, ctx, env.N2Pool, qualifiedTableName) assert.Equal(t, count1, count2, "Row counts should match after repair") assert.Equal(t, 3, count1, "Expected 3 rows after repair") @@ -2209,26 +2220,30 @@ CREATE TABLE IF NOT EXISTS %s.%s ( } // Common row should be unchanged - verifyRow(pgCluster.Node1Pool, collisionPKs[0], "common_row", "1000000.1234") + verifyRow(env.N1Pool, collisionPKs[0], "common_row", "1000000.1234") // Node1-only row should have been deleted var deletedCount int - err = pgCluster.Node1Pool.QueryRow(ctx, + err = env.N1Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s WHERE id = $1", qualifiedTableName), collisionPKs[1]). Scan(&deletedCount) require.NoError(t, err) assert.Equal(t, 0, deletedCount, "Node1-only row (PK %d) should have been deleted", collisionPKs[1]) // Node2-only row should have been inserted into node1 - verifyRow(pgCluster.Node1Pool, collisionPKs[2], "n2_only_290", "3000000.9012") + verifyRow(env.N1Pool, collisionPKs[2], "n2_only_290", "3000000.9012") // Modified row should have node2's version on node1 - verifyRow(pgCluster.Node1Pool, collisionPKs[3], "new_data_291", "4000000.9999") + verifyRow(env.N1Pool, collisionPKs[3], "new_data_291", "4000000.9999") - log.Println("TestTableRepair_LargeBigintPK PASSED") + log.Println("testTableRepair_LargeBigintPK PASSED") log.Println(" - 4 adjacent PKs above 2^53 that collide under float64") log.Println(" - DELETE: node1-only row correctly removed") log.Println(" - INSERT: node2-only row correctly added to node1") log.Println(" - UPDATE: modified row correctly updated on node1") log.Println(" - All PKs preserved with exact precision (no float64 truncation)") } + +func TestTableRepair_LargeBigintPK(t *testing.T) { + testTableRepair_LargeBigintPK(t, newSpockEnv()) +} From 4dfd7c2335787b73d7977691facb6e9a7ecd3af0 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 8 Apr 2026 17:00:13 -0700 Subject: [PATCH 11/14] fix: return error from isSpockAvailable and remove unused composite param isSpockAvailable now returns (bool, error) so callers don't silently fall back to native PG mode on detection failures. Also removes the unused composite parameter from setupDivergence. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/consistency/repair/table_repair.go | 32 ++++++++++++++------- tests/integration/advanced_repair_test.go | 6 ++-- tests/integration/table_repair_test.go | 14 ++++----- tests/integration/test_env_test.go | 2 +- 4 files changed, 32 insertions(+), 22 deletions(-) diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index b2730f9..f908dc9 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -163,25 +163,25 @@ func (t *TableRepairTask) setRole(tx pgx.Tx, nodeName string) error { } // isSpockAvailable returns whether spock is installed on the given node, -// detecting lazily on first check per node. -func (t *TableRepairTask) isSpockAvailable(nodeName string) bool { +// detecting lazily on first check per node. Returns an error if detection +// fails so callers don't silently fall back to the wrong repair mode. +func (t *TableRepairTask) isSpockAvailable(nodeName string) (bool, error) { if t.spockPerNode == nil { t.spockPerNode = make(map[string]bool) } if _, checked := t.spockPerNode[nodeName]; !checked { pool := t.Pools[nodeName] if pool == nil { - return false + return false, fmt.Errorf("no connection pool for node %s", nodeName) } spockInstalled, err := queries.CheckSpockInstalled(t.Ctx, pool) if err != nil { - logger.Warn("failed to detect spock extension on %s: %v", nodeName, err) - return false + return false, fmt.Errorf("failed to detect spock extension on %s: %w", nodeName, err) } t.spockPerNode[nodeName] = spockInstalled logger.Info("spock extension on %s: %v", nodeName, spockInstalled) } - return t.spockPerNode[nodeName] + return t.spockPerNode[nodeName], nil } // setupTransactionMode enables spock repair mode (when available), sets the session replication @@ -189,7 +189,11 @@ func (t *TableRepairTask) isSpockAvailable(nodeName string) bool { // Returns true if spock repair mode was activated, false otherwise. func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool, error) { spockRepairModeActive := false - if t.isSpockAvailable(nodeName) { + spock, err := t.isSpockAvailable(nodeName) + if err != nil { + return false, err + } + if spock { if _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(true)"); err != nil { return false, fmt.Errorf("enabling spock.repair_mode(true) on %s: %w", nodeName, err) } @@ -197,7 +201,6 @@ func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool spockRepairModeActive = true } - var err error if t.FireTriggers { _, err = tx.Exec(t.Ctx, "SET session_replication_role = 'local'") } else { @@ -217,10 +220,14 @@ func (t *TableRepairTask) setupTransactionMode(tx pgx.Tx, nodeName string) (bool // disableSpockRepairMode disables spock repair mode on the given transaction. // When spock is not available, this is a no-op. func (t *TableRepairTask) disableSpockRepairMode(tx pgx.Tx, nodeName string) error { - if !t.isSpockAvailable(nodeName) { + spock, err := t.isSpockAvailable(nodeName) + if err != nil { + return err + } + if !spock { return nil } - _, err := tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") + _, err = tx.Exec(t.Ctx, "SELECT spock.repair_mode(false)") if err != nil { return fmt.Errorf("disabling spock.repair_mode(false) on %s: %w", nodeName, err) } @@ -2918,7 +2925,10 @@ func calculateRepairSetsWithSourceOfTruth(task *TableRepairTask) (map[string]map } func (t *TableRepairTask) fetchLSNsForNode(pool *pgxpool.Pool, failedNode, survivor string) (originLSN *uint64, slotLSN *uint64, err error) { - spock := t.isSpockAvailable(survivor) + spock, err := t.isSpockAvailable(survivor) + if err != nil { + return nil, nil, fmt.Errorf("detecting spock on %s: %w", survivor, err) + } var originStr *string if spock { originStr, err = queries.GetSpockOriginLSNForNode(t.Ctx, pool, failedNode) diff --git a/tests/integration/advanced_repair_test.go b/tests/integration/advanced_repair_test.go index 38c6b4f..f29ad56 100644 --- a/tests/integration/advanced_repair_test.go +++ b/tests/integration/advanced_repair_test.go @@ -30,7 +30,7 @@ func TestAdvancedRepairPlan_MixedSelectorsAndActions(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) plan := ` @@ -87,7 +87,7 @@ func TestAdvancedRepairPlan_DeleteAndWhenPredicate(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) plan := ` @@ -140,7 +140,7 @@ func TestAdvancedRepairPlan_StaleRepairsSkipped(t *testing.T) { ctx := context.Background() qualifiedTableName := "public.customers" - setupDivergence(t, ctx, qualifiedTableName, false) + setupDivergence(t, ctx, qualifiedTableName) diffFile := runTableDiff(t, qualifiedTableName, []string{serviceN1, serviceN2}) beforeLogs := listStaleSkipLogs(t) diff --git a/tests/integration/table_repair_test.go b/tests/integration/table_repair_test.go index bb59d37..f73ef4a 100644 --- a/tests/integration/table_repair_test.go +++ b/tests/integration/table_repair_test.go @@ -37,9 +37,9 @@ import ( // - 2 rows only on node1 (IDs 1001, 1002) // - 2 rows only on node2 (IDs 2001, 2002) // - 2 common rows modified on node2 (IDs 1, 2) -func setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string, composite bool) { +func setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { t.Helper() - newSpockEnv().setupDivergence(t, ctx, qualifiedTableName, composite) + newSpockEnv().setupDivergence(t, ctx, qualifiedTableName) } // runTableDiff executes a table-diff task and returns the path to the latest diff file. @@ -169,7 +169,7 @@ func testTableRepair_UnidirectionalDefault(t *testing.T, env *testEnv) { tc.setup() t.Cleanup(tc.teardown) - env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { env.repairTable(t, qualifiedTableName, env.ServiceN1) }) @@ -235,7 +235,7 @@ func testTableRepair_InsertOnly(t *testing.T, env *testEnv) { tc.setup() t.Cleanup(tc.teardown) - env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { env.repairTable(t, qualifiedTableName, env.ServiceN1) }) @@ -311,7 +311,7 @@ func testTableRepair_UpsertOnly(t *testing.T, env *testEnv) { tc.setup() t.Cleanup(tc.teardown) - env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { env.repairTable(t, qualifiedTableName, env.ServiceN1) }) @@ -461,7 +461,7 @@ func testTableRepair_DryRun(t *testing.T, env *testEnv) { tc.setup() t.Cleanup(tc.teardown) - env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) t.Cleanup(func() { env.repairTable(t, qualifiedTableName, env.ServiceN1) }) @@ -532,7 +532,7 @@ func testTableRepair_GenerateReport(t *testing.T, env *testEnv) { }) os.RemoveAll(reportDir) - env.setupDivergence(t, ctx, qualifiedTableName, tc.composite) + env.setupDivergence(t, ctx, qualifiedTableName) diffFile := env.runTableDiff(t, qualifiedTableName, []string{env.ServiceN1, env.ServiceN2}) repairTask := env.newTableRepairTask(env.ServiceN1, qualifiedTableName, diffFile) diff --git a/tests/integration/test_env_test.go b/tests/integration/test_env_test.go index 962d730..f79e8a8 100644 --- a/tests/integration/test_env_test.go +++ b/tests/integration/test_env_test.go @@ -199,7 +199,7 @@ func (e *testEnv) newTableRepairTask(sourceOfTruthNode, qualifiedTableName, diff // - 2 rows only on n1 (index 1001, 1002) // - 2 rows only on n2 (index 2001, 2002) // - 2 common rows modified on n2 (index 1, 2) -func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string, composite bool) { +func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTableName string) { t.Helper() log.Println("Setting up data divergence for", qualifiedTableName) From 4bfc1f471723f21bdaf8a75adceb163820bf2e7a Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Mon, 13 Apr 2026 19:37:55 -0700 Subject: [PATCH 12/14] fix: sanitize aceSchema identifier to prevent SQL breakage with non-simple names Use pgx.Identifier.Sanitize() to quote the configurable ACE schema name in SQL templates, consistent with how all other identifiers are handled in queries.go. Co-Authored-By: Claude Opus 4.6 (1M context) --- db/queries/templates.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/db/queries/templates.go b/db/queries/templates.go index 81d7f37..c067413 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -14,13 +14,14 @@ package queries import ( "text/template" + "github.com/jackc/pgx/v5" "github.com/pgedge/ace/pkg/config" ) // aceTemplateFuncs provides the {{aceSchema}} function to SQL templates. // The function is evaluated at render time (after config is loaded), not at parse time. var aceTemplateFuncs = template.FuncMap{ - "aceSchema": func() string { return config.Get().MTree.Schema }, + "aceSchema": func() string { return pgx.Identifier{config.Get().MTree.Schema}.Sanitize() }, } type Templates struct { From 24b9b1f69a94c5805a504fda83e9cfa78338ede4 Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Tue, 14 Apr 2026 11:32:25 -0700 Subject: [PATCH 13/14] fix: pin spock.repair_mode to a single pooled connection in test helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit withRepairMode previously used pool.Exec for enable, fn body, and disable — each call could hit a different backend session, making repair_mode ineffective. Now acquires a dedicated connection via pool.Acquire and passes it to the callback. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/integration/table_repair_test.go | 28 ++++++------ tests/integration/test_env_test.go | 59 ++++++++++++++------------ 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/tests/integration/table_repair_test.go b/tests/integration/table_repair_test.go index f73ef4a..11c7954 100644 --- a/tests/integration/table_repair_test.go +++ b/tests/integration/table_repair_test.go @@ -387,21 +387,21 @@ func testTableRepair_Bidirectional(t *testing.T, env *testEnv) { log.Println("Setting up data for bidirectional test") for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { - env.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + env.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoError(t, err) }) } - env.withRepairMode(t, ctx, env.N1Pool, func() { + env.withRepairMode(t, ctx, env.N1Pool, func(conn *pgxpool.Conn) { for i := 3001; i <= 3003; i++ { - _, err := env.N1Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1-Bi-%d", i)) require.NoError(t, err) } }) - env.withRepairMode(t, ctx, env.N2Pool, func() { + env.withRepairMode(t, ctx, env.N2Pool, func(conn *pgxpool.Conn) { for i := 4001; i <= 4002; i++ { - _, err := env.N2Pool.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name) VALUES ($1, $2, $3)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2-Bi-%d", i)) require.NoError(t, err) } }) @@ -1118,8 +1118,8 @@ func testTableRepair_FixNulls_BidirectionalUpdate(t *testing.T, env *testEnv) { // Clean table on both nodes for _, pool := range []*pgxpool.Pool{env.N1Pool, env.N2Pool} { - env.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + env.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoError(t, err, "Failed to truncate table") }) } @@ -1133,22 +1133,22 @@ func testTableRepair_FixNulls_BidirectionalUpdate(t *testing.T, env *testEnv) { ) // Node1 data - env.withRepairMode(t, ctx, env.N1Pool, func() { + env.withRepairMode(t, ctx, env.N1Pool, func(conn *pgxpool.Conn) { // Row 1 on Node1: NULL first_name and city - _, err := env.N1Pool.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") + _, err := conn.Exec(ctx, insertSQL, 100, "CUST-100", nil, "LastName100", nil, "email100@example.com") require.NoError(t, err) // Row 2 on Node1: NULL last_name and email - _, err = env.N1Pool.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) + _, err = conn.Exec(ctx, insertSQL, 200, "CUST-200", "FirstName200", nil, "City200", nil) require.NoError(t, err) }) // Node2 data - env.withRepairMode(t, ctx, env.N2Pool, func() { + env.withRepairMode(t, ctx, env.N2Pool, func(conn *pgxpool.Conn) { // Row 1 on Node2: NULL last_name and email - _, err := env.N2Pool.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) + _, err := conn.Exec(ctx, insertSQL, 100, "CUST-100", "FirstName100", nil, "City100", nil) require.NoError(t, err) // Row 2 on Node2: NULL first_name and city - _, err = env.N2Pool.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") + _, err = conn.Exec(ctx, insertSQL, 200, "CUST-200", nil, "LastName200", nil, "email200@example.com") require.NoError(t, err) }) diff --git a/tests/integration/test_env_test.go b/tests/integration/test_env_test.go index f79e8a8..3e81fe9 100644 --- a/tests/integration/test_env_test.go +++ b/tests/integration/test_env_test.go @@ -121,19 +121,26 @@ func newNativeEnv(state *nativeClusterState) *testEnv { } } -// withRepairMode wraps a function with spock.repair_mode(true/false) when spock -// is available. On native PG (no replication), this is a no-op. -func (e *testEnv) withRepairMode(t *testing.T, ctx context.Context, pool *pgxpool.Pool, fn func()) { +// withRepairMode acquires a single connection from the pool, enables +// spock.repair_mode on it (when available), runs fn, then disables repair mode +// and releases the connection. On native PG this simply pins a connection for +// the duration of fn. All work inside fn must use the provided conn so that +// repair_mode is in effect. +func (e *testEnv) withRepairMode(t *testing.T, ctx context.Context, pool *pgxpool.Pool, fn func(conn *pgxpool.Conn)) { t.Helper() + conn, err := pool.Acquire(ctx) + require.NoError(t, err, "acquire connection for repair mode") + defer conn.Release() + if e.HasSpock { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(true)") + _, err := conn.Exec(ctx, "SELECT spock.repair_mode(true)") require.NoError(t, err, "enable spock.repair_mode") defer func() { - _, err := pool.Exec(ctx, "SELECT spock.repair_mode(false)") + _, err := conn.Exec(ctx, "SELECT spock.repair_mode(false)") require.NoError(t, err, "disable spock.repair_mode") }() } - fn() + fn(conn) } // withRepairModeTx is like withRepairMode but operates on a transaction. @@ -205,8 +212,8 @@ func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTa // Truncate on both nodes for _, pool := range e.pools() { - e.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoError(t, err, "truncate table") }) } @@ -214,8 +221,8 @@ func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTa // Insert common rows for i := 1; i <= 5; i++ { for _, pool := range e.pools() { - e.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("FirstName%d", i), fmt.Sprintf("LastName%d", i), fmt.Sprintf("email%d@example.com", i)) require.NoError(t, err) @@ -224,9 +231,9 @@ func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTa } // Rows only on n1 - e.withRepairMode(t, ctx, e.N1Pool, func() { + e.withRepairMode(t, ctx, e.N1Pool, func(conn *pgxpool.Conn) { for i := 1001; i <= 1002; i++ { - _, err := e.N1Pool.Exec(ctx, + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N1OnlyFirst%d", i), fmt.Sprintf("N1OnlyLast%d", i), fmt.Sprintf("n1.only%d@example.com", i)) require.NoError(t, err) @@ -234,9 +241,9 @@ func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTa }) // Rows only on n2 - e.withRepairMode(t, ctx, e.N2Pool, func() { + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { for i := 2001; i <= 2002; i++ { - _, err := e.N2Pool.Exec(ctx, + _, err := conn.Exec(ctx, fmt.Sprintf("INSERT INTO %s (index, customer_id, first_name, last_name, email) VALUES ($1, $2, $3, $4, $5)", qualifiedTableName), i, fmt.Sprintf("CUST-%d", i), fmt.Sprintf("N2OnlyFirst%d", i), fmt.Sprintf("N2OnlyLast%d", i), fmt.Sprintf("n2.only%d@example.com", i)) require.NoError(t, err) @@ -244,9 +251,9 @@ func (e *testEnv) setupDivergence(t *testing.T, ctx context.Context, qualifiedTa }) // Modify rows on n2 - e.withRepairMode(t, ctx, e.N2Pool, func() { + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { for i := 1; i <= 2; i++ { - _, err := e.N2Pool.Exec(ctx, + _, err := conn.Exec(ctx, fmt.Sprintf("UPDATE %s SET email = $1 WHERE index = $2", qualifiedTableName), fmt.Sprintf("modified.email%d@example.com", i), i) require.NoError(t, err) @@ -263,8 +270,8 @@ func (e *testEnv) setupNullDivergence(t *testing.T, ctx context.Context, qualifi log.Println("Setting up null divergence for", qualifiedTableName) for _, pool := range e.pools() { - e.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoError(t, err, "truncate table") }) } @@ -275,18 +282,18 @@ func (e *testEnv) setupNullDivergence(t *testing.T, ctx context.Context, qualifi ) // Node1: missing city for id 1, missing first_name for id 2 - e.withRepairMode(t, ctx, e.N1Pool, func() { - _, err := e.N1Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) + e.withRepairMode(t, ctx, e.N1Pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", "Schumacher", nil) require.NoError(t, err) - _, err = e.N1Pool.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") + _, err = conn.Exec(ctx, insertSQL, 2, "CUST-2", nil, "Alonso", "Oviedo") require.NoError(t, err) }) // Node2: missing last_name for id 1, missing city for id 2 - e.withRepairMode(t, ctx, e.N2Pool, func() { - _, err := e.N2Pool.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") + e.withRepairMode(t, ctx, e.N2Pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, insertSQL, 1, "CUST-1", "Michael", nil, "Austria") require.NoError(t, err) - _, err = e.N2Pool.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) + _, err = conn.Exec(ctx, insertSQL, 2, "CUST-2", "Fernando", "Alonso", nil) require.NoError(t, err) }) } @@ -369,8 +376,8 @@ func (e *testEnv) resetSharedTable(t *testing.T, tableName string) { csvPath, err := filepath.Abs(defaultCsvFilePath + tableName + ".csv") require.NoError(t, err) for _, pool := range e.pools() { - e.withRepairMode(t, ctx, pool, func() { - _, err := pool.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) + e.withRepairMode(t, ctx, pool, func(conn *pgxpool.Conn) { + _, err := conn.Exec(ctx, fmt.Sprintf("TRUNCATE TABLE %s CASCADE", qualifiedTableName)) require.NoError(t, err, "truncate %s", qualifiedTableName) }) require.NoError(t, loadDataFromCSV(ctx, pool, e.Schema, tableName, csvPath), "load CSV into %s", qualifiedTableName) From aea4a4d0c1f91c59c1c1d7a80c5640b7d466ca1a Mon Sep 17 00:00:00 2001 From: Mason Sharp Date: Wed, 15 Apr 2026 17:08:14 -0700 Subject: [PATCH 14/14] fix: native PG origin resolution for --against-origin and use pg_xact_commit_timestamp_origin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GetNativeNodeOriginNames query that joins pg_replication_origin with pg_subscription to map roident → subscription name (the native PG equivalent of GetSpockNodeNames) - Replace spock.xact_commit_timestamp_origin with the standard PG function pg_xact_commit_timestamp_origin in buildEffectiveFilter (they're identical — spock's impl just calls TransactionIdGetCommitTsData) - Remove hardcoded host ports from docker-compose-native.yaml to avoid conflicts with local services - Add unit tests for resolveAgainstOrigin and buildEffectiveFilter - Add integration tests for native PG origin name resolution and origin-tracked replication with repair Co-Authored-By: Claude Opus 4.6 (1M context) --- db/queries/queries.go | 33 ++- db/queries/templates.go | 10 + internal/consistency/diff/table_diff.go | 2 +- .../diff/table_diff_origin_test.go | 135 ++++++++++ internal/consistency/repair/table_repair.go | 3 +- tests/integration/docker-compose-native.yaml | 6 +- tests/integration/native_pg_test.go | 245 ++++++++++++++++++ 7 files changed, 427 insertions(+), 7 deletions(-) create mode 100644 internal/consistency/diff/table_diff_origin_test.go diff --git a/db/queries/queries.go b/db/queries/queries.go index 1c7b471..8dfe37c 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -997,6 +997,37 @@ func GetReplicationOriginNames(ctx context.Context, db DBQuerier) (map[string]st return names, nil } +// GetNativeNodeOriginNames maps replication origin IDs to subscription names +// for native PG logical replication (no spock). This is the native PG +// equivalent of GetSpockNodeNames. +func GetNativeNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { + sql, err := RenderSQL(SQLTemplates.GetNativeNodeOriginNames, nil) + if err != nil { + return nil, err + } + + rows, err := db.Query(ctx, sql) + if err != nil { + return nil, err + } + defer rows.Close() + + names := make(map[string]string) + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + names[id] = name + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return names, nil +} + func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) { var spockAvailable bool err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&spockAvailable) @@ -1006,7 +1037,7 @@ func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, e if spockAvailable { return GetSpockNodeNames(ctx, db) } - return GetReplicationOriginNames(ctx, db) + return GetNativeNodeOriginNames(ctx, db) } func CheckSpockInstalled(ctx context.Context, db DBQuerier) (bool, error) { diff --git a/db/queries/templates.go b/db/queries/templates.go index c067413..6a0ffb8 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -134,6 +134,7 @@ type Templates struct { GetNativeOriginLSNForNode *template.Template GetNativeSlotLSNForNode *template.Template GetReplicationOriginNames *template.Template + GetNativeNodeOriginNames *template.Template EnsureHashVersionColumn *template.Template GetHashVersion *template.Template MarkAllLeavesDirty *template.Template @@ -1571,6 +1572,15 @@ var SQLTemplates = Templates{ GetReplicationOriginNames: template.Must(template.New("getReplicationOriginNames").Parse(` SELECT roident::text, roname FROM pg_replication_origin; `)), + // GetNativeNodeOriginNames maps pg_replication_origin entries to their + // corresponding pg_subscription names. This provides the native PG + // equivalent of GetSpockNodeNames — mapping origin IDs (used by + // pg_xact_commit_timestamp_origin) to human-readable node identifiers. + GetNativeNodeOriginNames: template.Must(template.New("getNativeNodeOriginNames").Parse(` + SELECT ro.roident::text, s.subname + FROM pg_catalog.pg_replication_origin ro + JOIN pg_catalog.pg_subscription s ON ro.roname = 'pg_' || s.oid::text + `)), GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(` SELECT roident FROM pg_replication_origin WHERE roname = $1 `)), diff --git a/internal/consistency/diff/table_diff.go b/internal/consistency/diff/table_diff.go index c1479f9..bd93add 100644 --- a/internal/consistency/diff/table_diff.go +++ b/internal/consistency/diff/table_diff.go @@ -266,7 +266,7 @@ func (t *TableDiffTask) buildEffectiveFilter() (string, error) { if err != nil { return "", fmt.Errorf("resolved against-origin %q is not a valid numeric node ID", t.resolvedAgainstOrigin) } - parts = append(parts, fmt.Sprintf("(to_json(spock.xact_commit_timestamp_origin(xmin))->>'roident' = '%d')", nodeID)) + parts = append(parts, fmt.Sprintf("(to_json(pg_xact_commit_timestamp_origin(xmin))->>'roident' = '%d')", nodeID)) } if f := queries.CommitTimestampFilter(t.untilTime); f != "" { diff --git a/internal/consistency/diff/table_diff_origin_test.go b/internal/consistency/diff/table_diff_origin_test.go new file mode 100644 index 0000000..bafd7aa --- /dev/null +++ b/internal/consistency/diff/table_diff_origin_test.go @@ -0,0 +1,135 @@ +// /////////////////////////////////////////////////////////////////////////// +// +// # ACE - Active Consistency Engine +// +// Copyright (C) 2023 - 2026, pgEdge (https://www.pgedge.com/) +// +// This software is released under the PostgreSQL License: +// https://opensource.org/license/postgresql +// +// /////////////////////////////////////////////////////////////////////////// + +package diff + +import ( + "strings" + "testing" +) + +func TestResolveAgainstOrigin_EmptyInput(t *testing.T) { + task := &TableDiffTask{AgainstOrigin: ""} + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "" { + t.Fatalf("expected empty resolvedAgainstOrigin, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_NoNodeOriginNames(t *testing.T) { + task := &TableDiffTask{AgainstOrigin: "n1"} + err := task.resolveAgainstOrigin() + if err == nil { + t.Fatal("expected error when NodeOriginNames is empty") + } + if !strings.Contains(err.Error(), "no node origin names available") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestResolveAgainstOrigin_MatchByID(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "3", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "3" { + t.Fatalf("expected resolvedAgainstOrigin=3, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_MatchByName_SpockNodeName(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "n1", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "3" { + t.Fatalf("expected resolvedAgainstOrigin=3, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_MatchByName_SubscriptionName(t *testing.T) { + // Native PG: NodeOriginNames maps roident -> subscription name + task := &TableDiffTask{ + AgainstOrigin: "sub_n1_to_n2", + NodeOriginNames: map[string]string{"5": "sub_n1_to_n2", "6": "sub_n3_to_n2"}, + } + if err := task.resolveAgainstOrigin(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if task.resolvedAgainstOrigin != "5" { + t.Fatalf("expected resolvedAgainstOrigin=5, got %q", task.resolvedAgainstOrigin) + } +} + +func TestResolveAgainstOrigin_NoMatch(t *testing.T) { + task := &TableDiffTask{ + AgainstOrigin: "nonexistent", + NodeOriginNames: map[string]string{"3": "n1", "4": "n2"}, + } + err := task.resolveAgainstOrigin() + if err == nil { + t.Fatal("expected error for unresolvable origin") + } + if !strings.Contains(err.Error(), "nonexistent") { + t.Fatalf("error should mention the unresolved name: %v", err) + } +} + +func TestBuildEffectiveFilter_AgainstOrigin(t *testing.T) { + task := &TableDiffTask{ + resolvedAgainstOrigin: "3", + } + filter, err := task.buildEffectiveFilter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(filter, "pg_xact_commit_timestamp_origin") { + t.Fatalf("expected pg_xact_commit_timestamp_origin in filter, got: %s", filter) + } + if !strings.Contains(filter, "'3'") { + t.Fatalf("expected roident=3 in filter, got: %s", filter) + } + if strings.Contains(filter, "spock") { + t.Fatalf("filter should not reference spock: %s", filter) + } +} + +func TestBuildEffectiveFilter_NonNumericOrigin(t *testing.T) { + task := &TableDiffTask{ + resolvedAgainstOrigin: "not_a_number", + } + _, err := task.buildEffectiveFilter() + if err == nil { + t.Fatal("expected error for non-numeric resolved origin") + } + if !strings.Contains(err.Error(), "not a valid numeric node ID") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestBuildEffectiveFilter_Empty(t *testing.T) { + task := &TableDiffTask{} + filter, err := task.buildEffectiveFilter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filter != "" { + t.Fatalf("expected empty filter, got: %s", filter) + } +} diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index f908dc9..635514f 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -1753,7 +1753,8 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { } logger.Debug("Transaction committed successfully on %s", nodeName) } else if !t.PreserveOrigin || len(t.extractOriginInfoForNode(nodeName, fullUpserts[nodeName])) == 0 { - // Non-spock path: still need to commit the transaction + // Non-spock path: commit unless the preserve-origin branch (line ~1692) + // already committed this tx and ran upserts in separate transactions. err = tx.Commit(t.Ctx) if err != nil { logger.Error("committing transaction on node %s: %v", nodeName, err) diff --git a/tests/integration/docker-compose-native.yaml b/tests/integration/docker-compose-native.yaml index 030955f..be60f2b 100644 --- a/tests/integration/docker-compose-native.yaml +++ b/tests/integration/docker-compose-native.yaml @@ -26,8 +26,7 @@ services: - "-c" - "wal_level=logical" ports: - - target: 5432 - published: 7432 + - "5432" native-n2: image: postgres:17 environment: @@ -41,5 +40,4 @@ services: - "-c" - "wal_level=logical" ports: - - target: 5432 - published: 7433 + - "5432" diff --git a/tests/integration/native_pg_test.go b/tests/integration/native_pg_test.go index a0aacfd..abf93ed 100644 --- a/tests/integration/native_pg_test.go +++ b/tests/integration/native_pg_test.go @@ -240,6 +240,65 @@ func TestNativePG(t *testing.T) { assert.False(t, installed, "spock should not be installed on vanilla PG") }) + t.Run("GetNodeOriginNames_NativeSubscription", func(t *testing.T) { + // Set up a real publication on n1 and subscription on n2 so that + // pg_replication_origin gets populated with a subscription-linked entry. + subName := "test_origin_sub" + pubName := "test_origin_pub" + + // Create a test table and publication on n1. + _, err := state.n1Pool.Exec(ctx, + "CREATE TABLE IF NOT EXISTS public.origin_test (id int PRIMARY KEY, val text)") + require.NoError(t, err) + _, err = state.n1Pool.Exec(ctx, + fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE public.origin_test", pubName)) + require.NoError(t, err) + + // Create the same table on n2 (subscription target). + _, err = state.n2Pool.Exec(ctx, + "CREATE TABLE IF NOT EXISTS public.origin_test (id int PRIMARY KEY, val text)") + require.NoError(t, err) + + // Create a subscription on n2 pointing at n1 via Docker-internal hostname. + connStr := fmt.Sprintf( + "host=%s port=5432 dbname=%s user=%s password=%s", + nativeServiceN1, nativeDBName, nativeUser, nativePassword) + _, err = state.n2Pool.Exec(ctx, + fmt.Sprintf("CREATE SUBSCRIPTION %s CONNECTION '%s' PUBLICATION %s", + subName, connStr, pubName)) + require.NoError(t, err) + + t.Cleanup(func() { + state.n2Pool.Exec(ctx, fmt.Sprintf("DROP SUBSCRIPTION IF EXISTS %s", subName)) + state.n1Pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", pubName)) + state.n1Pool.Exec(ctx, "DROP TABLE IF EXISTS public.origin_test") + state.n2Pool.Exec(ctx, "DROP TABLE IF EXISTS public.origin_test") + }) + + // GetNodeOriginNames on n2 should route to GetNativeNodeOriginNames + // (spock not installed) and return the subscription name as the value. + names, err := queries.GetNodeOriginNames(ctx, state.n2Pool) + require.NoError(t, err, "GetNodeOriginNames should succeed on native PG with subscriptions") + require.NotEmpty(t, names, "should have at least one origin mapping") + + // Verify that the subscription name appears as a value in the map. + found := false + for _, name := range names { + if name == subName { + found = true + break + } + } + assert.True(t, found, + "expected subscription name %q in origin names map, got: %v", subName, names) + + // The key should be a numeric roident (parseable as int). + for id := range names { + _, parseErr := fmt.Sscanf(id, "%d", new(int)) + assert.NoError(t, parseErr, "origin ID %q should be numeric", id) + } + }) + t.Run("SpockDiff_GracefulError", func(t *testing.T) { task := diff.NewSpockDiffTask() task.ClusterName = nativeClusterName @@ -381,4 +440,190 @@ func TestNativePG(t *testing.T) { t.Run("MerkleTreeNumericScaleInvariance", func(t *testing.T) { testMerkleTreeNumericScaleInvariance(t, env) }) + + // ── Native PG preserve-origin test ─────────────────────────────────── + // This test verifies the full diff → preserve-origin repair → verify + // cycle on native PG with real logical replication, including that + // GetNodeOriginNames returns subscription names and that repaired rows + // retain their original replication origin. + + t.Run("TableRepair_PreserveOrigin_NativePG", func(t *testing.T) { + testNativePreserveOrigin(t, state, env) + }) +} + +// getNativeReplicationOrigin retrieves the replication origin for a row on +// native PG (no spock). Uses pg_xact_commit_timestamp_origin to get the +// roident, then resolves it via pg_replication_origin.roname. +func getNativeReplicationOrigin(t *testing.T, ctx context.Context, pool *pgxpool.Pool, qualifiedTableName string, id int) string { + t.Helper() + + var roidentStr *string + query := fmt.Sprintf( + `SELECT (pg_xact_commit_timestamp_origin(xmin)).roident::text FROM %s WHERE id = $1`, + qualifiedTableName) + err := pool.QueryRow(ctx, query, id).Scan(&roidentStr) + if err != nil || roidentStr == nil || *roidentStr == "" || *roidentStr == "0" { + return "" + } + + var originName string + err = pool.QueryRow(ctx, + "SELECT roname FROM pg_replication_origin WHERE roident::text = $1", *roidentStr).Scan(&originName) + if err == nil && originName != "" { + return originName + } + + return *roidentStr +} + +// testNativePreserveOrigin verifies origin tracking on native PG with real +// logical replication: +// 1. Set up logical replication (publication on n1, subscription on n2) +// 2. Insert data on n1, wait for streaming replication to n2 +// 3. Verify GetNodeOriginNames maps roident → subscription name +// 4. Verify replicated rows on n2 have origin tracked via pg_xact_commit_timestamp_origin +// 5. Verify the origin resolves to the subscription's pg_replication_origin entry +// 6. Delete rows on n2, run diff + repair, verify rows restored +// +// Note: preserve-origin cannot fully restore origins in a 2-node setup because +// the source-of-truth node (n1) has origin="local" for rows it wrote. A 3-node +// setup (like the spock PreserveOrigin test) is needed for full origin preservation. +func testNativePreserveOrigin(t *testing.T, state *nativeClusterState, env *testEnv) { + ctx := context.Background() + tableName := "native_preserve_origin_test" + qualifiedTableName := fmt.Sprintf("public.%s", tableName) + subName := "preserve_origin_sub" + pubName := "preserve_origin_pub" + + // Create table on both nodes. + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + _, err := pool.Exec(ctx, fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s (id INT PRIMARY KEY, data TEXT)", qualifiedTableName)) + require.NoError(t, err) + } + + // Create publication on n1 and subscription on n2. + // Use copy_data=false so the initial table sync doesn't use a transient + // replication origin (which PG deletes after sync, leaving rows with a + // defunct roident). With copy_data=false, data inserted after the + // subscription starts streaming uses the subscription's main origin. + _, err := state.n1Pool.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s", pubName, qualifiedTableName)) + require.NoError(t, err) + + connStr := fmt.Sprintf("host=%s port=5432 dbname=%s user=%s password=%s", + nativeServiceN1, nativeDBName, nativeUser, nativePassword) + _, err = state.n2Pool.Exec(ctx, fmt.Sprintf( + "CREATE SUBSCRIPTION %s CONNECTION '%s' PUBLICATION %s WITH (copy_data = false)", + subName, connStr, pubName)) + require.NoError(t, err) + + t.Cleanup(func() { + state.n2Pool.Exec(ctx, fmt.Sprintf("DROP SUBSCRIPTION IF EXISTS %s", subName)) + state.n1Pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s", pubName)) + for _, pool := range []*pgxpool.Pool{state.n1Pool, state.n2Pool} { + pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s CASCADE", qualifiedTableName)) + } + }) + + // Brief pause for the subscription to connect and start streaming. + time.Sleep(2 * time.Second) + + // Insert data on n1 AFTER subscription is streaming. + sampleIDs := []int{1, 2, 3, 4, 5} + for _, id := range sampleIDs { + _, err := state.n1Pool.Exec(ctx, + fmt.Sprintf("INSERT INTO %s (id, data) VALUES ($1, $2)", qualifiedTableName), + id, fmt.Sprintf("row_%d", id)) + require.NoError(t, err) + } + + assertEventually(t, 30*time.Second, func() error { + var count int + if err := state.n2Pool.QueryRow(ctx, + fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&count); err != nil { + return err + } + if count < len(sampleIDs) { + return fmt.Errorf("expected %d rows on n2, got %d", len(sampleIDs), count) + } + return nil + }) + log.Println("Replication complete: all rows present on n2 (via streaming)") + + // --- Verify GetNodeOriginNames maps roident → subscription name --- + names, err := queries.GetNodeOriginNames(ctx, state.n2Pool) + require.NoError(t, err) + found := false + var subRoident string + for id, name := range names { + if name == subName { + found = true + subRoident = id + break + } + } + require.True(t, found, "GetNodeOriginNames should contain subscription %q, got: %v", subName, names) + log.Printf("GetNodeOriginNames on n2: %v (subscription roident=%s)", names, subRoident) + + // --- Verify replicated rows on n2 have non-local origin --- + for _, id := range sampleIDs { + origin := getNativeReplicationOrigin(t, ctx, state.n2Pool, qualifiedTableName, id) + require.NotEmpty(t, origin, + "Row %d on n2 should have a replication origin (was replicated from n1)", id) + log.Printf("Row %d on n2: origin=%s", id, origin) + } + + // --- Verify rows on n1 are "local" origin --- + for _, id := range sampleIDs { + origin := getNativeReplicationOrigin(t, ctx, state.n1Pool, qualifiedTableName, id) + assert.Empty(t, origin, + "Row %d on n1 should have local origin (roident=0), got %q", id, origin) + } + log.Println("Origin tracking verified: n2 rows have subscription origin, n1 rows are local") + + // --- Simulate data loss on n2 and verify basic repair works --- + log.Println("Simulating data loss on n2...") + tx, err := state.n2Pool.Begin(ctx) + require.NoError(t, err) + _, err = tx.Exec(ctx, "SET session_replication_role = 'replica'") + require.NoError(t, err) + for _, id := range sampleIDs { + _, err = tx.Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE id = $1", qualifiedTableName), id) + require.NoError(t, err) + } + require.NoError(t, tx.Commit(ctx)) + + // Run table-diff. + diffTask := env.newTableDiffTask(t, qualifiedTableName, []string{nativeServiceN1, nativeServiceN2}) + require.NoError(t, diffTask.RunChecks(false)) + require.NoError(t, diffTask.ExecuteTask()) + diffFile := getLatestDiffFile(t) + require.NotEmpty(t, diffFile) + + // Run repair (recovery mode). + repairTask := env.newTableRepairTask(nativeServiceN1, qualifiedTableName, diffFile) + repairTask.RecoveryMode = true + + err = repairTask.Run(false) + require.NoError(t, err) + if repairTask.TaskStatus == "FAILED" { + t.Fatalf("Repair failed: %s", repairTask.TaskContext) + } + + // Verify all rows are restored. + var count int + err = state.n2Pool.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", qualifiedTableName)).Scan(&count) + require.NoError(t, err) + require.Equal(t, len(sampleIDs), count, "All rows should be restored after repair") + + // Verify row content matches. + for _, id := range sampleIDs { + var data string + err := state.n2Pool.QueryRow(ctx, + fmt.Sprintf("SELECT data FROM %s WHERE id = $1", qualifiedTableName), id).Scan(&data) + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("row_%d", id), data, "Row %d data mismatch", id) + } + log.Println("Native PG origin tracking and repair verified successfully") }