Skip to content

Commit 366a8db

Browse files
authored
Quote schema name in operations to support special characters and spaces (#1175)
Came up in #1170: we don't quote schemas in SQL operations, and that can lead to problems in use of spaces, special characters, and uppercase characters. It's not the end of the world given use of the above can be considered a sizable anti-pattern anyway, but use of quoting is good for general correctness. Fixes #1170.
1 parent d1a55b1 commit 366a8db

12 files changed

Lines changed: 145 additions & 22 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- `riverlog.Middleware` now supports `MiddlewareConfig.MaxTotalBytes` (default 8 MB) to cap total persisted `river:log` history per job. When the cap is exceeded, oldest log entries are dropped first while retaining the newest entry. Values over 64 MB are clamped to 64 MB. [PR #1157](https://github.com/riverqueue/river/pull/1157).
1313
- Improved `riverlog` performance and reduced memory amplification when appending to large persisted `river:log` histories. [PR #1157](https://github.com/riverqueue/river/pull/1157).
1414
- Reduced snooze-path memory amplification by setting `snoozes` in metadata updates before marshaling, avoiding an extra full-payload JSON rewrite. [PR #1159](https://github.com/riverqueue/river/pull/1159).
15+
- Schema names are now quoted in SQL operations, enabling the use of spaces and other odd characters. [PR #1175](https://github.com/riverqueue/river/pull/1175).
1516

1617
### Fixed
1718

riverdriver/riverdatabasesql/river_database_sql_driver.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/riverqueue/river/riverdriver/riverdatabasesql/internal/dbsqlc"
2525
"github.com/riverqueue/river/rivershared/sqlctemplate"
2626
"github.com/riverqueue/river/rivershared/uniquestates"
27+
"github.com/riverqueue/river/rivershared/util/dbutil"
2728
"github.com/riverqueue/river/rivershared/util/ptrutil"
2829
"github.com/riverqueue/river/rivershared/util/savepointutil"
2930
"github.com/riverqueue/river/rivershared/util/sliceutil"
@@ -143,7 +144,7 @@ func (e *Executor) Exec(ctx context.Context, sql string, args ...any) error {
143144
func (e *Executor) IndexDropIfExists(ctx context.Context, params *riverdriver.IndexDropIfExistsParams) error {
144145
var maybeSchema string
145146
if params.Schema != "" {
146-
maybeSchema = params.Schema + "."
147+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
147148
}
148149

149150
_, err := e.dbtx.ExecContext(ctx, "DROP INDEX CONCURRENTLY IF EXISTS "+maybeSchema+params.Index)
@@ -164,7 +165,7 @@ func (e *Executor) IndexExists(ctx context.Context, params *riverdriver.IndexExi
164165
func (e *Executor) IndexReindex(ctx context.Context, params *riverdriver.IndexReindexParams) error {
165166
var maybeSchema string
166167
if params.Schema != "" {
167-
maybeSchema = params.Schema + "."
168+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
168169
}
169170

170171
_, err := e.dbtx.ExecContext(ctx, "REINDEX INDEX CONCURRENTLY "+maybeSchema+params.Index)
@@ -972,12 +973,12 @@ func (e *Executor) QueryRow(ctx context.Context, sql string, args ...any) riverd
972973
}
973974

974975
func (e *Executor) SchemaCreate(ctx context.Context, params *riverdriver.SchemaCreateParams) error {
975-
_, err := e.dbtx.ExecContext(ctx, "CREATE SCHEMA "+params.Schema)
976+
_, err := e.dbtx.ExecContext(ctx, "CREATE SCHEMA "+dbutil.SafeIdentifier(params.Schema))
976977
return interpretError(err)
977978
}
978979

979980
func (e *Executor) SchemaDrop(ctx context.Context, params *riverdriver.SchemaDropParams) error {
980-
_, err := e.dbtx.ExecContext(ctx, "DROP SCHEMA "+params.Schema+" CASCADE")
981+
_, err := e.dbtx.ExecContext(ctx, "DROP SCHEMA "+dbutil.SafeIdentifier(params.Schema)+" CASCADE")
981982
return interpretError(err)
982983
}
983984

@@ -996,7 +997,7 @@ func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExi
996997
// Different from other operations because the schemaAndTable name is a parameter.
997998
schemaAndTable := params.Table
998999
if params.Schema != "" {
999-
schemaAndTable = params.Schema + "." + schemaAndTable
1000+
schemaAndTable = dbutil.SafeIdentifier(params.Schema) + "." + schemaAndTable
10001001
}
10011002

10021003
exists, err := dbsqlc.New().TableExists(ctx, e.dbtx, schemaAndTable)
@@ -1006,7 +1007,7 @@ func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExi
10061007
func (e *Executor) TableTruncate(ctx context.Context, params *riverdriver.TableTruncateParams) error {
10071008
var maybeSchema string
10081009
if params.Schema != "" {
1009-
maybeSchema = params.Schema + "."
1010+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
10101011
}
10111012

10121013
// Uses raw SQL so we can truncate multiple tables at once.
@@ -1240,7 +1241,7 @@ func queueFromInternal(internal *dbsqlc.RiverQueue) *rivertype.Queue {
12401241

12411242
func schemaTemplateParam(ctx context.Context, schema string) context.Context {
12421243
if schema != "" {
1243-
schema += "."
1244+
schema = dbutil.SafeIdentifier(schema) + "."
12441245
}
12451246

12461247
return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{

riverdriver/riverdatabasesql/river_database_sql_driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,6 @@ func TestSchemaTemplateParam(t *testing.T) {
113113
nil,
114114
)
115115
require.NoError(t, err)
116-
require.Equal(t, "SELECT 1 FROM custom_schema.river_job", updatedSQL)
116+
require.Equal(t, `SELECT 1 FROM "custom_schema".river_job`, updatedSQL)
117117
})
118118
}

riverdriver/riverdrivertest/riverdrivertest.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ func Exercise[TTx any](ctx context.Context, t *testing.T,
3636
exerciseSQLFragments(ctx, t, executorWithTx)
3737
exerciseExecutorTx(ctx, t, driverWithSchema, executorWithTx)
3838
exerciseSchemaIntrospection(ctx, t, driverWithSchema, executorWithTx)
39+
exerciseSchemaName(ctx, t, driverWithSchema)
3940
exerciseJobInsert(ctx, t, driverWithSchema, executorWithTx)
4041
exerciseJobRead(ctx, t, executorWithTx)
4142
exerciseJobUpdate(ctx, t, executorWithTx)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package riverdrivertest
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/riverqueue/river/riverdbtest"
10+
"github.com/riverqueue/river/riverdriver"
11+
"github.com/riverqueue/river/rivermigrate"
12+
"github.com/riverqueue/river/rivershared/riversharedtest"
13+
"github.com/riverqueue/river/rivershared/util/randutil"
14+
"github.com/riverqueue/river/rivertype"
15+
)
16+
17+
func exerciseSchemaName[TTx any](ctx context.Context, t *testing.T,
18+
driverWithSchema func(ctx context.Context, t *testing.T, opts *riverdbtest.TestSchemaOpts) (riverdriver.Driver[TTx], string),
19+
) {
20+
t.Helper()
21+
22+
t.Run("SchemaNameWithSpace", func(t *testing.T) {
23+
t.Parallel()
24+
25+
driver, _ := driverWithSchema(ctx, t, nil)
26+
27+
// In SQLite schemas are files assigned to particular names, so this
28+
// check isn't relevant in the same way.
29+
if driver.DatabaseName() != databaseNamePostgres {
30+
t.Skip("Skipping; schema names with spaces only relevant for Postgres")
31+
}
32+
33+
// Schemas should get cleaned up, but still need some randomness in case
34+
// multiple tests are running in parallel.
35+
schema := "river test schema " + randutil.Hex(8)
36+
37+
exec := driver.GetExecutor()
38+
39+
require.NoError(t, exec.SchemaCreate(ctx, &riverdriver.SchemaCreateParams{Schema: schema}))
40+
t.Cleanup(func() {
41+
require.NoError(t, exec.SchemaDrop(ctx, &riverdriver.SchemaDropParams{Schema: schema}))
42+
})
43+
44+
migrator, err := rivermigrate.New(driver, &rivermigrate.Config{
45+
Logger: riversharedtest.Logger(t),
46+
Schema: schema,
47+
})
48+
require.NoError(t, err)
49+
50+
_, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, nil)
51+
require.NoError(t, err)
52+
53+
// Insert a job and verify it can be read back.
54+
insertedJobs, err := exec.JobInsertFastMany(ctx, &riverdriver.JobInsertFastManyParams{
55+
Jobs: []*riverdriver.JobInsertFastParams{{
56+
EncodedArgs: []byte(`{}`),
57+
Kind: "test_kind",
58+
MaxAttempts: 25,
59+
Metadata: []byte(`{}`),
60+
Priority: 1,
61+
Queue: "default",
62+
State: rivertype.JobStateAvailable,
63+
}},
64+
Schema: schema,
65+
})
66+
require.NoError(t, err)
67+
require.Len(t, insertedJobs, 1)
68+
69+
fetchedJob, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{
70+
ID: insertedJobs[0].Job.ID,
71+
Schema: schema,
72+
})
73+
require.NoError(t, err)
74+
require.Equal(t, insertedJobs[0].Job.ID, fetchedJob.ID)
75+
require.Equal(t, "test_kind", fetchedJob.Kind)
76+
77+
// Verify the table exists in the schema.
78+
exists, err := exec.TableExists(ctx, &riverdriver.TableExistsParams{
79+
Schema: schema,
80+
Table: "river_job",
81+
})
82+
require.NoError(t, err)
83+
require.True(t, exists)
84+
85+
// Migrate back down.
86+
_, err = migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{
87+
TargetVersion: -1,
88+
})
89+
require.NoError(t, err)
90+
91+
// Verify tables are gone.
92+
exists, err = exec.TableExists(ctx, &riverdriver.TableExistsParams{
93+
Schema: schema,
94+
Table: "river_job",
95+
})
96+
require.NoError(t, err)
97+
require.False(t, exists)
98+
})
99+
}

riverdriver/riverpgxv5/river_pgx_v5_driver.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/riverqueue/river/riverdriver/riverpgxv5/internal/dbsqlc"
2929
"github.com/riverqueue/river/rivershared/sqlctemplate"
3030
"github.com/riverqueue/river/rivershared/uniquestates"
31+
"github.com/riverqueue/river/rivershared/util/dbutil"
3132
"github.com/riverqueue/river/rivershared/util/ptrutil"
3233
"github.com/riverqueue/river/rivershared/util/sliceutil"
3334
"github.com/riverqueue/river/rivertype"
@@ -151,7 +152,7 @@ func (e *Executor) Exec(ctx context.Context, sql string, args ...any) error {
151152
func (e *Executor) IndexDropIfExists(ctx context.Context, params *riverdriver.IndexDropIfExistsParams) error {
152153
var maybeSchema string
153154
if params.Schema != "" {
154-
maybeSchema = params.Schema + "."
155+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
155156
}
156157

157158
_, err := e.dbtx.Exec(ctx, "DROP INDEX CONCURRENTLY IF EXISTS "+maybeSchema+params.Index)
@@ -172,7 +173,7 @@ func (e *Executor) IndexExists(ctx context.Context, params *riverdriver.IndexExi
172173
func (e *Executor) IndexReindex(ctx context.Context, params *riverdriver.IndexReindexParams) error {
173174
var maybeSchema string
174175
if params.Schema != "" {
175-
maybeSchema = params.Schema + "."
176+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
176177
}
177178

178179
_, err := e.dbtx.Exec(ctx, "REINDEX INDEX CONCURRENTLY "+maybeSchema+params.Index)
@@ -957,12 +958,12 @@ func (e *Executor) QueryRow(ctx context.Context, sql string, args ...any) riverd
957958
}
958959

959960
func (e *Executor) SchemaCreate(ctx context.Context, params *riverdriver.SchemaCreateParams) error {
960-
_, err := e.dbtx.Exec(ctx, "CREATE SCHEMA "+params.Schema)
961+
_, err := e.dbtx.Exec(ctx, "CREATE SCHEMA "+dbutil.SafeIdentifier(params.Schema))
961962
return interpretError(err)
962963
}
963964

964965
func (e *Executor) SchemaDrop(ctx context.Context, params *riverdriver.SchemaDropParams) error {
965-
_, err := e.dbtx.Exec(ctx, "DROP SCHEMA "+params.Schema+" CASCADE")
966+
_, err := e.dbtx.Exec(ctx, "DROP SCHEMA "+dbutil.SafeIdentifier(params.Schema)+" CASCADE")
966967
return interpretError(err)
967968
}
968969

@@ -981,7 +982,7 @@ func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExi
981982
// Different from other operations because the schemaAndTable name is a parameter.
982983
schemaAndTable := params.Table
983984
if params.Schema != "" {
984-
schemaAndTable = params.Schema + "." + schemaAndTable
985+
schemaAndTable = dbutil.SafeIdentifier(params.Schema) + "." + schemaAndTable
985986
}
986987

987988
exists, err := dbsqlc.New().TableExists(ctx, e.dbtx, schemaAndTable)
@@ -991,7 +992,7 @@ func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExi
991992
func (e *Executor) TableTruncate(ctx context.Context, params *riverdriver.TableTruncateParams) error {
992993
var maybeSchema string
993994
if params.Schema != "" {
994-
maybeSchema = params.Schema + "."
995+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
995996
}
996997

997998
// Uses raw SQL so we can truncate multiple tables at once.
@@ -1303,7 +1304,7 @@ func schemaCopyFrom(ctx context.Context, schema string) context.Context {
13031304

13041305
func schemaTemplateParam(ctx context.Context, schema string) context.Context {
13051306
if schema != "" {
1306-
schema += "."
1307+
schema = dbutil.SafeIdentifier(schema) + "."
13071308
}
13081309

13091310
return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{

riverdriver/riverpgxv5/river_pgx_v5_driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ func TestSchemaTemplateParam(t *testing.T) {
236236
nil,
237237
)
238238
require.NoError(t, err)
239-
require.Equal(t, "SELECT 1 FROM custom_schema.river_job", updatedSQL)
239+
require.Equal(t, `SELECT 1 FROM "custom_schema".river_job`, updatedSQL)
240240
})
241241
}
242242

riverdriver/riversqlite/river_sqlite_driver.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func (e *Executor) Exec(ctx context.Context, sql string, args ...any) error {
199199
func (e *Executor) IndexDropIfExists(ctx context.Context, params *riverdriver.IndexDropIfExistsParams) error {
200200
var maybeSchema string
201201
if params.Schema != "" {
202-
maybeSchema = params.Schema + "."
202+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
203203
}
204204

205205
_, err := e.dbtx.ExecContext(ctx, "DROP INDEX IF EXISTS "+maybeSchema+params.Index)
@@ -214,7 +214,7 @@ func (e *Executor) IndexExists(ctx context.Context, params *riverdriver.IndexExi
214214
func (e *Executor) IndexReindex(ctx context.Context, params *riverdriver.IndexReindexParams) error {
215215
var maybeSchema string
216216
if params.Schema != "" {
217-
maybeSchema = params.Schema + "."
217+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
218218
}
219219

220220
_, err := e.dbtx.ExecContext(ctx, "REINDEX "+maybeSchema+params.Index)
@@ -1452,7 +1452,7 @@ func (e *Executor) TableExists(ctx context.Context, params *riverdriver.TableExi
14521452
func (e *Executor) TableTruncate(ctx context.Context, params *riverdriver.TableTruncateParams) error {
14531453
var maybeSchema string
14541454
if params.Schema != "" {
1455-
maybeSchema = params.Schema + "."
1455+
maybeSchema = dbutil.SafeIdentifier(params.Schema) + "."
14561456
}
14571457

14581458
// SQLite doesn't have a `TRUNCATE` command, but `DELETE FROM` is optimized
@@ -1672,7 +1672,7 @@ func migrationFromInternal(internal *dbsqlc.RiverMigration) *riverdriver.Migrati
16721672

16731673
func schemaTemplateParam(ctx context.Context, schema string) context.Context {
16741674
if schema != "" {
1675-
schema += "."
1675+
schema = dbutil.SafeIdentifier(schema) + "."
16761676
}
16771677

16781678
return sqlctemplate.WithReplacements(ctx, map[string]sqlctemplate.Replacement{

riverdriver/riversqlite/river_sqlite_driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,6 @@ func TestSchemaTemplateParam(t *testing.T) {
9292
nil,
9393
)
9494
require.NoError(t, err)
95-
require.Equal(t, "SELECT 1 FROM custom_schema.river_job", updatedSQL)
95+
require.Equal(t, `SELECT 1 FROM "custom_schema".river_job`, updatedSQL)
9696
})
9797
}

rivermigrate/river_migrate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex
547547

548548
var schema string
549549
if m.schema != "" {
550-
schema = m.schema + "."
550+
schema = dbutil.SafeIdentifier(m.schema) + "."
551551
}
552552
schemaReplacement := map[string]sqlctemplate.Replacement{
553553
"schema": {Value: schema},

0 commit comments

Comments
 (0)