Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,8 @@ jobs:
- name: Run catastrophic single-node failure recovery test
run: go test -count=1 -v ./tests/integration -run 'TestCatastrophicSingleNodeFailure'

- name: Run native PG (no spock) tests
run: go test -count=1 -v ./tests/integration -run 'TestNativePG'

- name: Run timestamp comparison tests
run: go test -count=1 -v ./tests/integration -run 'TestCompareTimestampsExact|TestPostgreSQLMicrosecondPrecision|TestOldVsNewComparison'
110 changes: 110 additions & 0 deletions db/queries/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,116 @@ func GetSpockSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string
return lsn, nil
}

func GetNativeOriginLSNForNode(ctx context.Context, db DBQuerier, originNodeName string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeOriginLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, originNodeName).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch native origin lsn: %w", err)
}
return lsn, nil
}

func GetNativeSlotLSNForNode(ctx context.Context, db DBQuerier, failedNode string) (*string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeSlotLSNForNode, nil)
if err != nil {
return nil, err
}
var lsn *string
if err := db.QueryRow(ctx, sql, failedNode).Scan(&lsn); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to fetch native slot lsn: %w", err)
}
return lsn, nil
}

func GetReplicationOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
sql, err := RenderSQL(SQLTemplates.GetReplicationOriginNames, nil)
if err != nil {
return nil, err
}

rows, err := db.Query(ctx, sql)
if err != nil {
return nil, err
}
defer rows.Close()

names := make(map[string]string)
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
names[id] = name
}

if err := rows.Err(); err != nil {
return nil, err
}

return names, nil
}

// GetNativeNodeOriginNames maps replication origin IDs to subscription names
// for native PG logical replication (no spock). This is the native PG
// equivalent of GetSpockNodeNames.
func GetNativeNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
sql, err := RenderSQL(SQLTemplates.GetNativeNodeOriginNames, nil)
if err != nil {
return nil, err
}

rows, err := db.Query(ctx, sql)
if err != nil {
return nil, err
}
defer rows.Close()

names := make(map[string]string)
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
names[id] = name
}

if err := rows.Err(); err != nil {
return nil, err
}

return names, nil
}

func GetNodeOriginNames(ctx context.Context, db DBQuerier) (map[string]string, error) {
var spockAvailable bool
err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&spockAvailable)
if err != nil {
return nil, fmt.Errorf("detecting spock extension: %w", err)
}
if spockAvailable {
return GetSpockNodeNames(ctx, db)
}
return GetNativeNodeOriginNames(ctx, db)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func CheckSpockInstalled(ctx context.Context, db DBQuerier) (bool, error) {
var exists bool
err := db.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'spock')").Scan(&exists)
if err != nil {
return false, fmt.Errorf("detecting spock extension: %w", err)
}
return exists, nil
}

func GetSpockRepSetInfo(ctx context.Context, db DBQuerier) ([]types.SpockRepSetInfo, error) {
sql, err := RenderSQL(SQLTemplates.SpockRepSetInfo, nil)
if err != nil {
Expand Down
122 changes: 88 additions & 34 deletions db/queries/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@

package queries

import "text/template"
import (
"text/template"

"github.com/jackc/pgx/v5"
"github.com/pgedge/ace/pkg/config"
)

// aceTemplateFuncs provides the {{aceSchema}} function to SQL templates.
// The function is evaluated at render time (after config is loaded), not at parse time.
var aceTemplateFuncs = template.FuncMap{
"aceSchema": func() string { return pgx.Identifier{config.Get().MTree.Schema}.Sanitize() },
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

type Templates struct {
EstimateRowCount *template.Template
Expand Down Expand Up @@ -78,6 +89,7 @@ type Templates struct {
GetBlockCountSimple *template.Template
GetBlockSizeFromMetadata *template.Template
GetMaxNodeLevel *template.Template
CompareBlocksSQL *template.Template

DropXORFunction *template.Template
DropMetadataTable *template.Template
Expand Down Expand Up @@ -119,6 +131,10 @@ type Templates struct {
RemoveTableFromCDCMetadata *template.Template
GetSpockOriginLSNForNode *template.Template
GetSpockSlotLSNForNode *template.Template
GetNativeOriginLSNForNode *template.Template
GetNativeSlotLSNForNode *template.Template
GetReplicationOriginNames *template.Template
GetNativeNodeOriginNames *template.Template
EnsureHashVersionColumn *template.Template
GetHashVersion *template.Template
MarkAllLeavesDirty *template.Template
Expand All @@ -133,8 +149,8 @@ type Templates struct {

var SQLTemplates = Templates{
// A template isn't needed for this query; just keeping the struct uniform
CreateMetadataTable: template.Must(template.New("createMetadataTable").Parse(`
CREATE TABLE IF NOT EXISTS spock.ace_mtree_metadata (
CreateMetadataTable: template.Must(template.New("createMetadataTable").Funcs(aceTemplateFuncs).Parse(`
CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_mtree_metadata (
schema_name text,
table_name text,
total_rows bigint,
Expand All @@ -161,8 +177,8 @@ var SQLTemplates = Templates{
ALTER PUBLICATION {{.PublicationName}} DROP TABLE {{.TableName}}
`)),

RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Parse(`
UPDATE spock.ace_cdc_metadata
RemoveTableFromCDCMetadata: template.Must(template.New("removeTableFromCDCMetadata").Funcs(aceTemplateFuncs).Parse(`
UPDATE {{aceSchema}}.ace_cdc_metadata
SET tables = array_remove(tables, $1)
WHERE publication_name = $2
`)),
Expand All @@ -179,9 +195,9 @@ var SQLTemplates = Templates{
)
`)),

UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Parse(`
UpdateCDCMetadata: template.Must(template.New("updateCdcMetadata").Funcs(aceTemplateFuncs).Parse(`
INSERT INTO
spock.ace_cdc_metadata (
{{aceSchema}}.ace_cdc_metadata (
publication_name,
slot_name,
start_lsn,
Expand Down Expand Up @@ -220,17 +236,17 @@ var SQLTemplates = Templates{
CheckPIDExists: template.Must(template.New("checkPIDExists").Parse(`
SELECT pid FROM pg_stat_activity WHERE pid = $1
`)),
DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Parse(`
DROP TABLE IF EXISTS spock.ace_cdc_metadata
DropCDCMetadataTable: template.Must(template.New("dropCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(`
DROP TABLE IF EXISTS {{aceSchema}}.ace_cdc_metadata
`)),

GetCDCMetadata: template.Must(template.New("getCDCMetadata").Parse(`
GetCDCMetadata: template.Must(template.New("getCDCMetadata").Funcs(aceTemplateFuncs).Parse(`
SELECT
slot_name,
start_lsn,
tables
FROM
spock.ace_cdc_metadata
{{aceSchema}}.ace_cdc_metadata
WHERE
publication_name = $1
`)),
Expand Down Expand Up @@ -317,8 +333,8 @@ var SQLTemplates = Templates{
AND mt.node_position = b.node_position;
`)),

CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Parse(`
CREATE TABLE IF NOT EXISTS spock.ace_cdc_metadata (
CreateCDCMetadataTable: template.Must(template.New("createCDCMetadataTable").Funcs(aceTemplateFuncs).Parse(`
CREATE TABLE IF NOT EXISTS {{aceSchema}}.ace_cdc_metadata (
publication_name text PRIMARY KEY,
slot_name text,
start_lsn text,
Expand Down Expand Up @@ -772,9 +788,9 @@ var SQLTemplates = Templates{
VALUES
(0, $1, {{.StartExpr}}, {{.EndExpr}});
`)),
CreateXORFunction: template.Must(template.New("createXORFunction").Parse(`
CreateXORFunction: template.Must(template.New("createXORFunction").Funcs(aceTemplateFuncs).Parse(`
CREATE
OR REPLACE FUNCTION spock.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$
OR REPLACE FUNCTION {{aceSchema}}.bytea_xor(a bytea, b bytea) RETURNS bytea AS $$
DECLARE
result bytea;
len int;
Expand Down Expand Up @@ -805,7 +821,7 @@ var SQLTemplates = Templates{
CREATE OPERATOR # (
LEFTARG = bytea,
RIGHTARG = bytea,
PROCEDURE = spock.bytea_xor
PROCEDURE = {{aceSchema}}.bytea_xor
);
END IF;
END $$;
Expand Down Expand Up @@ -840,9 +856,9 @@ var SQLTemplates = Templates{
AND c.relname = $2
AND a.attname = $3
`)),
UpdateMetadata: template.Must(template.New("updateMetadata").Parse(`
UpdateMetadata: template.Must(template.New("updateMetadata").Funcs(aceTemplateFuncs).Parse(`
INSERT INTO
spock.ace_mtree_metadata (
{{aceSchema}}.ace_mtree_metadata (
schema_name,
table_name,
total_rows,
Expand Down Expand Up @@ -873,8 +889,8 @@ var SQLTemplates = Templates{
hash_version = EXCLUDED.hash_version,
last_updated = EXCLUDED.last_updated
`)),
DeleteMetadata: template.Must(template.New("deleteMetadata").Parse(`
DELETE FROM spock.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2
DeleteMetadata: template.Must(template.New("deleteMetadata").Funcs(aceTemplateFuncs).Parse(`
DELETE FROM {{aceSchema}}.ace_mtree_metadata WHERE schema_name = $1 AND table_name = $2
`)),
InsertBlockRanges: template.Must(template.New("insertBlockRanges").Parse(`
INSERT INTO
Expand Down Expand Up @@ -1055,11 +1071,11 @@ var SQLTemplates = Templates{
ORDER BY
node_position
`)),
GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Parse(`
GetRowCountEstimate: template.Must(template.New("getRowCountEstimate").Funcs(aceTemplateFuncs).Parse(`
SELECT
total_rows
FROM
spock.ace_mtree_metadata
{{aceSchema}}.ace_mtree_metadata
WHERE
schema_name = $1
AND table_name = $2
Expand Down Expand Up @@ -1294,11 +1310,11 @@ var SQLTemplates = Templates{
mt.range_start,
mt.range_end
`)),
GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Parse(`
GetBlockSizeFromMetadata: template.Must(template.New("getBlockSizeFromMetadata").Funcs(aceTemplateFuncs).Parse(`
SELECT
block_size
FROM
spock.ace_mtree_metadata
{{aceSchema}}.ace_mtree_metadata
WHERE
schema_name = $1
AND table_name = $2
Expand All @@ -1309,11 +1325,19 @@ var SQLTemplates = Templates{
FROM
{{.MtreeTable}}
`)),
DropXORFunction: template.Must(template.New("dropXORFunction").Parse(`
DROP FUNCTION IF EXISTS spock.bytea_xor(bytea, bytea) CASCADE
CompareBlocksSQL: template.Must(template.New("compareBlocksSQL").Parse(`
SELECT
*
FROM
{{.TableName}}
WHERE
{{.WhereClause}}
`)),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
DropXORFunction: template.Must(template.New("dropXORFunction").Funcs(aceTemplateFuncs).Parse(`
DROP FUNCTION IF EXISTS {{aceSchema}}.bytea_xor(bytea, bytea) CASCADE
`)),
DropMetadataTable: template.Must(template.New("dropMetadataTable").Parse(`
DROP TABLE IF EXISTS spock.ace_mtree_metadata CASCADE
DropMetadataTable: template.Must(template.New("dropMetadataTable").Funcs(aceTemplateFuncs).Parse(`
DROP TABLE IF EXISTS {{aceSchema}}.ace_mtree_metadata CASCADE
`)),
DropMtreeTable: template.Must(template.New("dropMtreeTable").Parse(`
DROP TABLE IF EXISTS {{.MtreeTable}} CASCADE
Expand Down Expand Up @@ -1506,13 +1530,13 @@ var SQLTemplates = Templates{
ORDER BY rs.confirmed_flush_lsn DESC
LIMIT 1
`)),
EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Parse(`
ALTER TABLE spock.ace_mtree_metadata
EnsureHashVersionColumn: template.Must(template.New("ensureHashVersionColumn").Funcs(aceTemplateFuncs).Parse(`
ALTER TABLE {{aceSchema}}.ace_mtree_metadata
ADD COLUMN IF NOT EXISTS hash_version int NOT NULL DEFAULT 1
`)),
GetHashVersion: template.Must(template.New("getHashVersion").Parse(`
GetHashVersion: template.Must(template.New("getHashVersion").Funcs(aceTemplateFuncs).Parse(`
SELECT COALESCE(
(SELECT hash_version FROM spock.ace_mtree_metadata
(SELECT hash_version FROM {{aceSchema}}.ace_mtree_metadata
WHERE schema_name = $1 AND table_name = $2),
1
)
Expand All @@ -1522,11 +1546,41 @@ var SQLTemplates = Templates{
SET dirty = true
WHERE node_level = 0
`)),
UpdateHashVersion: template.Must(template.New("updateHashVersion").Parse(`
UPDATE spock.ace_mtree_metadata
UpdateHashVersion: template.Must(template.New("updateHashVersion").Funcs(aceTemplateFuncs).Parse(`
UPDATE {{aceSchema}}.ace_mtree_metadata
SET hash_version = $1, last_updated = current_timestamp
WHERE schema_name = $2 AND table_name = $3
`)),
GetNativeOriginLSNForNode: template.Must(template.New("getNativeOriginLSNForNode").Parse(`
SELECT ros.remote_lsn::text
FROM pg_catalog.pg_replication_origin_status ros
JOIN pg_catalog.pg_replication_origin ro ON ro.roident = ros.local_id
JOIN pg_catalog.pg_subscription s ON ro.roname LIKE 'pg_%' || s.oid::text
WHERE s.subname ~ ('\m' || $1 || '\M')
AND ros.remote_lsn IS NOT NULL
Comment on lines +1559 to +1560
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$1 (a user-supplied node name) is interpolated raw into a regex fragment. A node name containing |, ., or * silently
matches wrong subscriptions. The equivalent Spock templates use = $1 exact match. Fix: use = $1 here too.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically valid but low-risk in practice: node
names are simple identifiers (n1, postgres-n1). And the code already handles this gracefully at
table_repair.go:2940-2942

LIMIT 1
`)),
GetNativeSlotLSNForNode: template.Must(template.New("getNativeSlotLSNForNode").Parse(`
SELECT rs.confirmed_flush_lsn::text
FROM pg_catalog.pg_replication_slots rs
JOIN pg_catalog.pg_subscription s ON rs.slot_name = s.subslotname
WHERE s.subname ~ ('\m' || $1 || '\M')
AND rs.confirmed_flush_lsn IS NOT NULL
ORDER BY rs.confirmed_flush_lsn DESC
LIMIT 1
`)),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
GetReplicationOriginNames: template.Must(template.New("getReplicationOriginNames").Parse(`
SELECT roident::text, roname FROM pg_replication_origin;
`)),
// GetNativeNodeOriginNames maps pg_replication_origin entries to their
// corresponding pg_subscription names. This provides the native PG
// equivalent of GetSpockNodeNames — mapping origin IDs (used by
// pg_xact_commit_timestamp_origin) to human-readable node identifiers.
GetNativeNodeOriginNames: template.Must(template.New("getNativeNodeOriginNames").Parse(`
SELECT ro.roident::text, s.subname
FROM pg_catalog.pg_replication_origin ro
JOIN pg_catalog.pg_subscription s ON ro.roname = 'pg_' || s.oid::text
`)),
GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(`
SELECT roident FROM pg_replication_origin WHERE roname = $1
`)),
Expand Down
Loading
Loading