diff --git a/internal/builders.go b/internal/builders.go index c091558..0549d91 100644 --- a/internal/builders.go +++ b/internal/builders.go @@ -2,13 +2,14 @@ package internal import ( "fmt" + "sort" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/inflection" "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" - "sort" - "strings" ) func (gen *PythonGenerator) buildTable(schema *plugin.Schema, table *plugin.Table) core.Table { @@ -47,8 +48,7 @@ func (gen *PythonGenerator) buildTables() []core.Table { continue } for _, table := range schema.Tables { - t := gen.buildTable(schema, table) - tables = append(tables, t) + tables = append(tables, gen.buildTable(schema, table)) } } if len(tables) > 0 { @@ -59,7 +59,7 @@ func (gen *PythonGenerator) buildTables() []core.Table { func (gen *PythonGenerator) makePythonType(col *plugin.Column) core.PyType { columnType := sdk.DataType(col.Type) - strType := gen.typeConversionFunc(gen.req, col, gen.config) + strType := gen.typeConversionFunc(gen.req, col.Type, gen.config) for _, override := range gen.config.Overrides { if override.PyTypeName == "" { continue @@ -122,6 +122,16 @@ func (gen *PythonGenerator) buildEnums() []core.Enum { Name: core.SnakeToCamel(enumName, gen.config), Comment: enum.Comment, } + enumType := core.PyType{ + SqlType: enumName, + Type: e.Name, + DefaultType: "", + IsList: false, + IsNullable: false, + IsEnum: true, + IsOverride: false, + Override: nil, + } seen := make(map[string]struct{}, len(enum.Vals)) for i, v := range enum.Vals { @@ -132,7 +142,7 @@ func (gen *PythonGenerator) buildEnums() []core.Enum { e.Constants = append(e.Constants, core.Constant{ Name: core.SnakeToCamel(enumName+"_"+value, gen.config), Value: v, - Type: e.Name, + Type: enumType, }) seen[value] = struct{}{} } diff --git a/internal/codegen/builders/docstrings.go b/internal/codegen/builders/docstrings.go index 19976e1..44608c6 100644 --- a/internal/codegen/builders/docstrings.go +++ b/internal/codegen/builders/docstrings.go @@ -2,35 +2,26 @@ package builders import ( "fmt" + "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/sqlc-dev/plugin-sdk-go/metadata" ) -var docstringConfig *string -var docstringConfigEmitSQL *bool -var docstringConfigDriver core.SQLDriverType = core.SQLDriverAsyncpg - -func SetDocstringConfig(c *string, b *bool, d core.SQLDriverType) { - docstringConfig = c - docstringConfigEmitSQL = b - docstringConfigDriver = d -} - func (b *IndentStringBuilder) WriteQueryResultsIterDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""`+"Initialize iteration support.") b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "QueryResults[T]") b.WriteIndentedLine(3, "Self as an iterator.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "Self as an iterator.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "Self as an iterator.") } @@ -38,20 +29,20 @@ func (b *IndentStringBuilder) WriteQueryResultsIterDocstring() { } func (b *IndentStringBuilder) WriteQueryResultsAiterDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""`+"Initialize iteration support for `async for`.") b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "QueryResults[T]") b.WriteIndentedLine(3, "Self as an asynchronous iterator.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "Self as an asynchronous iterator.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "Self as an asynchronous iterator.") } @@ -59,12 +50,12 @@ func (b *IndentStringBuilder) WriteQueryResultsAiterDocstring() { } func (b *IndentStringBuilder) WriteQueryResultsNextDocstringSqlite() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""Yield the next item in the query result using a sqlite3 cursor.`) b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "T") @@ -74,13 +65,13 @@ func (b *IndentStringBuilder) WriteQueryResultsNextDocstringSqlite() { b.WriteIndentedLine(2, "------") b.WriteIndentedLine(2, "StopIteration") b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "The next decoded result of type `T`.") b.NewLine() b.WriteIndentedLine(2, "Raises:") b.WriteIndentedLine(3, "StopIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "The next decoded result of type `T`.") b.NewLine() @@ -91,12 +82,12 @@ func (b *IndentStringBuilder) WriteQueryResultsNextDocstringSqlite() { } func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAiosqlite() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""Yield the next item in the query result using an aiosqlite cursor.`) b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "T") @@ -106,13 +97,13 @@ func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAiosqlite() { b.WriteIndentedLine(2, "------") b.WriteIndentedLine(2, "StopAsyncIteration") b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "The next decoded result of type `T`.") b.NewLine() b.WriteIndentedLine(2, "Raises:") b.WriteIndentedLine(3, "StopAsyncIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "The next decoded result of type `T`.") b.NewLine() @@ -123,12 +114,12 @@ func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAiosqlite() { } func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAsyncpg() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""Yield the next item in the query result using an asyncpg cursor.`) b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "T") @@ -138,13 +129,13 @@ func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAsyncpg() { b.WriteIndentedLine(2, "------") b.WriteIndentedLine(2, "StopAsyncIteration") b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "The next decoded result of type `T`.") b.NewLine() b.WriteIndentedLine(2, "Raises:") b.WriteIndentedLine(3, "StopAsyncIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "The next decoded result of type `T`.") b.NewLine() @@ -155,20 +146,20 @@ func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAsyncpg() { } func (b *IndentStringBuilder) WriteQueryResultsAwaitDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""`+"Allow `await` on the object to return all rows as a fully decoded sequence.") b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "collections.abc.Sequence[T]") b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "A sequence of decoded objects of type `T`.") } @@ -176,20 +167,20 @@ func (b *IndentStringBuilder) WriteQueryResultsAwaitDocstring() { } func (b *IndentStringBuilder) WriteQueryResultsCallDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""`+"Allow calling the object to return all rows as a fully decoded sequence.") b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, "collections.abc.Sequence[T]") b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, "A sequence of decoded objects of type `T`.") } @@ -197,13 +188,13 @@ func (b *IndentStringBuilder) WriteQueryResultsCallDocstring() { } func (b *IndentStringBuilder) WriteQueryResultsInitDocstring(docstringConnType string, docstringDriverReturnType string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedString(2, fmt.Sprintf(`"""Initialize the QueryResults instance.`)) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteLine(`"""`) - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.NNewLine(2) b.WriteIndentedLine(2, "Args:") b.WriteIndentedLine(3, "conn:") @@ -215,7 +206,7 @@ func (b *IndentStringBuilder) WriteQueryResultsInitDocstring(docstringConnType s b.WriteIndentedLine(3, "*args:") b.WriteIndentedLine(4, "Arguments that should be sent when executing the sql query.") b.WriteIndentedLine(2, `"""`) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.NNewLine(2) b.WriteIndentedLine(2, "Arguments:") b.WriteIndentedLine(2, fmt.Sprintf("conn -- The connection object of type `%s` used to execute queries.", docstringConnType)) @@ -227,11 +218,11 @@ func (b *IndentStringBuilder) WriteQueryResultsInitDocstring(docstringConnType s } func (b *IndentStringBuilder) WriteQueryResultsClassDocstring(docstringConnType string, docstringDriverReturnType string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedString(1, `"""Helper class that allows both iteration and normal fetching of data from the db.`) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.NewLine() b.NewLine() b.WriteIndentedLine(1, "Parameters") @@ -253,20 +244,20 @@ func (b *IndentStringBuilder) WriteQueryResultsClassDocstring(docstringConnType } func (b *IndentStringBuilder) WriteQueryClassConnDocstring(docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(2, `"""Connection object used to make queries.`) b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(2, "Returns") b.WriteIndentedLine(2, "-------") b.WriteIndentedLine(2, docstringConnType) b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(3, fmt.Sprintf("Connection object of type `%s` used to make queries.", docstringConnType)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(2, "Returns:") b.WriteIndentedLine(2, fmt.Sprintf("%s -- Connection object used to make queries.", docstringConnType)) } @@ -274,11 +265,11 @@ func (b *IndentStringBuilder) WriteQueryClassConnDocstring(docstringConnType str } func (b *IndentStringBuilder) WriteQueryClassDocstring(sourceName string, docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedString(1, fmt.Sprintf(`"""Queries from file %s.`, sourceName)) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.NewLine() b.NewLine() b.WriteIndentedLine(1, "Parameters") @@ -294,19 +285,19 @@ func (b *IndentStringBuilder) WriteQueryClassDocstring(sourceName string, docstr } func (b *IndentStringBuilder) WriteQueryClassInitDocstring(lvl int, docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedString(lvl, fmt.Sprintf(`"""Initialize the instance using the connection.`)) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteLine(`"""`) - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.NNewLine(2) b.WriteIndentedLine(lvl, "Args:") b.WriteIndentedLine(lvl+1, "conn:") b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) b.WriteIndentedLine(lvl, `"""`) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.NNewLine(2) b.WriteIndentedLine(lvl, "Arguments:") b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute queries.", docstringConnType)) @@ -315,11 +306,11 @@ func (b *IndentStringBuilder) WriteQueryClassInitDocstring(lvl int, docstringCon } func (b *IndentStringBuilder) WriteModelClassDocstring(table *core.Table) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteIndentedLine(1, `"""`+fmt.Sprintf("Model representing %s.", table.Name)) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.NewLine() b.WriteIndentedLine(1, "Attributes") b.WriteIndentedLine(1, "----------") @@ -334,7 +325,7 @@ func (b *IndentStringBuilder) WriteModelClassDocstring(table *core.Table) { b.WriteIndentedLine(1, fmt.Sprintf("%s : %s", col.Name, type_)) } b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.NewLine() b.WriteIndentedLine(1, "Attributes:") for _, col := range table.Columns { @@ -347,7 +338,7 @@ func (b *IndentStringBuilder) WriteModelClassDocstring(table *core.Table) { } b.WriteIndentedLine(2, fmt.Sprintf("%s: %s", col.Name, type_)) } - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.NewLine() b.WriteIndentedLine(1, "Attributes:") for _, col := range table.Columns { @@ -366,39 +357,41 @@ func (b *IndentStringBuilder) WriteModelClassDocstring(table *core.Table) { } func (b *IndentStringBuilder) WriteModelFileModuleDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteLine(`"""Module containing models."""`) } func (b *IndentStringBuilder) WriteQueryFileModuleDocstring(sourceName string) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteLine(fmt.Sprintf(`"""Module containing queries from file %s."""`, sourceName)) } func (b *IndentStringBuilder) WriteInitFileModuleDocstring() { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } b.WriteLine(`"""Package containing queries and models automatically generated using sqlc-gen-better-python."""`) } func (b *IndentStringBuilder) writeQueryFunctionSQL(lvl int, query *core.Query) { - if *docstringConfigEmitSQL { - b.WriteIndentedLine(lvl, "```sql") - for _, line := range core.SplitLines(query.SQL) { - b.WriteIndentedLine(lvl, line) - } - b.WriteIndentedLine(lvl, "```") - b.NewLine() + if b.docstringOmitSQL { + return } + + b.WriteIndentedLine(lvl, "```sql") + for _, line := range core.SplitLines(query.SQL) { + b.WriteIndentedLine(lvl, line) + } + b.WriteIndentedLine(lvl, "```") + b.NewLine() } func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Query, docstringConnType string, queryArgs []core.FunctionArg, returnType core.PyType) { - if *docstringConfig == core.DocstringConventionNone { + if b.docstringConvention == core.DocstringConventionNone { return } @@ -410,7 +403,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`) return } - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") if docstringConnType != "" { @@ -421,7 +414,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) } b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { b.WriteIndentedLine(lvl+1, "conn:") @@ -430,7 +423,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q for _, arg := range queryArgs { b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) } - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) @@ -444,7 +437,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute SQL query with `name: %s %s` and return the number of affected rows.", query.MethodName, query.Cmd)) b.NewLine() b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -460,13 +453,13 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, "Returns") b.WriteIndentedLine(lvl, "-------") b.WriteIndentedLine(lvl, returnType.Type) - if docstringConfigDriver == core.SQLDriverAioSQLite { + if b.docstringDriver == core.SQLDriverAioSQLite { b.WriteIndentedLine(lvl+1, "The number of affected rows. This will be -1 for queries like `CREATE TABLE`.") } else { b.WriteIndentedLine(lvl+1, "The number of affected rows. This will be 0 for queries like `CREATE TABLE`.") } b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -479,12 +472,12 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.NewLine() } b.WriteIndentedLine(lvl, "Returns:") - if docstringConfigDriver == core.SQLDriverAioSQLite { + if b.docstringDriver == core.SQLDriverAioSQLite { b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows. This will be -1 for queries like `CREATE TABLE`.", returnType.Type)) } else { b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows. This will be 0 for queries like `CREATE TABLE`.", returnType.Type)) } - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { @@ -496,7 +489,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.NewLine() } b.WriteIndentedLine(lvl, "Returns:") - if docstringConfigDriver == core.SQLDriverAioSQLite { + if b.docstringDriver == core.SQLDriverAioSQLite { b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s -- The number of affected rows. This will be -1 for queries like `CREATE TABLE`.", returnType.Type)) } else { b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s -- The number of affected rows. This will be 0 for queries like `CREATE TABLE`.", returnType.Type)) @@ -506,7 +499,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } else if query.Cmd == metadata.CmdCopyFrom { b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute COPY FROM query to insert rows into a table with `name: %s %s` and return the number of affected rows.", query.MethodName, query.Cmd)) b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -525,7 +518,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, returnType.Type) b.WriteIndentedLine(lvl+1, "The number of affected rows.") b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -540,7 +533,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } b.WriteIndentedLine(lvl, "Returns:") b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { @@ -559,7 +552,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute and return the result of SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) b.NewLine() b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -577,7 +570,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, returnType.Type) b.WriteIndentedLine(lvl+1, "The result returned when executing the query.") b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -591,7 +584,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } b.WriteIndentedLine(lvl, "Returns:") b.WriteIndentedLine(lvl+1, fmt.Sprintf("The result of type `%s` returned when executing the query.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { @@ -610,7 +603,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute SQL query with `name: %s %s` and return the id of the last affected row.", query.MethodName, query.Cmd)) b.NewLine() b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -628,7 +621,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, returnType.Type) b.WriteIndentedLine(lvl+1, "The id of the last affected row. Will be `None` if no rows are affected.") b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -642,7 +635,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } b.WriteIndentedLine(lvl, "Returns:") b.WriteIndentedLine(lvl+1, fmt.Sprintf("The id (`%s`) of the last affected row. Will be `None` if no rows are affected.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { @@ -661,7 +654,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Fetch one from the db using the SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) b.NewLine() b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -680,7 +673,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl+1, "Result fetched from the db. Will be `None` if not found.") b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -694,7 +687,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } b.WriteIndentedLine(lvl, "Returns:") b.WriteIndentedLine(lvl+1, fmt.Sprintf("Result of type `%s` fetched from the db. Will be `None` if not found.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { @@ -713,7 +706,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Fetch many from the db using the SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) b.NewLine() b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { + if b.docstringConvention == core.DocstringConventionNumpy { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Parameters") b.WriteIndentedLine(lvl, "----------") @@ -731,7 +724,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q b.WriteIndentedLine(lvl, fmt.Sprintf("QueryResults[%s]", returnType.Type)) b.WriteIndentedLine(lvl+1, "Helper class that allows both iteration and normal fetching of data from the db.") b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { + } else if b.docstringConvention == core.DocstringConventionGoogle { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Args:") if docstringConnType != "" { @@ -745,7 +738,7 @@ func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Q } b.WriteIndentedLine(lvl, "Returns:") b.WriteIndentedLine(lvl+1, fmt.Sprintf("Helper class of type `QueryResults[%s]` that allows both iteration and normal fetching of data from the db.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { + } else if b.docstringConvention == core.DocstringConventionPEP257 { if len(queryArgs) != 0 || docstringConnType != "" { b.WriteIndentedLine(lvl, "Arguments:") if docstringConnType != "" { diff --git a/internal/codegen/builders/string.go b/internal/codegen/builders/string.go index 4b17299..4530a97 100644 --- a/internal/codegen/builders/string.go +++ b/internal/codegen/builders/string.go @@ -2,9 +2,10 @@ package builders import ( "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" "os" "strings" + + "github.com/rayakame/sqlc-gen-better-python/internal/core" ) type IndentStringBuilder struct { @@ -12,12 +13,19 @@ type IndentStringBuilder struct { indentChar string charsPerIndentLevel int + + docstringOmitSQL bool + docstringConvention core.DocstringConvention + docstringDriver core.SQLDriver } -func NewIndentStringBuilder(indentChar string, charsPerIndentLevel int) *IndentStringBuilder { +func NewIndentStringBuilder(indentChar string, charsPerIndentLevel int, docstringConvention core.DocstringConvention, docstringOmitSQL bool, docstringDriver core.SQLDriver) *IndentStringBuilder { return &IndentStringBuilder{ indentChar: indentChar, charsPerIndentLevel: charsPerIndentLevel, + docstringConvention: docstringConvention, + docstringDriver: docstringDriver, + docstringOmitSQL: docstringOmitSQL, } } @@ -68,3 +76,7 @@ func (b *IndentStringBuilder) NNewLine(n int) { b.WriteString("\n") } } + +func (b *IndentStringBuilder) Bytes() []byte { + return []byte(b.String()) +} diff --git a/internal/codegen/common.go b/internal/codegen/common.go index e302e4f..2a76f97 100644 --- a/internal/codegen/common.go +++ b/internal/codegen/common.go @@ -2,6 +2,7 @@ package codegen import ( "fmt" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/rayakame/sqlc-gen-better-python/internal/codegen/drivers" "github.com/rayakame/sqlc-gen-better-python/internal/core" @@ -58,7 +59,6 @@ func NewDriver(conf *core.Config) (*Driver, error) { default: return nil, fmt.Errorf("unsupported driver: %s", conf.SqlDriver.String()) } - builders.SetDocstringConfig(conf.EmitDocstrings, conf.EmitDocstringsSQL, conf.SqlDriver) return &Driver{ buildPyQueryFunc: buildPyQueryFunc, @@ -70,6 +70,16 @@ func NewDriver(conf *core.Config) (*Driver, error) { }, nil } +func (dr *Driver) GetStringBuilder() *builders.IndentStringBuilder { + return builders.NewIndentStringBuilder( + dr.conf.IndentChar, + dr.conf.CharsPerIndentLevel, + dr.conf.EmitDocstrings, + dr.conf.OmitDocstringsSQL, + dr.conf.SqlDriver, + ) +} + func (dr *Driver) supportedCMD(command string) error { cmds := dr.acceptedDriverCMDs() for _, cmd := range cmds { diff --git a/internal/codegen/drivers/aiosqlite.go b/internal/codegen/drivers/aiosqlite.go index fe634fa..6bd6f2d 100644 --- a/internal/codegen/drivers/aiosqlite.go +++ b/internal/codegen/drivers/aiosqlite.go @@ -2,14 +2,15 @@ package drivers import ( "fmt" + "strconv" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" "github.com/rayakame/sqlc-gen-better-python/internal/types" "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strconv" - "strings" ) const AioSQLiteConn = "aiosqlite.Connection" @@ -43,11 +44,11 @@ func AioSQLiteBuildTypeConvFunc(queries []core.Query, body *builders.IndentStrin toConvert := make(map[string]bool) for _, query := range queries { for sqlType, _ := range typeConversion.SqliteGetConversions() { - name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Column{Type: &plugin.Identifier{ + name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Identifier{ Catalog: "", Schema: "", Name: sqlType, - }}, conf) + }, conf) if queryValueUses(name, query.Ret) { toConvert[name] = true } diff --git a/internal/codegen/drivers/sqlite3.go b/internal/codegen/drivers/sqlite3.go index e43f78c..9ee1a18 100644 --- a/internal/codegen/drivers/sqlite3.go +++ b/internal/codegen/drivers/sqlite3.go @@ -2,14 +2,15 @@ package drivers import ( "fmt" + "strconv" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" "github.com/rayakame/sqlc-gen-better-python/internal/types" "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strconv" - "strings" ) const Sqlite3Result = "sqlite3.Row" @@ -44,11 +45,11 @@ func SQLite3BuildTypeConvFunc(queries []core.Query, body *builders.IndentStringB toConvert := make(map[string]bool) for _, query := range queries { for sqlType, _ := range typeConversion.SqliteGetConversions() { - name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Column{Type: &plugin.Identifier{ + name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Identifier{ Catalog: "", Schema: "", Name: sqlType, - }}, conf) + }, conf) if queryValueUses(name, query.Ret) { toConvert[name] = true } diff --git a/internal/codegen/enums.go b/internal/codegen/enums.go new file mode 100644 index 0000000..0d1f538 --- /dev/null +++ b/internal/codegen/enums.go @@ -0,0 +1,73 @@ +package codegen + +import ( + "fmt" + + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" + "github.com/rayakame/sqlc-gen-better-python/internal/core" + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +func (dr *Driver) BuildPyEnumsFile(imp *core.Importer, enums []core.Enum) (*plugin.File, error) { + fileName, fileContent, err := dr.buildPyEnums(imp, enums) + if err != nil { + return nil, err + } + return &plugin.File{ + Name: core.SQLToPyFileName(fileName), + Contents: fileContent, + }, nil +} + +func (dr *Driver) buildPyEnum(enum *core.Enum, body *builders.IndentStringBuilder) { + body.WriteLine(fmt.Sprintf("class %s(enum.StrEnum):", enum.Name)) + for _, constant := range enum.Constants { + body.WriteIndentedLine(1, constant.Name+" = \""+constant.Value+"\"") + } +} + +func (dr *Driver) buildPyEnums(imp *core.Importer, enums []core.Enum) (string, []byte, error) { + fileName := "enums.sql" + body := dr.GetStringBuilder() + body.WriteSqlcHeader() + body.WriteModelFileModuleDocstring() + body.WriteImportAnnotations() + body.WriteLine("__all__: collections.abc.Sequence[str] = (") + for _, table := range enums { + body.WriteIndentedLine(1, fmt.Sprintf("\"%s\",", table.Name)) + } + body.WriteLine(")") + body.WriteString("\n") + std, tye, pkg := imp.Imports(fileName) + for _, imp := range std { + body.WriteLine(imp) + } + if len(tye) != 0 { + if len(std) != 0 { + body.NewLine() + } + if !dr.conf.OmitTypecheckingBlock { + body.WriteLine("if typing.TYPE_CHECKING:") + for _, imp := range tye { + body.WriteIndentedLine(1, imp) + } + } else { + for _, imp := range tye { + body.WriteLine(imp) + } + } + } + for i, imp := range pkg { + if i == 0 { + body.NewLine() + } + body.WriteLine(imp) + } + + for _, enum := range enums { + body.WriteString("\n") + body.WriteString("\n") + dr.buildPyEnum(&enum, body) + } + return fileName, body.Bytes(), nil +} diff --git a/internal/codegen/init.go b/internal/codegen/init.go index 937106d..510917e 100644 --- a/internal/codegen/init.go +++ b/internal/codegen/init.go @@ -1,12 +1,11 @@ package codegen import ( - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/sqlc-dev/plugin-sdk-go/plugin" ) func (dr *Driver) BuildInitFile() *plugin.File { - body := builders.NewIndentStringBuilder(dr.conf.IndentChar, dr.conf.CharsPerIndentLevel) + body := dr.GetStringBuilder() body.WriteSqlcHeader() body.WriteInitFileModuleDocstring() return &plugin.File{ diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go index 6471ce5..6ca3a25 100644 --- a/internal/codegen/queries.go +++ b/internal/codegen/queries.go @@ -2,13 +2,14 @@ package codegen import ( "fmt" + "sort" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/rayakame/sqlc-gen-better-python/internal/codegen/drivers" "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" - "sort" - "strings" ) func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) ([]core.FunctionArg, string, []string) { @@ -119,7 +120,7 @@ func (dr *Driver) buildClassTemplate(sourceName string, body *builders.IndentStr } func (dr *Driver) buildPyQueriesFile(imp *core.Importer, queries []core.Query, sourceName string) ([]byte, error) { - body := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) + body := dr.GetStringBuilder() body.WriteSqlcHeader() body.WriteQueryFileModuleDocstring(sourceName) body.WriteImportAnnotations() @@ -130,8 +131,8 @@ func (dr *Driver) buildPyQueriesFile(imp *core.Importer, queries []core.Query, s } allNames := make([]string, 0) - funcBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) - pyTableBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) + funcBody := dr.GetStringBuilder() + pyTableBody := dr.GetStringBuilder() for _, query := range queries { if !dr.conf.EmitClasses { allNames = append(allNames, query.FuncName) diff --git a/internal/codegen/tables.go b/internal/codegen/tables.go index 7cd8312..92ad5e6 100644 --- a/internal/codegen/tables.go +++ b/internal/codegen/tables.go @@ -2,6 +2,7 @@ package codegen import ( "fmt" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/sqlc-dev/plugin-sdk-go/plugin" @@ -18,7 +19,7 @@ func (dr *Driver) BuildPyTablesFile(imp *core.Importer, tables []core.Table) (*p }, nil } -func BuildPyTabel(modelType string, table *core.Table, body *builders.IndentStringBuilder) { +func BuildPyTabel(modelType core.ModelType, table *core.Table, body *builders.IndentStringBuilder) { if modelType == core.ModelTypeDataclass { body.WriteLine("@dataclasses.dataclass()") } else if modelType == core.ModelTypeAttrs { @@ -44,7 +45,7 @@ func BuildPyTabel(modelType string, table *core.Table, body *builders.IndentStri func (dr *Driver) buildPyTables(imp *core.Importer, tables []core.Table) (string, []byte, error) { fileName := "models.sql" - body := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) + body := dr.GetStringBuilder() body.WriteSqlcHeader() body.WriteModelFileModuleDocstring() body.WriteImportAnnotations() diff --git a/internal/core/config.go b/internal/core/config.go index 3481888..22bfaf7 100644 --- a/internal/core/config.go +++ b/internal/core/config.go @@ -3,28 +3,30 @@ package core import ( "encoding/json" "fmt" + + "github.com/rayakame/sqlc-gen-better-python/internal/utils" "github.com/sqlc-dev/plugin-sdk-go/plugin" ) const PluginVersion = "v0.4.5" type Config struct { - Package string `json:"package" yaml:"package"` - SqlDriver SQLDriverType `json:"sql_driver" yaml:"sql_driver"` - ModelType string `json:"model_type" yaml:"model_type"` - Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms,omitempty"` - EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` - EmitClasses bool `json:"emit_classes" yaml:"emit_classes"` - InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names,omitempty"` - OmitUnusedModels bool `json:"omit_unused_models" yaml:"omit_unused_models"` - OmitTypecheckingBlock bool `json:"omit_typechecking_block" yaml:"omit_typechecking_block"` - QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` - OmitKwargsLimit *int32 `json:"omit_kwargs_limit,omitempty" yaml:"omit_kwargs_limit"` - EmitInitFile *bool `json:"emit_init_file" yaml:"emit_init_file"` - EmitDocstrings *string `json:"docstrings" yaml:"docstrings"` - EmitDocstringsSQL *bool `json:"docstrings_emit_sql" yaml:"docstrings_emit_sql"` - Speedups bool `json:"speedups" yaml:"speedups"` - Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` + Package string `json:"package" yaml:"package"` + SqlDriver SQLDriver `json:"sql_driver" yaml:"sql_driver"` + ModelType ModelType `json:"model_type" yaml:"model_type"` + Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms,omitempty"` + EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` + EmitClasses bool `json:"emit_classes" yaml:"emit_classes"` + InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names,omitempty"` + OmitUnusedModels bool `json:"omit_unused_models" yaml:"omit_unused_models"` + OmitTypecheckingBlock bool `json:"omit_typechecking_block" yaml:"omit_typechecking_block"` + QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` + OmitKwargsLimit *int32 `json:"omit_kwargs_limit,omitempty" yaml:"omit_kwargs_limit"` + EmitInitFile *bool `json:"emit_init_file" yaml:"emit_init_file"` + EmitDocstrings DocstringConvention `json:"docstrings" yaml:"docstrings"` + OmitDocstringsSQL bool `json:"docstrings_emit_sql" yaml:"docstrings_emit_sql"` + Speedups bool `json:"speedups" yaml:"speedups"` + Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` Debug bool `json:"debug" yaml:"debug"` @@ -35,7 +37,19 @@ type Config struct { Async bool } -func ParseConfig(req *plugin.GenerateRequest) (*Config, error) { +func NewConfig(req *plugin.GenerateRequest) (*Config, error) { + config, err := parseConfig(req) + if err != nil { + return nil, err + } + err = validateConf(config, req.Settings.Engine) + if err != nil { + return nil, err + } + return config, nil +} + +func parseConfig(req *plugin.GenerateRequest) (*Config, error) { var config Config if len(req.PluginOptions) == 0 { return &config, nil @@ -43,14 +57,7 @@ func ParseConfig(req *plugin.GenerateRequest) (*Config, error) { if err := json.Unmarshal(req.PluginOptions, &config); err != nil { return nil, fmt.Errorf("unmarshalling plugin options: %w", err) } - if config.SqlDriver == "" { - return nil, fmt.Errorf("invalid options: driver must not be empty") - } - val, err := isDriverAsync(config.SqlDriver) - if err != nil { - return nil, fmt.Errorf("invalid options: %s", err) - } - config.Async = val + config.Async = config.SqlDriver.Async() for i := range config.Overrides { if err := config.Overrides[i].parse(req); err != nil { @@ -62,30 +69,22 @@ func ParseConfig(req *plugin.GenerateRequest) (*Config, error) { config.ModelType = ModelTypeDataclass } if config.QueryParameterLimit == nil { - config.QueryParameterLimit = new(int32) - *config.QueryParameterLimit = 1 + config.QueryParameterLimit = utils.ToPtr(int32(1)) } if config.OmitKwargsLimit == nil { config.OmitKwargsLimit = new(int32) - *config.OmitKwargsLimit = 0 } if config.Initialisms == nil { - config.Initialisms = new([]string) - *config.Initialisms = []string{"id"} + config.Initialisms = utils.ToPtr([]string{"id"}) } if config.IndentChar == "" { config.IndentChar = " " } - if config.CharsPerIndentLevel == 0 { + if config.CharsPerIndentLevel <= 0 { config.CharsPerIndentLevel = 4 } - if config.EmitDocstrings == nil { - config.EmitDocstrings = new(string) - *config.EmitDocstrings = DocstringConventionNone - } - if config.EmitDocstringsSQL == nil { - config.EmitDocstringsSQL = new(bool) - *config.EmitDocstringsSQL = true + if config.EmitDocstrings == "" { + config.EmitDocstrings = DocstringConventionNone } config.InitialismsMap = map[string]struct{}{} @@ -94,7 +93,7 @@ func ParseConfig(req *plugin.GenerateRequest) (*Config, error) { } return &config, nil } -func ValidateConf(conf *Config, engine string) error { +func validateConf(conf *Config, engine string) error { if *conf.QueryParameterLimit < 0 { return fmt.Errorf("invalid options: query parameter limit must not be negative") } @@ -110,16 +109,16 @@ func ValidateConf(conf *Config, engine string) error { return fmt.Errorf("invalid options: package must not be empty") } - if err := isDriverValid(conf.SqlDriver, engine); err != nil { - return err + if err := conf.SqlDriver.Validate(engine); err != nil { + return fmt.Errorf("invalid options: unknown model type: %e", err) } - if err := isModelTypeValid(conf.ModelType); err != nil { - return fmt.Errorf("invalid options: %s", err) + if !conf.ModelType.Valid() { + return fmt.Errorf("invalid options: unknown model type: %s", conf.ModelType) } - if err := isDocstringValid(conf.EmitDocstrings); err != nil { - return fmt.Errorf("invalid options: %s", err) + if !conf.EmitDocstrings.Valid() { + return fmt.Errorf("invalid options: unknown docstring convention: %s", conf.EmitDocstrings) } return nil diff --git a/internal/core/enums.go b/internal/core/enums.go index caf2bf8..56909ee 100644 --- a/internal/core/enums.go +++ b/internal/core/enums.go @@ -2,31 +2,35 @@ package core import "fmt" -type SQLDriverType string +type ( + SQLDriver string + DocstringConvention string + ModelType string +) -func (dr *SQLDriverType) String() string { - return string(*dr) +func (dr SQLDriver) String() string { + return string(dr) } const ( - SQLDriverSQLite SQLDriverType = "sqlite3" - SQLDriverAioSQLite SQLDriverType = "aiosqlite" - SQLDriverAsyncpg SQLDriverType = "asyncpg" + SQLDriverSQLite SQLDriver = "sqlite3" + SQLDriverAioSQLite SQLDriver = "aiosqlite" + SQLDriverAsyncpg SQLDriver = "asyncpg" ) const ( - ModelTypeDataclass = "dataclass" - ModelTypeAttrs = "attrs" - ModelTypeMsgspec = "msgspec" + ModelTypeDataclass ModelType = "dataclass" + ModelTypeAttrs ModelType = "attrs" + ModelTypeMsgspec ModelType = "msgspec" ) -var asyncDrivers = map[SQLDriverType]bool{ +var asyncDrivers = map[SQLDriver]bool{ SQLDriverSQLite: false, SQLDriverAioSQLite: true, SQLDriverAsyncpg: true, } -var driversEngine = map[SQLDriverType]string{ +var driversEngine = map[SQLDriver]string{ SQLDriverSQLite: "sqlite", SQLDriverAioSQLite: "sqlite", SQLDriverAsyncpg: "postgresql", @@ -39,51 +43,45 @@ var validModelTypes = map[string]struct{}{ } const ( - DocstringConventionNone = "none" - DocstringConventionGoogle = "google" - DocstringConventionNumpy = "numpy" - DocstringConventionPEP257 = "pep257" + DocstringConventionNone DocstringConvention = "none" + DocstringConventionGoogle DocstringConvention = "google" + DocstringConventionNumpy DocstringConvention = "numpy" + DocstringConventionPEP257 DocstringConvention = "pep257" ) -var validDocstringConventions = map[string]struct{}{ - DocstringConventionNone: {}, - DocstringConventionGoogle: {}, - DocstringConventionNumpy: {}, - DocstringConventionPEP257: {}, -} - -func isDriverAsync(sqlDriver SQLDriverType) (bool, error) { - val, found := asyncDrivers[sqlDriver] +func (dr SQLDriver) Async() bool { + val, found := asyncDrivers[dr] if !found { - return false, fmt.Errorf("unknown SQL driver: %s", sqlDriver) + return false } - return val, nil + return val } -func isDriverValid(sqlDriver SQLDriverType, engine string) error { - val, found := driversEngine[sqlDriver] +func (dr SQLDriver) Validate(engine string) error { + val, found := driversEngine[dr] if !found { - return fmt.Errorf("unknown SQL driver: %s", sqlDriver) + return fmt.Errorf("unknown SQL driver: %s", dr) } if val != engine { - return fmt.Errorf("SQL driver %s does not support %s", sqlDriver, engine) + return fmt.Errorf("SQL driver %s does not support %s", dr, engine) } return nil } -func isModelTypeValid(modelType string) error { - if _, found := validModelTypes[modelType]; !found { - return fmt.Errorf("unknown model type: %s", modelType) +func (modelType ModelType) Valid() bool { + switch modelType { + case ModelTypeDataclass, ModelTypeMsgspec, ModelTypeAttrs: + return true + default: + return false } - return nil } -func isDocstringValid(ds *string) error { - if ds == nil { - return nil - } - if _, found := validDocstringConventions[*ds]; !found { - return fmt.Errorf("unknown docstring convention: %s", ds) +func (ds DocstringConvention) Valid() bool { + switch ds { + case DocstringConventionNone, DocstringConventionNumpy, DocstringConventionGoogle, DocstringConventionPEP257: + return true + default: + return false } - return nil } diff --git a/internal/core/importer.go b/internal/core/importer.go index 3e2ec67..45dac3d 100644 --- a/internal/core/importer.go +++ b/internal/core/importer.go @@ -2,10 +2,11 @@ package core import ( "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/sqlc-dev/plugin-sdk-go/metadata" "sort" "strings" + + "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" + "github.com/sqlc-dev/plugin-sdk-go/metadata" ) type importSpec struct { @@ -39,6 +40,9 @@ func (i *Importer) Imports(fileName string) ([]string, []string, []string) { if fileName == "models.sql" { return i.modelImports() } + if fileName == "enums.sql" { + return i.enumImports() + } return i.queryImports(fileName) } @@ -49,7 +53,15 @@ func TableUses(name string, s Table) (bool, PyType) { } } return false, PyType{} +} +func enumUses(name string, e Enum) (bool, PyType) { + for _, constant := range e.Constants { + if constant.Type.Type == name { + return true, constant.Type + } + } + return false, PyType{} } func (i *Importer) getModelImportSpec() (string, importSpec, error) { @@ -78,6 +90,33 @@ func (i *Importer) splitTypeChecking(pks map[string]importSpec) (map[string]impo return normalImports, typeChecking } +func (i *Importer) enumImportSpecs() (map[string]importSpec, map[string]importSpec, map[string]importSpec) { + modelUses := func(name string) (bool, bool) { + for _, enum := range i.Enums { + if val, _ := enumUses(name, enum); val { + return true, true + } + } + return false, false + } + + std := stdImports(modelUses) + for _, override := range i.C.Overrides { + if val1, val2 := modelUses(override.PyTypeName); val1 { + std[override.PyTypeName] = importSpec{Module: override.PyImportPath, Name: override.PyPackageName, TypeChecking: val2} + } + } + std, typeChecking := i.splitTypeChecking(std) + if len(typeChecking) != 0 { + std["typing"] = importSpec{Module: "typing"} + } + std["enum"] = importSpec{Module: "enum"} + + pkg := make(map[string]importSpec) + + return std, typeChecking, pkg +} + func (i *Importer) modelImportSpecs() (map[string]importSpec, map[string]importSpec, map[string]importSpec) { modelUses := func(name string) (bool, bool) { for _, table := range i.Tables { @@ -103,7 +142,7 @@ func (i *Importer) modelImportSpecs() (map[string]importSpec, map[string]importS std[modelName] = modelImport } if len(i.Enums) > 0 { - std["enum"] = importSpec{Module: fmt.Sprintf("from %s import enums", i.C.Package)} + std["enum"] = importSpec{Module: i.C.Package, Name: "enums"} } pkg := make(map[string]importSpec) @@ -341,6 +380,23 @@ func (i *Importer) queryImports(fileName string) ([]string, []string, []string) return importLines, typeLines, packageLines } +func (i *Importer) enumImports() ([]string, []string, []string) { + std, typeCheck, pkg := i.enumImportSpecs() + importLines := make([]string, 0) + typeLines := make([]string, 0) + packageLines := make([]string, 0) + if len(std) != 0 { + importLines = append(importLines, buildImportBlock(std)...) + } + if len(typeCheck) != 0 { + typeLines = append(typeLines, buildImportBlock(typeCheck)...) + } + if len(pkg) != 0 { + packageLines = append(packageLines, buildImportBlock(pkg)...) + } + return importLines, typeLines, packageLines +} + func (i *Importer) modelImports() ([]string, []string, []string) { std, typeCheck, pkg := i.modelImportSpecs() importLines := make([]string, 0) diff --git a/internal/core/models.go b/internal/core/models.go index 22bbbcf..43d2630 100644 --- a/internal/core/models.go +++ b/internal/core/models.go @@ -1,10 +1,11 @@ package core import ( + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strings" ) type Table struct { @@ -37,7 +38,7 @@ func (p *PyType) DoOverride() bool { type Constant struct { Name string - Type string + Type PyType Value string } diff --git a/internal/gen.go b/internal/generator.go similarity index 83% rename from internal/gen.go rename to internal/generator.go index fda62d2..61c05a6 100644 --- a/internal/gen.go +++ b/internal/generator.go @@ -1,15 +1,14 @@ package internal import ( - "context" - "encoding/json" "fmt" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/codegen" "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/log" "github.com/rayakame/sqlc-gen-better-python/internal/types" "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strings" ) type PythonGenerator struct { @@ -21,13 +20,11 @@ type PythonGenerator struct { } func NewPythonGenerator(req *plugin.GenerateRequest) (*PythonGenerator, error) { - config, err := core.ParseConfig(req) + config, err := core.NewConfig(req) if err != nil { return nil, err } - if err = core.ValidateConf(config, req.Settings.Engine); err != nil { - return nil, err - } + var typeConversionFunc types.TypeConversionFunc switch req.Settings.Engine { case "postgresql": @@ -54,6 +51,7 @@ func NewPythonGenerator(req *plugin.GenerateRequest) (*PythonGenerator, error) { func (gen *PythonGenerator) Run() (*plugin.GenerateResponse, error) { outputFiles := make([]*plugin.File, 0) log.GlobalLogger.LogByte(gen.req.PluginOptions) + enums := gen.buildEnums() tables := gen.buildTables() queries, err := gen.buildQueries(tables) @@ -61,16 +59,11 @@ func (gen *PythonGenerator) Run() (*plugin.GenerateResponse, error) { return nil, err } - jsonData, _ := json.Marshal(gen.req) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(gen.config) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(enums) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(tables) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(queries) - log.GlobalLogger.LogByte(jsonData) + log.GlobalLogger.LogAny(gen.req) + log.GlobalLogger.LogAny(gen.config) + log.GlobalLogger.LogAny(enums) + log.GlobalLogger.LogAny(tables) + log.GlobalLogger.LogAny(queries) if gen.config.OmitUnusedModels { enums, tables = filterUnusedStructs(enums, tables, queries) @@ -94,11 +87,16 @@ func (gen *PythonGenerator) Run() (*plugin.GenerateResponse, error) { } else { outputFiles = append(outputFiles, files...) } + if file, err := gen.sqlDriver.BuildPyEnumsFile(&importer, enums); err != nil { + return nil, err + } else { + outputFiles = append(outputFiles, file) + } + if *gen.config.EmitInitFile { outputFiles = append(outputFiles, gen.sqlDriver.BuildInitFile()) } - jsonData, _ = json.Marshal(outputFiles) - log.GlobalLogger.LogByte(jsonData) + log.GlobalLogger.LogAny(outputFiles) if gen.config.Debug { fileName, fileContent := log.GlobalLogger.Print() outputFiles = append(outputFiles, &plugin.File{ @@ -109,14 +107,6 @@ func (gen *PythonGenerator) Run() (*plugin.GenerateResponse, error) { return &plugin.GenerateResponse{Files: outputFiles}, nil } -func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { - pythonGenerator, err := NewPythonGenerator(req) - if err != nil { - return nil, err - } - return pythonGenerator.Run() -} - func (gen *PythonGenerator) validate(enums []core.Enum, structs []core.Table) error { enumNames := make(map[string]struct{}) for _, enum := range enums { diff --git a/internal/handler.go b/internal/handler.go new file mode 100644 index 0000000..c6da716 --- /dev/null +++ b/internal/handler.go @@ -0,0 +1,15 @@ +package internal + +import ( + "context" + + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +func Handler(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { + pythonGenerator, err := NewPythonGenerator(req) + if err != nil { + return nil, err + } + return pythonGenerator.Run() +} diff --git a/internal/log/logger.go b/internal/log/logger.go index b672da1..6897a3f 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -1,5 +1,10 @@ package log +import ( + "encoding/json" + "fmt" +) + type Logger struct { messages []string } @@ -11,6 +16,15 @@ func (logger *Logger) LogByte(message []byte) { logger.messages = append(logger.messages, string(message)) } +func (logger *Logger) LogAny(message any) { + jsonData, err := json.Marshal(message) + if err != nil { + logger.Log(fmt.Sprintf("Error while trying to log any: %e", err)) + } else { + logger.LogByte(jsonData) + } +} + func (logger *Logger) Print() (string, []byte) { var loggedMessages string for _, message := range logger.messages { diff --git a/internal/types/common.go b/internal/types/common.go index 0267933..05aef59 100644 --- a/internal/types/common.go +++ b/internal/types/common.go @@ -5,4 +5,4 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/plugin" ) -type TypeConversionFunc func(req *plugin.GenerateRequest, col *plugin.Column, conf *core.Config) string +type TypeConversionFunc func(req *plugin.GenerateRequest, typ *plugin.Identifier, conf *core.Config) string diff --git a/internal/types/postgresql.go b/internal/types/postgresql.go index 61e09a3..3d07e86 100644 --- a/internal/types/postgresql.go +++ b/internal/types/postgresql.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/log" @@ -9,8 +10,8 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func PostgresTypeToPython(req *plugin.GenerateRequest, col *plugin.Column, conf *core.Config) string { - columnType := sdk.DataType(col.Type) +func PostgresTypeToPython(req *plugin.GenerateRequest, typ *plugin.Identifier, conf *core.Config) string { + columnType := sdk.DataType(typ) switch columnType { case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": @@ -54,9 +55,9 @@ func PostgresTypeToPython(req *plugin.GenerateRequest, col *plugin.Column, conf for _, enum := range schema.Enums { if columnType == enum.Name { if schema.Name == req.Catalog.DefaultSchema { - return "models." + core.ModelName(enum.Name, "", conf) + return "enums." + core.ModelName(enum.Name, "", conf) } - return "models." + core.ModelName(enum.Name, schema.Name, conf) + return "enums." + core.ModelName(enum.Name, schema.Name, conf) } } } diff --git a/internal/types/sqlite.go b/internal/types/sqlite.go index a15ed2a..984eaca 100644 --- a/internal/types/sqlite.go +++ b/internal/types/sqlite.go @@ -2,16 +2,17 @@ package types import ( "fmt" + "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/core" "github.com/rayakame/sqlc-gen-better-python/internal/log" - "strings" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func SqliteTypeToPython(_ *plugin.GenerateRequest, col *plugin.Column, _ *core.Config) string { - columnType := strings.ToLower(sdk.DataType(col.Type)) +func SqliteTypeToPython(_ *plugin.GenerateRequest, typ *plugin.Identifier, _ *core.Config) string { + columnType := strings.ToLower(sdk.DataType(typ)) switch columnType { case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8", "bigserial": diff --git a/internal/utils/common.go b/internal/utils/common.go new file mode 100644 index 0000000..e747d24 --- /dev/null +++ b/internal/utils/common.go @@ -0,0 +1,5 @@ +package utils + +func ToPtr[T any](t T) *T { + return &t +} diff --git a/plugin/main.go b/plugin/main.go index 3abf2c9..015065e 100644 --- a/plugin/main.go +++ b/plugin/main.go @@ -6,5 +6,5 @@ import ( ) func main() { - codegen.Run(python.Generate) + codegen.Run(python.Handler) } diff --git a/sqlc.yaml b/sqlc.yaml index a8e10ce..937e3fc 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: 4b9794407f2b48fc9711d2f90636b291e7973689e4ffe7bf9248d1ba37d40d59 sql: - schema: test/schema.sql queries: test/queries.sql diff --git a/test/driver_aiosqlite/sqlc-gen-better-python.wasm b/test/driver_aiosqlite/sqlc-gen-better-python.wasm index 56caaa4..59159bf 100644 Binary files a/test/driver_aiosqlite/sqlc-gen-better-python.wasm and b/test/driver_aiosqlite/sqlc-gen-better-python.wasm differ diff --git a/test/driver_aiosqlite/sqlc.yaml b/test/driver_aiosqlite/sqlc.yaml index 157a046..ad7ddc1 100644 --- a/test/driver_aiosqlite/sqlc.yaml +++ b/test/driver_aiosqlite/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: 4b9794407f2b48fc9711d2f90636b291e7973689e4ffe7bf9248d1ba37d40d59 sql: - schema: schema.sql queries: queries.sql diff --git a/test/driver_asyncpg/sqlc-gen-better-python.wasm b/test/driver_asyncpg/sqlc-gen-better-python.wasm index 56caaa4..59159bf 100644 Binary files a/test/driver_asyncpg/sqlc-gen-better-python.wasm and b/test/driver_asyncpg/sqlc-gen-better-python.wasm differ diff --git a/test/driver_asyncpg/sqlc.yaml b/test/driver_asyncpg/sqlc.yaml index 90e1aee..90ada3c 100644 --- a/test/driver_asyncpg/sqlc.yaml +++ b/test/driver_asyncpg/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: 4b9794407f2b48fc9711d2f90636b291e7973689e4ffe7bf9248d1ba37d40d59 sql: - schema: schema.sql queries: queries.sql diff --git a/test/driver_sqlite3/sqlc-gen-better-python.wasm b/test/driver_sqlite3/sqlc-gen-better-python.wasm index 56caaa4..59159bf 100644 Binary files a/test/driver_sqlite3/sqlc-gen-better-python.wasm and b/test/driver_sqlite3/sqlc-gen-better-python.wasm differ diff --git a/test/driver_sqlite3/sqlc.yaml b/test/driver_sqlite3/sqlc.yaml index 25560b7..130e55b 100644 --- a/test/driver_sqlite3/sqlc.yaml +++ b/test/driver_sqlite3/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: 4b9794407f2b48fc9711d2f90636b291e7973689e4ffe7bf9248d1ba37d40d59 sql: - schema: schema.sql queries: queries.sql diff --git a/test/queries.sql b/test/queries.sql index edfe618..b22c986 100644 --- a/test/queries.sql +++ b/test/queries.sql @@ -38,3 +38,7 @@ SELECT * FROM test_postgres_types; -- name: TTTT :one SELECT serial_test FROM test_postgres_types LIMIT 1; + +-- name: TestEnum :exec +INSERT INTO test_enum (id, m) +VALUES ($1, $2); diff --git a/test/schema.sql b/test/schema.sql index b249f21..41f0ac9 100644 --- a/test/schema.sql +++ b/test/schema.sql @@ -6,10 +6,21 @@ CREATE TABLE test_postgres_types timestamp_test timestamp NOT NULL ); +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + CREATE TABLE test_inner_postgres_types ( /* ───────────── Integer family ───────────── */ table_id int NOT NULL, /* ───────────── Boolean ───────────── */ bool_test boolean NOT NULL -); \ No newline at end of file +); + +CREATE TABLE test_enum +( + /* ───────────── Integer family ───────────── */ + id int PRIMARY KEY NOT NULL, + m mood NOT NULL +); + +