diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..2739289 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,164 @@ +version: "2" +linters: + # Default set of linters. + # The value can be: + # - `standard`: https://golangci-lint.run/docs/linters/#enabled-by-default + # - `all`: enables all linters by default. + # - `none`: disables all linters by default. + # - `fast`: enables only linters considered as "fast" (`golangci-lint help linters --json | jq '[ .[] | select(.fast==true) ] | map(.name)'`). + # Default: standard + default: none + enable: + - asciicheck + - bidichk + - bodyclose + - containedctx + # - contextcheck + - copyloopvar + # - cyclop + - decorder + - depguard + - dogsled + - dupl + - durationcheck + - embeddedstructfieldcheck + - errcheck + - errchkjson + - errname + - errorlint + - exhaustive + - exptostd + - fatcontext + - forbidigo + - forcetypeassert + - funcorder + - ginkgolinter + - gochecksumtype + - goconst + - gocritic + - godoclint + - godot + - godox + - gosec + - gosmopolitan + - govet + - grouper + - iface + - importas + - inamedparam + - ineffassign + - interfacebloat + - intrange + - iotamixing + - loggercheck + - makezero + - mirror + - misspell + - mnd + - musttag + - nakedret + - nestif + - nilerr + - nilnesserr + - nilnil + - nlreturn + - noctx + - nolintlint + - nonamedreturns + - nosprintfhostport + - paralleltest + - perfsprint + - prealloc + - predeclared + - reassign + - recvcheck + - rowserrcheck + - sqlclosecheck + - staticcheck + - tagalign + - tagliatelle + - tparallel + - unconvert + - unparam + - unused + - usestdlibvars + - wastedassign + - whitespace + - exhaustruct + settings: + exhaustruct: + # List of regular expressions to match type names that should be excluded from processing. + # Anonymous structs can be matched by '' alias. + # Has precedence over `include`. + # Each regular expression must match the full type name, including package path. + # For example, to match type `net/http.Cookie` regular expression should be `.*/http\.Cookie`, + # but not `http\.Cookie`. + # Default: [] + exclude: [] + allow-empty: true + tagliatelle: + case: + rules: + json: snake + yaml: snake + nonamedreturns: + report-error-in-defer: true + errcheck: + check-type-assertions: true + check-blank: true + disable-default-exclusions: true + exhaustive: + default-signifies-exhaustive: true + depguard: + # Rules to apply. + # + # Variables: + # - File Variables + # Use an exclamation mark `!` to negate a variable. + # Example: `!$test` matches any file that is not a go test file. + # + # `$all` - matches all go files + # `$test` - matches all go test files + # + # - Package Variables + # + # `$gostd` - matches all of go's standard library (Pulled from `GOROOT`) + # + # Default (applies if no custom rules are defined): Only allow $gostd in all files. + rules: + main: + list-mode: strict + # List of file globs that will match this list of settings to compare against. + # By default, if a path is relative, it is relative to the directory where the golangci-lint command is executed. + # The placeholder '${base-path}' is substituted with a path relative to the mode defined with `run.relative-path-mode`. + # The placeholder '${config-path}' is substituted with a path relative to the configuration file. + # Default: $all + files: + - "$all" + - "!$test" + # List of allowed packages. + # Entries can be a variable (starting with $), a string prefix, or an exact match (if ending with $). + # Default: [] + allow: + - "$gostd" + - "github.com/sqlc-dev/plugin-sdk-go" + - "github.com/jinzhu/inflection" + - "github.com/rayakame/sqlc-gen-better-python" + - "golang.org/x/text/cases" + - "golang.org/x/text/language" +formatters: + # Enable specific formatter. + # Default: [] (uses standard Go formatting) + enable: + - gci + - swaggo + - gofmt + - gofumpt + - goimports + - golines + + settings: + golines: + max-len: 130 + chain-split-dots: false + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..eb1d199 --- /dev/null +++ b/Makefile @@ -0,0 +1,25 @@ +-include .env +export + +.PHONY: pipelines tests fmt lint lint-fix changelog +.DEFAULT_GOAL := pipelines + +pipelines: + make lint-fix + make fmt + make lint + +tests: + go test -shuffle=on ./... + +fmt: + golangci-lint fmt + +lint: + golangci-lint run + +lint-fix: + golangci-lint run --no-config --default=none --fix -E godot,intrange,misspell,nlreturn,perfsprint,tagalign + +changelog: + go tool changie new diff --git a/internal/builders.go b/internal/builders.go deleted file mode 100644 index c091558..0000000 --- a/internal/builders.go +++ /dev/null @@ -1,439 +0,0 @@ -package internal - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/inflection" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "github.com/sqlc-dev/plugin-sdk-go/sdk" - "sort" - "strings" -) - -func (gen *PythonGenerator) buildTable(schema *plugin.Schema, table *plugin.Table) core.Table { - var tableName string - if schema.Name == gen.req.Catalog.DefaultSchema { - tableName = table.Rel.Name - } else { - tableName = schema.Name + "_" + table.Rel.Name - } - structName := tableName - if !gen.config.EmitExactTableNames { - structName = inflection.Singular(inflection.SingularParams{ - Name: structName, - Exclusions: gen.config.InflectionExcludeTableNames, - }) - } - t := core.Table{ - Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name}, - Name: core.SnakeToCamel(structName, gen.config), - Comment: table.Comment, - } - for i, column := range table.Columns { - t.Columns = append(t.Columns, core.Column{ - Name: core.ColumnName(column, i), - Type: gen.makePythonType(column), - Comment: column.Comment, - }) - } - return t -} - -func (gen *PythonGenerator) buildTables() []core.Table { - tables := make([]core.Table, 0) - for _, schema := range gen.req.Catalog.Schemas { - if schema.Name == "pg_catalog" || schema.Name == "information_schema" { - continue - } - for _, table := range schema.Tables { - t := gen.buildTable(schema, table) - tables = append(tables, t) - } - } - if len(tables) > 0 { - sort.Slice(tables, func(i, j int) bool { return tables[i].Name < tables[j].Name }) - } - return tables -} - -func (gen *PythonGenerator) makePythonType(col *plugin.Column) core.PyType { - columnType := sdk.DataType(col.Type) - strType := gen.typeConversionFunc(gen.req, col, gen.config) - for _, override := range gen.config.Overrides { - if override.PyTypeName == "" { - continue - } - cname := col.Name - if col.OriginalName != "" { - cname = col.OriginalName - } - sameTable := override.Matches(col.Table, gen.req.Catalog.DefaultSchema) - if override.Column != "" && override.ColumnName.MatchString(cname) && sameTable { - return core.PyType{ - SqlType: columnType, - Type: override.PyTypeName, - DefaultType: strType, - IsNullable: !col.NotNull, - IsList: col.GetIsArray() || col.GetIsSqlcSlice(), - IsEnum: false, - IsOverride: true, - Override: &override, - } - } - if override.DBType != "" && override.DBType == columnType { - return core.PyType{ - SqlType: columnType, - Type: override.PyTypeName, - DefaultType: strType, - IsNullable: !col.NotNull, - IsList: col.GetIsArray() || col.GetIsSqlcSlice(), - IsEnum: false, - IsOverride: true, - Override: &override, - } - } - } - return core.PyType{ - SqlType: columnType, - Type: strType, - DefaultType: strType, - IsNullable: !col.NotNull, - IsList: col.GetIsArray() || col.GetIsSqlcSlice(), - IsEnum: false, - } -} - -func (gen *PythonGenerator) buildEnums() []core.Enum { - var enums []core.Enum - for _, schema := range gen.req.Catalog.Schemas { - if schema.Name == "pg_catalog" || schema.Name == "information_schema" { - continue - } - for _, enum := range schema.Enums { - var enumName string - if schema.Name == gen.req.Catalog.DefaultSchema { - enumName = enum.Name - } else { - enumName = schema.Name + "_" + enum.Name - } - - e := core.Enum{ - Name: core.SnakeToCamel(enumName, gen.config), - Comment: enum.Comment, - } - - seen := make(map[string]struct{}, len(enum.Vals)) - for i, v := range enum.Vals { - value := core.EnumReplace(v) - if _, found := seen[value]; found || value == "" { - value = fmt.Sprintf("value_%d", i) - } - e.Constants = append(e.Constants, core.Constant{ - Name: core.SnakeToCamel(enumName+"_"+value, gen.config), - Value: v, - Type: e.Name, - }) - seen[value] = struct{}{} - } - enums = append(enums, e) - } - } - if len(enums) > 0 { - sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name }) - } - return enums -} - -type goColumn struct { - id int - *plugin.Column - embed *goEmbed -} - -type goEmbed struct { - modelType string - modelName string - fields []core.Column -} - -var cmdReturnsData = map[string]struct{}{ - metadata.CmdBatchMany: {}, - metadata.CmdBatchOne: {}, - metadata.CmdMany: {}, - metadata.CmdOne: {}, -} - -func putOutColumns(query *plugin.Query) bool { - _, found := cmdReturnsData[query.Cmd] - return found -} - -// look through all the structs and attempt to find a matching one to embed -// We need the name of the struct and its field names. -func newGoEmbed(embed *plugin.Identifier, structs []core.Table, defaultSchema string) *goEmbed { - if embed == nil { - return nil - } - - for _, s := range structs { - embedSchema := defaultSchema - if embed.Schema != "" { - embedSchema = embed.Schema - } - - // compare the other attributes - if embed.Catalog != s.Table.Catalog || embed.Name != s.Table.Name || embedSchema != s.Table.Schema { - continue - } - - fields := make([]core.Column, len(s.Columns)) - for i, f := range s.Columns { - fields[i] = f - } - return &goEmbed{ - modelType: s.Name, - modelName: s.Name, - fields: fields, - } - } - - return nil -} - -func (gen *PythonGenerator) buildQueries(tables []core.Table) ([]core.Query, error) { - qs := make([]core.Query, 0, len(gen.req.Queries)) - for _, query := range gen.req.Queries { - if query.Name == "" { - continue - } - if query.Cmd == "" { - continue - } - - constantName := core.UpperSnakeCase(query.Name) - - comments := query.Comments - - gq := core.Query{ - Cmd: query.Cmd, - ConstantName: constantName, - FuncName: strings.ToLower(constantName), - FieldName: sdk.LowerTitle(query.Name) + "Stmt", - MethodName: query.Name, - SourceName: query.Filename, - SQL: query.Text, - Comments: comments, - Table: query.InsertIntoTable, - } - - //qpl := int(*gen.config.QueryParameterLimit) TODO maybe? - - //if len(query.Params) == 1 && qpl != 0 { - if query.Cmd == metadata.CmdCopyFrom { - var cols []goColumn - for _, p := range query.Params { - cols = append(cols, goColumn{ - id: int(p.Number), - Column: p.Column, - }) - } - s, err := gen.columnsToStruct(gq.MethodName+"Params", cols, true) - if err != nil { - return nil, err - } - gq.Args = []core.QueryValue{{ - Emit: true, - Name: "params", - Table: s, - Typ: core.PyType{ - Type: gq.MethodName + "Params", - }, - }} - } else { - if len(query.Params) == 1 { - p := query.Params[0] - gq.Args = []core.QueryValue{{ - Name: core.Escape(core.ParamName(p)), - DBName: p.Column.GetName(), - Typ: gen.makePythonType(p.Column), - Column: p.Column, - }} - } else if len(query.Params) >= 1 { - var values []core.QueryValue - for _, p := range query.Params { - values = append(values, core.QueryValue{ - Name: core.Escape(core.ParamName(p)), - DBName: p.Column.GetName(), - Typ: gen.makePythonType(p.Column), - Column: p.Column, - }) - } - gq.Args = values - - // if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model - // otherwise we end up with a copyfrom using a struct without the struct definition - //if len(query.Params) <= qpl && query.Cmd != ":copyfrom" { - // gq.Args.Emit = false - //} - } - } - - if len(query.Columns) == 1 && query.Columns[0].EmbedTable == nil { - c := query.Columns[0] - name := core.ColumnName(c, 0) - name = strings.Replace(name, "$", "_", -1) - gq.Ret = core.QueryValue{ - Name: core.Escape(name), - DBName: name, - Typ: gen.makePythonType(c), - } - } else if putOutColumns(query) { - var gs *core.Table - var emit bool - - for _, s := range tables { - if len(s.Columns) != len(query.Columns) { - continue - } - same := true - for i, f := range s.Columns { - c := query.Columns[i] - sameName := f.Name == core.ColumnName(c, i) - sameType := f.Type.Type == gen.makePythonType(c).Type - sameTable := sdk.SameTableName(c.Table, s.Table, gen.req.Catalog.DefaultSchema) - if !sameName || !sameType || !sameTable { - same = false - } - } - if same { - gs = &s - break - } - } - - if gs == nil { - var columns []goColumn - for i, c := range query.Columns { - columns = append(columns, goColumn{ - id: i, - Column: c, - embed: newGoEmbed(c.EmbedTable, tables, gen.req.Catalog.DefaultSchema), - }) - } - var err error - gs, err = gen.columnsToStruct(gq.MethodName+"Row", columns, true) - if err != nil { - return nil, err - } - emit = true - } - gq.Ret = core.QueryValue{ - Emit: emit, - Name: "i", - Table: gs, - } - } - - qs = append(qs, gq) - } - sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) - return qs, nil -} - -func (gen *PythonGenerator) columnsToStruct(name string, columns []goColumn, useID bool) (*core.Table, error) { - gs := core.Table{ - Name: name, - } - seen := map[string][]int{} - suffixes := map[int]int{} - for i, c := range columns { - colName := core.ColumnName(c.Column, i) - - // override col/tag with expected model name - if c.embed != nil { - colName = c.embed.modelName - } - - fieldName := core.SnakeToCamel(colName, gen.config) - baseFieldName := fieldName - // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be - // reused. - suffix := 0 - if o, ok := suffixes[c.id]; ok && useID { - suffix = o - } else if v := len(seen[fieldName]); v > 0 && !c.IsNamedParam { - suffix = v + 1 - } - suffixes[c.id] = suffix - if suffix > 0 { - fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) - } - - f := core.Column{ - Name: inflection.Singular(inflection.SingularParams{ - Name: core.ColumnName(c.Column, i), - Exclusions: gen.config.InflectionExcludeTableNames, - }), - DBName: colName, - Column: c.Column, - } - - if c.embed == nil { - f.Type = gen.makePythonType(c.Column) - } else { - f.Type = core.PyType{ - SqlType: c.embed.modelType, - Type: "models." + c.embed.modelType, - IsList: false, - IsNullable: false, - IsEnum: false, - } - f.EmbedFields = c.embed.fields - } - - gs.Columns = append(gs.Columns, f) - if _, found := seen[baseFieldName]; !found { - seen[baseFieldName] = []int{i} - } else { - seen[baseFieldName] = append(seen[baseFieldName], i) - } - } - - // If a field does not have a known type, but another - // field with the same name has a known type, assign - // the known type to the field without a known type - /*for i, field := range gs.Columns { - if len(seen[field.Name]) > 1 && field.Type.Type == "interface{}" { - for _, j := range seen[field.Name] { - if i == j { - continue - } - otherField := gs.Fields[j] - if otherField.Type != field.Type { - field.Type = otherField.Type - } - gs.Fields[i] = field - } - } - }*/ - - err := checkIncompatibleFieldTypes(gs.Columns) - if err != nil { - return nil, err - } - - return &gs, nil -} - -func checkIncompatibleFieldTypes(fields []core.Column) error { - fieldTypes := map[string]string{} - for _, field := range fields { - if fieldType, found := fieldTypes[field.Name]; !found { - fieldTypes[field.Name] = field.Type.Type - } else if field.Type.Type != fieldType { - return fmt.Errorf("named param %s has incompatible types: %s, %s", field.Name, field.Type.Type, fieldType) - } - } - return nil -} diff --git a/internal/codegen/builders/docstrings.go b/internal/codegen/builders/docstrings.go deleted file mode 100644 index 19976e1..0000000 --- a/internal/codegen/builders/docstrings.go +++ /dev/null @@ -1,764 +0,0 @@ -package builders - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/sqlc-dev/plugin-sdk-go/metadata" -) - -var docstringConfig *string -var docstringConfigEmitSQL *bool -var docstringConfigDriver core.SQLDriverType = core.SQLDriverAsyncpg - -func SetDocstringConfig(c *string, b *bool, d core.SQLDriverType) { - docstringConfig = c - docstringConfigEmitSQL = b - docstringConfigDriver = d -} - -func (b *IndentStringBuilder) WriteQueryResultsIterDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""`+"Initialize iteration support.") - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "QueryResults[T]") - b.WriteIndentedLine(3, "Self as an iterator.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "Self as an iterator.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "Self as an iterator.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsAiterDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""`+"Initialize iteration support for `async for`.") - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "QueryResults[T]") - b.WriteIndentedLine(3, "Self as an asynchronous iterator.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "Self as an asynchronous iterator.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "Self as an asynchronous iterator.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsNextDocstringSqlite() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""Yield the next item in the query result using a sqlite3 cursor.`) - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "T") - b.WriteIndentedLine(3, "The next decoded result.") - b.NewLine() - b.WriteIndentedLine(2, "Raises") - b.WriteIndentedLine(2, "------") - b.WriteIndentedLine(2, "StopIteration") - b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(3, "StopIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(2, "StopIteration -- When no more records are available.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAiosqlite() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""Yield the next item in the query result using an aiosqlite cursor.`) - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "T") - b.WriteIndentedLine(3, "The next decoded result.") - b.NewLine() - b.WriteIndentedLine(2, "Raises") - b.WriteIndentedLine(2, "------") - b.WriteIndentedLine(2, "StopAsyncIteration") - b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(3, "StopAsyncIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(2, "StopAsyncIteration -- When no more records are available.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsAnextDocstringAsyncpg() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""Yield the next item in the query result using an asyncpg cursor.`) - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "T") - b.WriteIndentedLine(3, "The next decoded result.") - b.NewLine() - b.WriteIndentedLine(2, "Raises") - b.WriteIndentedLine(2, "------") - b.WriteIndentedLine(2, "StopAsyncIteration") - b.WriteIndentedLine(3, "When no more records are available.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(3, "StopAsyncIteration: When no more records are available.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "The next decoded result of type `T`.") - b.NewLine() - b.WriteIndentedLine(2, "Raises:") - b.WriteIndentedLine(2, "StopAsyncIteration -- When no more records are available.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsAwaitDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""`+"Allow `await` on the object to return all rows as a fully decoded sequence.") - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "collections.abc.Sequence[T]") - b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "A sequence of decoded objects of type `T`.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsCallDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""`+"Allow calling the object to return all rows as a fully decoded sequence.") - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, "collections.abc.Sequence[T]") - b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, "A sequence of decoded objects of type `T`.") - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, "A sequence of decoded objects of type `T`.") - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryResultsInitDocstring(docstringConnType string, docstringDriverReturnType string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedString(2, fmt.Sprintf(`"""Initialize the QueryResults instance.`)) - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteLine(`"""`) - } else if *docstringConfig == core.DocstringConventionGoogle { - b.NNewLine(2) - b.WriteIndentedLine(2, "Args:") - b.WriteIndentedLine(3, "conn:") - b.WriteIndentedLine(4, fmt.Sprintf("The connection object of type `%s` used to execute queries.", docstringConnType)) - b.WriteIndentedLine(3, "sql:") - b.WriteIndentedLine(4, "The SQL statement that will be executed when fetching/iterating.") - b.WriteIndentedLine(3, "decode_hook:") - b.WriteIndentedLine(4, fmt.Sprintf("A callback that turns an `%s` object into `T` that will be returned.", docstringDriverReturnType)) - b.WriteIndentedLine(3, "*args:") - b.WriteIndentedLine(4, "Arguments that should be sent when executing the sql query.") - b.WriteIndentedLine(2, `"""`) - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.NNewLine(2) - b.WriteIndentedLine(2, "Arguments:") - b.WriteIndentedLine(2, fmt.Sprintf("conn -- The connection object of type `%s` used to execute queries.", docstringConnType)) - b.WriteIndentedLine(2, "sql -- The SQL statement that will be executed when fetching/iterating.") - b.WriteIndentedLine(2, fmt.Sprintf("decode_hook -- A callback that turns an `%s` object into `T` that will be returned.", docstringDriverReturnType)) - b.WriteIndentedLine(2, "*args -- Arguments that should be sent when executing the sql query.") - b.WriteIndentedLine(2, `"""`) - } -} - -func (b *IndentStringBuilder) WriteQueryResultsClassDocstring(docstringConnType string, docstringDriverReturnType string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedString(1, `"""Helper class that allows both iteration and normal fetching of data from the db.`) - if *docstringConfig == core.DocstringConventionNumpy { - b.NewLine() - b.NewLine() - b.WriteIndentedLine(1, "Parameters") - b.WriteIndentedLine(1, "----------") - b.WriteIndentedLine(1, "conn") - b.WriteIndentedLine(2, fmt.Sprintf("The connection object of type `%s` used to execute queries.", docstringConnType)) - b.WriteIndentedLine(1, "sql") - b.WriteIndentedLine(2, "The SQL statement that will be executed when fetching/iterating.") - b.WriteIndentedLine(1, "decode_hook") - b.WriteIndentedLine(2, fmt.Sprintf("A callback that turns an `%s` object into `T` that will be returned.", docstringDriverReturnType)) - b.WriteIndentedLine(1, "*args") - b.WriteIndentedLine(2, "Arguments that should be sent when executing the sql query.") - b.NewLine() - b.WriteIndentedLine(1, `"""`) - } else { - b.WriteLine(`"""`) - } - b.NewLine() -} - -func (b *IndentStringBuilder) WriteQueryClassConnDocstring(docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(2, `"""Connection object used to make queries.`) - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(2, "Returns") - b.WriteIndentedLine(2, "-------") - b.WriteIndentedLine(2, docstringConnType) - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(3, fmt.Sprintf("Connection object of type `%s` used to make queries.", docstringConnType)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(2, "Returns:") - b.WriteIndentedLine(2, fmt.Sprintf("%s -- Connection object used to make queries.", docstringConnType)) - } - b.WriteIndentedLine(2, `"""`) -} - -func (b *IndentStringBuilder) WriteQueryClassDocstring(sourceName string, docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedString(1, fmt.Sprintf(`"""Queries from file %s.`, sourceName)) - if *docstringConfig == core.DocstringConventionNumpy { - b.NewLine() - b.NewLine() - b.WriteIndentedLine(1, "Parameters") - b.WriteIndentedLine(1, "----------") - b.WriteIndentedLine(1, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(2, "The connection object used to execute queries.") - b.NewLine() - b.WriteIndentedLine(1, `"""`) - } else { - b.WriteLine(`"""`) - } - b.NewLine() -} - -func (b *IndentStringBuilder) WriteQueryClassInitDocstring(lvl int, docstringConnType string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedString(lvl, fmt.Sprintf(`"""Initialize the instance using the connection.`)) - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteLine(`"""`) - } else if *docstringConfig == core.DocstringConventionGoogle { - b.NNewLine(2) - b.WriteIndentedLine(lvl, "Args:") - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - b.WriteIndentedLine(lvl, `"""`) - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.NNewLine(2) - b.WriteIndentedLine(lvl, "Arguments:") - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute queries.", docstringConnType)) - b.WriteIndentedLine(lvl, `"""`) - } -} - -func (b *IndentStringBuilder) WriteModelClassDocstring(table *core.Table) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteIndentedLine(1, `"""`+fmt.Sprintf("Model representing %s.", table.Name)) - if *docstringConfig == core.DocstringConventionNumpy { - b.NewLine() - b.WriteIndentedLine(1, "Attributes") - b.WriteIndentedLine(1, "----------") - for _, col := range table.Columns { - type_ := col.Type.Type - if col.Type.IsList { - type_ = "collections.abc.Sequence[" + type_ + "]" - } - if col.Type.IsNullable { - type_ = type_ + " | None" - } - b.WriteIndentedLine(1, fmt.Sprintf("%s : %s", col.Name, type_)) - } - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - b.NewLine() - b.WriteIndentedLine(1, "Attributes:") - for _, col := range table.Columns { - type_ := col.Type.Type - if col.Type.IsList { - type_ = "collections.abc.Sequence[" + type_ + "]" - } - if col.Type.IsNullable { - type_ = type_ + " | None" - } - b.WriteIndentedLine(2, fmt.Sprintf("%s: %s", col.Name, type_)) - } - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.NewLine() - b.WriteIndentedLine(1, "Attributes:") - for _, col := range table.Columns { - type_ := col.Type.Type - if col.Type.IsList { - type_ = "collections.abc.Sequence[" + type_ + "]" - } - if col.Type.IsNullable { - type_ = type_ + " | None" - } - b.WriteIndentedLine(1, fmt.Sprintf("%s -- %s", col.Name, type_)) - } - } - b.WriteIndentedLine(1, `"""`) - b.NewLine() -} - -func (b *IndentStringBuilder) WriteModelFileModuleDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteLine(`"""Module containing models."""`) -} - -func (b *IndentStringBuilder) WriteQueryFileModuleDocstring(sourceName string) { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteLine(fmt.Sprintf(`"""Module containing queries from file %s."""`, sourceName)) -} - -func (b *IndentStringBuilder) WriteInitFileModuleDocstring() { - if *docstringConfig == core.DocstringConventionNone { - return - } - b.WriteLine(`"""Package containing queries and models automatically generated using sqlc-gen-better-python."""`) -} - -func (b *IndentStringBuilder) writeQueryFunctionSQL(lvl int, query *core.Query) { - if *docstringConfigEmitSQL { - b.WriteIndentedLine(lvl, "```sql") - for _, line := range core.SplitLines(query.SQL) { - b.WriteIndentedLine(lvl, line) - } - b.WriteIndentedLine(lvl, "```") - b.NewLine() - } -} - -func (b *IndentStringBuilder) WriteQueryFunctionDocstring(lvl int, query *core.Query, docstringConnType string, queryArgs []core.FunctionArg, returnType core.PyType) { - if *docstringConfig == core.DocstringConventionNone { - return - } - - if query.Cmd == metadata.CmdExec { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if len(queryArgs) == 0 && docstringConnType == "" { - b.WriteIndentedLine(lvl, `"""`) - return - } - if *docstringConfig == core.DocstringConventionNumpy { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - } else if *docstringConfig == core.DocstringConventionPEP257 { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdExecRows { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute SQL query with `name: %s %s` and return the number of affected rows.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, returnType.Type) - if docstringConfigDriver == core.SQLDriverAioSQLite { - b.WriteIndentedLine(lvl+1, "The number of affected rows. This will be -1 for queries like `CREATE TABLE`.") - } else { - b.WriteIndentedLine(lvl+1, "The number of affected rows. This will be 0 for queries like `CREATE TABLE`.") - } - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - if docstringConfigDriver == core.SQLDriverAioSQLite { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows. This will be -1 for queries like `CREATE TABLE`.", returnType.Type)) - } else { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows. This will be 0 for queries like `CREATE TABLE`.", returnType.Type)) - } - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - if docstringConfigDriver == core.SQLDriverAioSQLite { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s -- The number of affected rows. This will be -1 for queries like `CREATE TABLE`.", returnType.Type)) - } else { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s -- The number of affected rows. This will be 0 for queries like `CREATE TABLE`.", returnType.Type)) - } - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdCopyFrom { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute COPY FROM query to insert rows into a table with `name: %s %s` and return the number of affected rows.", query.MethodName, query.Cmd)) - b.NewLine() - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - b.WriteIndentedLine(lvl+1, "A list of params for rows that should be inserted.") - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, returnType.Type) - b.WriteIndentedLine(lvl+1, "The number of affected rows.") - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - b.WriteIndentedLine(lvl+2, "A list of params for rows that should be inserted.") - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl+1, fmt.Sprintf("The number (`%s`) of affected rows.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s. A list of params for rows that should be inserted.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- The number of affected rows.", returnType.Type)) - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdExecResult { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute and return the result of SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, returnType.Type) - b.WriteIndentedLine(lvl+1, "The result returned when executing the query.") - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl+1, fmt.Sprintf("The result of type `%s` returned when executing the query.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- The result returned when executing the query.", returnType.Type)) - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdExecLastId { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Execute SQL query with `name: %s %s` and return the id of the last affected row.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, returnType.Type) - b.WriteIndentedLine(lvl+1, "The id of the last affected row. Will be `None` if no rows are affected.") - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl+1, fmt.Sprintf("The id (`%s`) of the last affected row. Will be `None` if no rows are affected.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- The id of the last affected row. Will be `None` if no rows are affected.", returnType.Type)) - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdOne { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Fetch one from the db using the SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, returnType.Type) - b.WriteIndentedLine(lvl+1, "Result fetched from the db. Will be `None` if not found.") - b.NewLine() - - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Result of type `%s` fetched from the db. Will be `None` if not found.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- Result fetched from the db. Will be `None` if not found.", returnType.Type)) - } - b.WriteIndentedLine(lvl, `"""`) - } else if query.Cmd == metadata.CmdMany { - b.WriteIndentedLine(lvl, `"""`+fmt.Sprintf("Fetch many from the db using the SQL query with `name: %s %s`.", query.MethodName, query.Cmd)) - b.NewLine() - b.writeQueryFunctionSQL(lvl, query) - if *docstringConfig == core.DocstringConventionNumpy { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Parameters") - b.WriteIndentedLine(lvl, "----------") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn : %s", docstringConnType)) - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s : %s", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns") - b.WriteIndentedLine(lvl, "-------") - b.WriteIndentedLine(lvl, fmt.Sprintf("QueryResults[%s]", returnType.Type)) - b.WriteIndentedLine(lvl+1, "Helper class that allows both iteration and normal fetching of data from the db.") - b.NewLine() - } else if *docstringConfig == core.DocstringConventionGoogle { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Args:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl+1, "conn:") - b.WriteIndentedLine(lvl+2, fmt.Sprintf("Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl+1, fmt.Sprintf("%s: %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl+1, fmt.Sprintf("Helper class of type `QueryResults[%s]` that allows both iteration and normal fetching of data from the db.", returnType.Type)) - } else if *docstringConfig == core.DocstringConventionPEP257 { - if len(queryArgs) != 0 || docstringConnType != "" { - b.WriteIndentedLine(lvl, "Arguments:") - if docstringConnType != "" { - b.WriteIndentedLine(lvl, fmt.Sprintf("conn -- Connection object of type `%s` used to execute the query.", docstringConnType)) - } - for _, arg := range queryArgs { - b.WriteIndentedLine(lvl, fmt.Sprintf("%s -- %s.", arg.Name, arg.Type)) - } - b.NewLine() - } - b.WriteIndentedLine(lvl, "Returns:") - b.WriteIndentedLine(lvl, fmt.Sprintf("QueryResults[%s] -- Helper class that allows both iteration and normal fetching of data from the db.", returnType.Type)) - } - b.WriteIndentedLine(lvl, `"""`) - } -} diff --git a/internal/codegen/builders/query_results.go b/internal/codegen/builders/query_results.go deleted file mode 100644 index 4af9324..0000000 --- a/internal/codegen/builders/query_results.go +++ /dev/null @@ -1,84 +0,0 @@ -package builders - -import "fmt" - -func (b *IndentStringBuilder) WriteSyncQueryResultsClassHeader(connType string, initFields []string, driverReturnType string) { - b.WriteLine(`T = typing.TypeVar("T")`) - b.NNewLine(2) - b.WriteLine("class QueryResults(typing.Generic[T]):") - b.WriteQueryResultsClassDocstring(connType, driverReturnType) - b.WriteIndentedLine(1, `__slots__ = ("_args", "_conn", "_cursor", "_decode_hook", "_iterator", "_sql")`) - b.NewLine() - b.WriteIndentedLine(1, "def __init__(") - b.WriteIndentedLine(2, "self,") - b.WriteIndentedLine(2, fmt.Sprintf("conn: %s,", connType)) - b.WriteIndentedLine(2, "sql: str,") - b.WriteIndentedLine(2, fmt.Sprintf("decode_hook: collections.abc.Callable[[%s], T],", driverReturnType)) - b.WriteIndentedLine(2, "*args: QueryResultsArgsType,") - b.WriteIndentedLine(1, ") -> None:") - b.WriteQueryResultsInitDocstring(connType, driverReturnType) - b.WriteIndentedLine(2, "self._conn = conn") - b.WriteIndentedLine(2, "self._sql = sql") - b.WriteIndentedLine(2, "self._decode_hook = decode_hook") - b.WriteIndentedLine(2, "self._args = args") - for _, line := range initFields { - b.WriteIndentedLine(2, line) - } - b.NewLine() - b.WriteIndentedLine(1, "def __iter__(self) -> QueryResults[T]:") - b.WriteQueryResultsIterDocstring() - b.WriteIndentedLine(2, "return self") - b.NewLine() -} - -func (b *IndentStringBuilder) WriteAsyncQueryResultsClassHeader(connType string, initFields []string, driverReturnType string) { - b.WriteLine(`T = typing.TypeVar("T")`) - b.NNewLine(2) - b.WriteLine("class QueryResults(typing.Generic[T]):") - b.WriteQueryResultsClassDocstring(connType, driverReturnType) - b.WriteIndentedLine(1, `__slots__ = ("_args", "_conn", "_cursor", "_decode_hook", "_iterator", "_sql")`) - b.NewLine() - b.WriteIndentedLine(1, "def __init__(") - b.WriteIndentedLine(2, "self,") - b.WriteIndentedLine(2, fmt.Sprintf("conn: %s,", connType)) - b.WriteIndentedLine(2, "sql: str,") - b.WriteIndentedLine(2, fmt.Sprintf("decode_hook: collections.abc.Callable[[%s], T],", driverReturnType)) - b.WriteIndentedLine(2, "*args: QueryResultsArgsType,") - b.WriteIndentedLine(1, ") -> None:") - b.WriteQueryResultsInitDocstring(connType, driverReturnType) - b.WriteIndentedLine(2, "self._conn = conn") - b.WriteIndentedLine(2, "self._sql = sql") - b.WriteIndentedLine(2, "self._decode_hook = decode_hook") - b.WriteIndentedLine(2, "self._args = args") - for _, line := range initFields { - b.WriteIndentedLine(2, line) - } - b.NewLine() - b.WriteIndentedLine(1, "def __aiter__(self) -> QueryResults[T]:") - b.WriteQueryResultsAiterDocstring() - b.WriteIndentedLine(2, "return self") - b.NewLine() -} - -func (b *IndentStringBuilder) WriteQueryResultsCallFunction(wrapperLines []string) { - b.WriteIndentedLine(1, "def __call__(") - b.WriteIndentedLine(2, "self,") - b.WriteIndentedLine(1, ") -> collections.abc.Sequence[T]:") - b.WriteQueryResultsCallDocstring() - for _, line := range wrapperLines { - b.WriteIndentedLine(2, line) - } -} - -func (b *IndentStringBuilder) WriteQueryResultsAwaitFunction(wrapperLines []string) { - b.WriteIndentedLine(1, "def __await__(") - b.WriteIndentedLine(2, "self,") - b.WriteIndentedLine(1, ") -> collections.abc.Generator[None, None, collections.abc.Sequence[T]]:") - b.WriteQueryResultsAwaitDocstring() - b.WriteIndentedLine(2, "async def _wrapper() -> collections.abc.Sequence[T]:") - for _, line := range wrapperLines { - b.WriteIndentedLine(3, line) - } - b.WriteIndentedLine(2, "return _wrapper().__await__()") - -} diff --git a/internal/codegen/builders/string.go b/internal/codegen/builders/string.go index 4b17299..9ebf02b 100644 --- a/internal/codegen/builders/string.go +++ b/internal/codegen/builders/string.go @@ -1,38 +1,43 @@ package builders import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" "os" "strings" + + "github.com/rayakame/sqlc-gen-better-python/internal/config" + "github.com/rayakame/sqlc-gen-better-python/internal/log" ) type IndentStringBuilder struct { - strings.Builder + builder strings.Builder indentChar string charsPerIndentLevel int + + docstringOmitSQL bool + docstringConvention config.DocstringConvention + docstringDriver config.SQLDriver } -func NewIndentStringBuilder(indentChar string, charsPerIndentLevel int) *IndentStringBuilder { +func NewIndentStringBuilder( + indentChar string, + charsPerIndentLevel int, + docstringConvention config.DocstringConvention, + docstringOmitSQL bool, + docstringDriver config.SQLDriver, +) *IndentStringBuilder { return &IndentStringBuilder{ + builder: strings.Builder{}, indentChar: indentChar, charsPerIndentLevel: charsPerIndentLevel, + docstringConvention: docstringConvention, + docstringDriver: docstringDriver, + docstringOmitSQL: docstringOmitSQL, } } -func (b *IndentStringBuilder) WriteQueryFunctionArgs(args []core.FunctionArg, conf *core.Config) { - for i, arg := range args { - if i == 0 && len(args) > int(*conf.OmitKwargsLimit) { - b.WriteString(", *") - } - b.WriteString(fmt.Sprintf(", %s", arg.FunctionFormat)) - } -} - -func (b *IndentStringBuilder) WriteIndentedString(level int, txt string) int { - count, _ := b.WriteString(strings.Repeat(b.indentChar, level*b.charsPerIndentLevel) + txt) - return count +func (b *IndentStringBuilder) WriteIndentedString(level int, txt string) { + b.WriteString(strings.Repeat(b.indentChar, level*b.charsPerIndentLevel) + txt) } func (b *IndentStringBuilder) WriteSqlcHeader() { @@ -41,7 +46,14 @@ func (b *IndentStringBuilder) WriteSqlcHeader() { b.WriteString("# Code generated by sqlc. DO NOT EDIT.\n") b.WriteString("# versions:\n") b.WriteString("# sqlc " + sqlcVersion + "\n") - b.WriteString("# sqlc-gen-better-python " + core.PluginVersion + "\n") + b.WriteString("# sqlc-gen-better-python " + config.PluginVersion + "\n") +} + +func (b *IndentStringBuilder) WriteString(txt string) { + _, err := b.builder.WriteString(txt) + if err != nil { + log.L().LogErr("Error while trying to write string", err) + } } func (b *IndentStringBuilder) WriteLine(txt string) { @@ -68,3 +80,7 @@ func (b *IndentStringBuilder) NNewLine(n int) { b.WriteString("\n") } } + +func (b *IndentStringBuilder) Bytes() []byte { + return []byte(b.builder.String()) +} diff --git a/internal/codegen/common.go b/internal/codegen/common.go deleted file mode 100644 index e302e4f..0000000 --- a/internal/codegen/common.go +++ /dev/null @@ -1,81 +0,0 @@ -package codegen - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/drivers" - "github.com/rayakame/sqlc-gen-better-python/internal/core" -) - -type TypeBuildPyQueryFunc func(*core.Query, *builders.IndentStringBuilder, []core.FunctionArg, core.PyType, *core.Config) error -type TypeAcceptedDriverCMDs func() []string -type TypeDriverTypeCheckingHook func() []string -type TypeDriverBuildQueryResults func(*builders.IndentStringBuilder) string - -func defaultDriverTypeCheckingHook() []string { - return nil -} -func defaultDriverBuildQueryResults(_ *builders.IndentStringBuilder) string { - return "" -} - -type Driver struct { - conf *core.Config - - connType string - buildPyQueryFunc TypeBuildPyQueryFunc - acceptedDriverCMDs TypeAcceptedDriverCMDs - - driverTypeCheckingHook TypeDriverTypeCheckingHook - driverBuildQueryResults TypeDriverBuildQueryResults - - //BuildPyQueriesFiles(*core.Importer, []core.Query) ([]*plugin.File, error) -} - -func NewDriver(conf *core.Config) (*Driver, error) { - var buildPyQueryFunc TypeBuildPyQueryFunc - var acceptedDriverCMDs TypeAcceptedDriverCMDs - var connType string - var driverTypeCheckingHook TypeDriverTypeCheckingHook = defaultDriverTypeCheckingHook - var driverBuildQueryResults TypeDriverBuildQueryResults = defaultDriverBuildQueryResults - switch conf.SqlDriver { - case core.SQLDriverAioSQLite: - buildPyQueryFunc = drivers.AioSQLiteBuildPyQueryFunc - acceptedDriverCMDs = drivers.AioSQLiteAcceptedDriverCMDs - connType = drivers.AioSQLiteConn - driverBuildQueryResults = drivers.AiosqliteBuildQueryResults - case core.SQLDriverSQLite: - buildPyQueryFunc = drivers.SQLite3BuildPyQueryFunc - acceptedDriverCMDs = drivers.SQLite3AcceptedDriverCMDs - connType = drivers.SQLite3Conn - driverBuildQueryResults = drivers.SQLite3BuildQueryResults - case core.SQLDriverAsyncpg: - buildPyQueryFunc = drivers.AsyncpgBuildPyQueryFunc - acceptedDriverCMDs = drivers.AsyncpgAcceptedDriverCMDs - connType = drivers.AsyncpgConn - driverTypeCheckingHook = drivers.AsyncpgTypeCheckingHook - driverBuildQueryResults = drivers.AsyncpgBuildQueryResults - default: - return nil, fmt.Errorf("unsupported driver: %s", conf.SqlDriver.String()) - } - builders.SetDocstringConfig(conf.EmitDocstrings, conf.EmitDocstringsSQL, conf.SqlDriver) - - return &Driver{ - buildPyQueryFunc: buildPyQueryFunc, - acceptedDriverCMDs: acceptedDriverCMDs, - conf: conf, - connType: connType, - driverTypeCheckingHook: driverTypeCheckingHook, - driverBuildQueryResults: driverBuildQueryResults, - }, nil -} - -func (dr *Driver) supportedCMD(command string) error { - cmds := dr.acceptedDriverCMDs() - for _, cmd := range cmds { - if cmd == command { - return nil - } - } - return fmt.Errorf("unsupported command for selected driver: %s", command) -} diff --git a/internal/codegen/drivers/aiosqlite.go b/internal/codegen/drivers/aiosqlite.go deleted file mode 100644 index fe634fa..0000000 --- a/internal/codegen/drivers/aiosqlite.go +++ /dev/null @@ -1,316 +0,0 @@ -package drivers - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/rayakame/sqlc-gen-better-python/internal/types" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strconv" - "strings" -) - -const AioSQLiteConn = "aiosqlite.Connection" - -func AioSQLiteBuildTypeConvFunc(queries []core.Query, body *builders.IndentStringBuilder, conf *core.Config) { - // this function fucking got out of hand - queryValueUses := func(name string, qv core.QueryValue) bool { - if !qv.IsEmpty() { - if qv.IsStruct() && qv.EmitStruct() { - if val, pyType := core.TableUses(name, *qv.Table); val { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } else if qv.IsStruct() { - if val, pyType := core.TableUses(name, *qv.Table); val { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } else { - if qv.Typ.Type == name { - if qv.Typ.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } - } - return false - } - toConvert := make(map[string]bool) - for _, query := range queries { - for sqlType, _ := range typeConversion.SqliteGetConversions() { - name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Column{Type: &plugin.Identifier{ - Catalog: "", - Schema: "", - Name: sqlType, - }}, conf) - if queryValueUses(name, query.Ret) { - toConvert[name] = true - } - for _, arg := range query.Args { - if queryValueUses(name, arg) { - toConvert[name] = true - } - } - } - } - adapters := make([]string, 0) - converters := make([]string, 0) - if _, found := toConvert["datetime.date"]; found { - body.WriteLine("def _adapt_date(val: datetime.date) -> str:") - body.WriteIndentedLine(1, "return val.isoformat()") - body.NNewLine(2) - adapters = append(adapters, "aiosqlite.register_adapter(datetime.date, _adapt_date)") - body.WriteLine("def _convert_date(val: bytes) -> datetime.date:") - if conf.Speedups { - body.WriteIndentedLine(1, "return ciso8601.parse_datetime(val.decode()).date()") - } else { - body.WriteIndentedLine(1, "return datetime.date.fromisoformat(val.decode())") - } - body.NNewLine(2) - converters = append(converters, `aiosqlite.register_converter("date", _convert_date)`) - } - if _, found := toConvert["decimal.Decimal"]; found { - body.WriteLine("def _adapt_decimal(val: decimal.Decimal) -> str:") - body.WriteIndentedLine(1, "return str(val)") - body.NNewLine(2) - adapters = append(adapters, "aiosqlite.register_adapter(decimal.Decimal, _adapt_decimal)") - body.WriteLine("def _convert_decimal(val: bytes) -> decimal.Decimal:") - body.WriteIndentedLine(1, "return decimal.Decimal(val.decode())") - body.NNewLine(2) - converters = append(converters, `aiosqlite.register_converter("decimal", _convert_decimal)`) - } - if _, found := toConvert["datetime.datetime"]; found { - body.WriteLine("def _adapt_datetime(val: datetime.datetime) -> str:") - body.WriteIndentedLine(1, "return val.isoformat()") - body.NNewLine(2) - adapters = append(adapters, "aiosqlite.register_adapter(datetime.datetime, _adapt_datetime)") - body.WriteLine("def _convert_datetime(val: bytes) -> datetime.datetime:") - if conf.Speedups { - body.WriteIndentedLine(1, "return ciso8601.parse_datetime(val.decode())") - } else { - body.WriteIndentedLine(1, "return datetime.datetime.fromisoformat(val.decode())") - } - body.NNewLine(2) - converters = append(converters, `aiosqlite.register_converter("datetime", _convert_datetime)`) - converters = append(converters, `aiosqlite.register_converter("timestamp", _convert_datetime)`) - } - if _, found := toConvert["bool"]; found { - body.WriteLine("def _adapt_bool(val: bool) -> int:") - body.WriteIndentedLine(1, "return int(val)") - body.NNewLine(2) - adapters = append(adapters, "aiosqlite.register_adapter(bool, _adapt_bool)") - body.WriteLine("def _convert_bool(val: bytes) -> bool:") - body.WriteIndentedLine(1, "return bool(int(val))") - body.NNewLine(2) - converters = append(converters, `aiosqlite.register_converter("bool", _convert_bool)`) - converters = append(converters, `aiosqlite.register_converter("boolean", _convert_bool)`) - } - if _, found := toConvert["memoryview"]; found { - body.WriteLine("def _adapt_memoryview(val: memoryview) -> bytes:") - body.WriteIndentedLine(1, "return val.tobytes()") - body.NNewLine(2) - adapters = append(adapters, "aiosqlite.register_adapter(memoryview, _adapt_memoryview)") - body.WriteLine("def _convert_memoryview(val: bytes) -> memoryview:") - body.WriteIndentedLine(1, "return memoryview(val)") - body.NNewLine(2) - converters = append(converters, `aiosqlite.register_converter("blob", _convert_memoryview)`) - } - for i, line := range adapters { - body.WriteLine(line) - if i == len(adapters)-1 { - body.NewLine() - } - } - for i, line := range converters { - body.WriteLine(line) - if i == len(converters)-1 { - body.NNewLine(2) - } - } -} - -func AiosqliteBuildQueryResults(body *builders.IndentStringBuilder) string { - body.WriteAsyncQueryResultsClassHeader(AioSQLiteConn, []string{ - "self._cursor: aiosqlite.Cursor | None = None", - fmt.Sprintf("self._iterator: collections.abc.AsyncIterator[%s] | None = None", Sqlite3Result), - }, Sqlite3Result) - body.WriteQueryResultsAwaitFunction([]string{ - "result = await (await self._conn.execute(self._sql, self._args)).fetchall()", - "return [self._decode_hook(row) for row in result]", - }) - body.NewLine() - body.WriteIndentedLine(1, "async def __anext__(self) -> T:") - body.WriteQueryResultsAnextDocstringAiosqlite() - body.WriteIndentedLine(2, "if self._cursor is None or self._iterator is None:") - body.WriteIndentedLine(3, "self._cursor: aiosqlite.Cursor | None = await self._conn.execute(self._sql, self._args)") - body.WriteIndentedLine(3, "self._iterator = self._cursor.__aiter__()") - body.WriteIndentedLine(2, "try:") - body.WriteIndentedLine(3, "record = await self._iterator.__anext__()") - body.WriteIndentedLine(2, "except StopAsyncIteration:") - body.WriteIndentedLine(3, "self._cursor = None") - body.WriteIndentedLine(3, "self._iterator = None") - body.WriteIndentedLine(3, "raise") - body.WriteIndentedLine(2, "return self._decode_hook(record)") - return "QueryResults" -} - -func AioSQLiteBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, args []core.FunctionArg, retType core.PyType, conf *core.Config) error { - indentLevel := 0 - params := fmt.Sprintf("conn: %s", AioSQLiteConn) - conn := "conn" - asyncFunc := "async " - docstringConnType := AioSQLiteConn - if conf.EmitClasses { - params = "self" - conn = "self._conn" - indentLevel = 1 - docstringConnType = "" - } - if query.Cmd == metadata.CmdMany { - asyncFunc = "" - } - body.WriteIndentedString(indentLevel, fmt.Sprintf("%sdef %s(%s", asyncFunc, query.FuncName, params)) - body.WriteQueryFunctionArgs(args, conf) - if query.Cmd == metadata.CmdExec { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecResult { - body.WriteLine(fmt.Sprintf(") -> %s:", "aiosqlite.Cursor")) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, core.PyType{Type: "aiosqlite.Cursor"}) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return await %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecRows { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return (await %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")).rowcount") - } else if query.Cmd == metadata.CmdExecLastId { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return (await %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")).lastrowid") - } else if query.Cmd == metadata.CmdOne { - body.WriteLine(fmt.Sprintf(") -> %s | None:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = await (await %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")).fetchone()") - body.WriteIndentedLine(indentLevel+1, "if row is None:") - body.WriteIndentedLine(indentLevel+2, "return None") - if query.Ret.IsStruct() { - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoOverride() { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } else { - if query.Ret.Typ.DoOverride() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("return %s(row[0])", query.Ret.Typ.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, "return row[0]") - } - } - } else if query.Cmd == metadata.CmdMany { - body.WriteLine(fmt.Sprintf(") -> QueryResults[%s]:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - - decodeHook := "_decode_hook" - if !query.Ret.IsStruct() && !query.Ret.Typ.DoOverride() { - decodeHook = "operator.itemgetter(0)" - } else if !query.Ret.IsStruct() && query.Ret.Typ.DoOverride() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", Sqlite3Result, retType.Type)) - body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("return %s(row[0])", retType.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", Sqlite3Result, retType.Type)) - body.WriteIndentedString(indentLevel+2, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoOverride() { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return QueryResults[%s](%s, %s, %s", retType.Type, conn, query.ConstantName, decodeHook)) - params := "" - for i, arg := range query.Args { - if !arg.IsEmpty() { - if i == len(query.Args)-1 { - params += fmt.Sprintf(" %s", arg.Name) - } else { - params += fmt.Sprintf(" %s,", arg.Name) - } - } - } - if params != "" { - body.WriteString("," + params) - } - body.WriteLine(")") - } - return nil -} - -func AioSQLiteAcceptedDriverCMDs() []string { - return []string{ - metadata.CmdExec, - metadata.CmdExecResult, - metadata.CmdExecLastId, - metadata.CmdExecRows, - metadata.CmdOne, - metadata.CmdMany, - } -} diff --git a/internal/codegen/drivers/asyncpg.go b/internal/codegen/drivers/asyncpg.go deleted file mode 100644 index 87f1e60..0000000 --- a/internal/codegen/drivers/asyncpg.go +++ /dev/null @@ -1,251 +0,0 @@ -package drivers - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "strconv" - "strings" -) - -const AsyncpgConn = "ConnectionLike" -const AsyncpgResult = "asyncpg.Record" - -func AsyncpgTypeCheckingHook() []string { - return []string{ - fmt.Sprintf( - "ConnectionLike: typing.TypeAlias = asyncpg.Connection[%[1]s] | asyncpg.pool.PoolConnectionProxy[%[1]s]", - AsyncpgResult, - ), - } -} - -func AsyncpgBuildQueryResults(body *builders.IndentStringBuilder) string { - body.WriteAsyncQueryResultsClassHeader(AsyncpgConn, []string{ - fmt.Sprintf("self._cursor: asyncpg.cursor.CursorFactory[%s] | None = None", AsyncpgResult), - fmt.Sprintf("self._iterator: asyncpg.cursor.CursorIterator[%s] | None = None", AsyncpgResult), - }, AsyncpgResult) - body.WriteQueryResultsAwaitFunction([]string{ - "result = await self._conn.fetch(self._sql, *self._args)", - "return [self._decode_hook(row) for row in result]", - }) - body.NewLine() - body.WriteIndentedLine(1, "async def __anext__(self) -> T:") - body.WriteQueryResultsAnextDocstringAsyncpg() - body.WriteIndentedLine(2, "if self._cursor is None or self._iterator is None:") - body.WriteIndentedLine(3, "self._cursor = self._conn.cursor(self._sql, *self._args)") - body.WriteIndentedLine(3, "self._iterator = self._cursor.__aiter__()") - body.WriteIndentedLine(2, "try:") - body.WriteIndentedLine(3, "record = await self._iterator.__anext__()") - body.WriteIndentedLine(2, "except StopAsyncIteration:") - body.WriteIndentedLine(3, "self._cursor = None") - body.WriteIndentedLine(3, "self._iterator = None") - body.WriteIndentedLine(3, "raise") - body.WriteIndentedLine(2, "return self._decode_hook(record)") - return "QueryResults" -} - -func AsyncpgBuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, args []core.FunctionArg, retType core.PyType, conf *core.Config) error { - indentLevel := 0 - params := fmt.Sprintf("conn: %s", AsyncpgConn) - conn := "conn" - asyncFunc := "async " - docstringConnType := AsyncpgConn - if conf.EmitClasses { - params = "self" - conn = "self._conn" - indentLevel = 1 - docstringConnType = "" - } - if query.Cmd == metadata.CmdMany { - asyncFunc = "" - } - body.WriteIndentedString(indentLevel, fmt.Sprintf("%sdef %s(%s", asyncFunc, query.FuncName, params)) - body.WriteQueryFunctionArgs(args, conf) - if query.Cmd == metadata.CmdExec { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("await %s.execute(%s", conn, query.ConstantName)) - asyncpgWriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecResult { - body.WriteLine(") -> str:") - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, core.PyType{Type: "str"}) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return await %s.execute(%s", conn, query.ConstantName)) - asyncpgWriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecRows { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("r = await %s.execute(%s", conn, query.ConstantName)) - asyncpgWriteParams(query, body) - body.WriteLine(")") - if conf.Speedups { - body.WriteIndentedLine(indentLevel+1, "return int(n) if (n := r.split()[-1]).isdigit() else 0") - } else { - body.WriteIndentedLine(indentLevel+1, "return int(n) if (p := r.split()) and (n := p[-1]).isdigit() else 0") - } - } else if query.Cmd == metadata.CmdCopyFrom { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedLine(indentLevel+1, "records = [") - params := "" - columns := `` - for i, arg := range query.Args[0].Table.Columns { - if i == len(query.Args[0].Table.Columns)-1 && i != 0 { - params += fmt.Sprintf("%s.%s", "param", arg.Name) - columns += fmt.Sprintf(`"%s"`, arg.Name) - } else { - params += fmt.Sprintf("%s.%s, ", "param", arg.Name) - columns += fmt.Sprintf(`"%s", `, arg.Name) - } - } - body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("(%s)", params)) - body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("for param in %s", query.Args[0].Name)) - body.WriteIndentedLine(indentLevel+1, "]") - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf(`r = await %s.copy_records_to_table("%s", columns=[%s], records=records)`, conn, query.Table.Name, columns)) - if conf.Speedups { - body.WriteIndentedLine(indentLevel+1, "return int(n) if (n := r.split()[-1]).isdigit() else 0") - } else { - body.WriteIndentedLine(indentLevel+1, "return int(n) if (p := r.split()) and (n := p[-1]).isdigit() else 0") - } - } else if query.Cmd == metadata.CmdOne { - body.WriteLine(fmt.Sprintf(") -> %s | None:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = await %s.fetchrow(%s", conn, query.ConstantName)) - asyncpgWriteParams(query, body) - body.WriteLine(")") - body.WriteIndentedLine(indentLevel+1, "if row is None:") - body.WriteIndentedLine(indentLevel+2, "return None") - if query.Ret.IsStruct() { - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() || embedCol.Type.DoConversion(typeConversion.AsyncpgDoTypeConversion) { - if embedCol.Type.IsNullable { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s]) if row[%s] is not None else None", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i), strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoConversion(typeConversion.AsyncpgDoTypeConversion) || col.Type.DoOverride() { - if col.Type.IsNullable { - body.WriteString(fmt.Sprintf("%s=%s(row[%s]) if row[%s] is not None else None", col.Name, col.Type.Type, strconv.Itoa(i), strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } else { - if retType.DoConversion(typeConversion.AsyncpgDoTypeConversion) || retType.DoOverride() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("return %s(row[0])", retType.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, "return row[0]") - } - } - } else if query.Cmd == metadata.CmdMany { - body.WriteLine(fmt.Sprintf(") -> QueryResults[%s]:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - decode_hook := "_decode_hook" - if !query.Ret.IsStruct() && !(retType.DoConversion(typeConversion.AsyncpgDoTypeConversion) || retType.DoOverride()) { - decode_hook = "operator.itemgetter(0)" - } else if !query.Ret.IsStruct() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", AsyncpgResult, retType.Type)) - body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("return %s(row[0])", retType.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", AsyncpgResult, retType.Type)) - body.WriteIndentedString(indentLevel+2, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() || embedCol.Type.DoConversion(typeConversion.AsyncpgDoTypeConversion) { - if embedCol.Type.IsNullable { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s]) if row[%s] is not None else None", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i), strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoConversion(typeConversion.AsyncpgDoTypeConversion) || col.Type.DoOverride() { - if col.Type.IsNullable { - body.WriteString(fmt.Sprintf("%s=%s(row[%s]) if row[%s] is not None else None", col.Name, col.Type.Type, strconv.Itoa(i), strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return QueryResults[%s](%s, %s, %s", retType.Type, conn, query.ConstantName, decode_hook)) - asyncpgWriteParams(query, body) - body.WriteLine(")") - } - return nil -} - -func AsyncpgAcceptedDriverCMDs() []string { - return []string{ - metadata.CmdExec, - metadata.CmdExecResult, - metadata.CmdExecRows, - metadata.CmdOne, - metadata.CmdMany, - metadata.CmdCopyFrom, - } -} - -func asyncpgWriteParams(query *core.Query, body *builders.IndentStringBuilder) { - if len(query.Args) == 0 { - return - } - params := "" - for i, arg := range query.Args { - if !arg.IsEmpty() { - argName := arg.Name - if arg.Typ.DoOverride() { - argName = fmt.Sprintf("%s(%s)", arg.Typ.DefaultType, argName) - } - if i == len(query.Args)-1 { - params += fmt.Sprintf(" %s", argName) - } else { - params += fmt.Sprintf(" %s,", argName) - } - } - } - body.WriteString("," + params) -} diff --git a/internal/codegen/drivers/sqlite3.go b/internal/codegen/drivers/sqlite3.go deleted file mode 100644 index e43f78c..0000000 --- a/internal/codegen/drivers/sqlite3.go +++ /dev/null @@ -1,334 +0,0 @@ -package drivers - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/rayakame/sqlc-gen-better-python/internal/types" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strconv" - "strings" -) - -const Sqlite3Result = "sqlite3.Row" -const SQLite3Conn = "sqlite3.Connection" - -func SQLite3BuildTypeConvFunc(queries []core.Query, body *builders.IndentStringBuilder, conf *core.Config) { - // this function fucking got out of hand - queryValueUses := func(name string, qv core.QueryValue) bool { - if !qv.IsEmpty() { - if qv.IsStruct() && qv.EmitStruct() { - if val, pyType := core.TableUses(name, *qv.Table); val { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } else if qv.IsStruct() { - if val, pyType := core.TableUses(name, *qv.Table); val { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } else { - if qv.Typ.Type == name { - if qv.Typ.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true - } - } - } - } - return false - } - toConvert := make(map[string]bool) - for _, query := range queries { - for sqlType, _ := range typeConversion.SqliteGetConversions() { - name := types.SqliteTypeToPython(&plugin.GenerateRequest{}, &plugin.Column{Type: &plugin.Identifier{ - Catalog: "", - Schema: "", - Name: sqlType, - }}, conf) - if queryValueUses(name, query.Ret) { - toConvert[name] = true - } - for _, arg := range query.Args { - if queryValueUses(name, arg) { - toConvert[name] = true - } - } - } - } - adapters := make([]string, 0) - converters := make([]string, 0) - if _, found := toConvert["datetime.date"]; found { - body.WriteLine("def _adapt_date(val: datetime.date) -> str:") - body.WriteIndentedLine(1, "return val.isoformat()") - body.NNewLine(2) - adapters = append(adapters, "sqlite3.register_adapter(datetime.date, _adapt_date)") - body.WriteLine("def _convert_date(val: bytes) -> datetime.date:") - if conf.Speedups { - body.WriteIndentedLine(1, "return ciso8601.parse_datetime(val.decode()).date()") - } else { - body.WriteIndentedLine(1, "return datetime.date.fromisoformat(val.decode())") - } - body.NNewLine(2) - converters = append(converters, `sqlite3.register_converter("date", _convert_date)`) - } - if _, found := toConvert["decimal.Decimal"]; found { - body.WriteLine("def _adapt_decimal(val: decimal.Decimal) -> str:") - body.WriteIndentedLine(1, "return str(val)") - body.NNewLine(2) - adapters = append(adapters, "sqlite3.register_adapter(decimal.Decimal, _adapt_decimal)") - body.WriteLine("def _convert_decimal(val: bytes) -> decimal.Decimal:") - body.WriteIndentedLine(1, "return decimal.Decimal(val.decode())") - body.NNewLine(2) - converters = append(converters, `sqlite3.register_converter("decimal", _convert_decimal)`) - } - if _, found := toConvert["datetime.datetime"]; found { - body.WriteLine("def _adapt_datetime(val: datetime.datetime) -> str:") - body.WriteIndentedLine(1, "return val.isoformat()") - body.NNewLine(2) - adapters = append(adapters, "sqlite3.register_adapter(datetime.datetime, _adapt_datetime)") - body.WriteLine("def _convert_datetime(val: bytes) -> datetime.datetime:") - if conf.Speedups { - body.WriteIndentedLine(1, "return ciso8601.parse_datetime(val.decode())") - } else { - body.WriteIndentedLine(1, "return datetime.datetime.fromisoformat(val.decode())") - } - body.NNewLine(2) - converters = append(converters, `sqlite3.register_converter("datetime", _convert_datetime)`) - converters = append(converters, `sqlite3.register_converter("timestamp", _convert_datetime)`) - } - if _, found := toConvert["bool"]; found { - body.WriteLine("def _adapt_bool(val: bool) -> int:") - body.WriteIndentedLine(1, "return int(val)") - body.NNewLine(2) - adapters = append(adapters, "sqlite3.register_adapter(bool, _adapt_bool)") - body.WriteLine("def _convert_bool(val: bytes) -> bool:") - body.WriteIndentedLine(1, "return bool(int(val))") - body.NNewLine(2) - converters = append(converters, `sqlite3.register_converter("bool", _convert_bool)`) - converters = append(converters, `sqlite3.register_converter("boolean", _convert_bool)`) - } - if _, found := toConvert["memoryview"]; found { - body.WriteLine("def _adapt_memoryview(val: memoryview) -> bytes:") - body.WriteIndentedLine(1, "return val.tobytes()") - body.NNewLine(2) - adapters = append(adapters, "sqlite3.register_adapter(memoryview, _adapt_memoryview)") - body.WriteLine("def _convert_memoryview(val: bytes) -> memoryview:") - body.WriteIndentedLine(1, "return memoryview(val)") - body.NNewLine(2) - converters = append(converters, `sqlite3.register_converter("blob", _convert_memoryview)`) - } - for i, line := range adapters { - body.WriteLine(line) - if i == len(adapters)-1 { - body.NewLine() - } - } - for i, line := range converters { - body.WriteLine(line) - if i == len(converters)-1 { - body.NNewLine(2) - } - } -} - -func SQLite3BuildQueryResults(body *builders.IndentStringBuilder) string { - body.WriteSyncQueryResultsClassHeader(SQLite3Conn, []string{ - "self._cursor: sqlite3.Cursor | None = None", - fmt.Sprintf("self._iterator: collections.abc.Iterator[%s] | None = None", Sqlite3Result), - }, Sqlite3Result) - body.WriteQueryResultsCallFunction([]string{ - "result = self._conn.execute(self._sql, self._args).fetchall()", - "return [self._decode_hook(row) for row in result]", - }) - body.NewLine() - body.WriteIndentedLine(1, "def __next__(self) -> T:") - body.WriteQueryResultsNextDocstringSqlite() - body.WriteIndentedLine(2, "if self._cursor is None or self._iterator is None:") - body.WriteIndentedLine(3, "self._cursor: sqlite3.Cursor | None = self._conn.execute(self._sql, self._args)") - body.WriteIndentedLine(3, "self._iterator = self._cursor.__iter__()") - body.WriteIndentedLine(2, "try:") - body.WriteIndentedLine(3, "record = self._iterator.__next__()") - body.WriteIndentedLine(2, "except StopIteration:") - body.WriteIndentedLine(3, "self._cursor = None") - body.WriteIndentedLine(3, "self._iterator = None") - body.WriteIndentedLine(3, "raise") - body.WriteIndentedLine(2, "return self._decode_hook(record)") - return "QueryResults" -} - -func SQLite3BuildPyQueryFunc(query *core.Query, body *builders.IndentStringBuilder, args []core.FunctionArg, retType core.PyType, conf *core.Config) error { - indentLevel := 0 - params := fmt.Sprintf("conn: %s", SQLite3Conn) - conn := "conn" - docstringConnType := SQLite3Conn - if conf.EmitClasses { - params = "self" - conn = "self._conn" - indentLevel = 1 - docstringConnType = "" - } - body.WriteIndentedString(indentLevel, fmt.Sprintf("def %s(%s", query.FuncName, params)) - body.WriteQueryFunctionArgs(args, conf) - if query.Cmd == metadata.CmdExec { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("%s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecResult { - body.WriteLine(fmt.Sprintf(") -> %s:", "sqlite3.Cursor")) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, core.PyType{Type: "sqlite3.Cursor"}) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(")") - } else if query.Cmd == metadata.CmdExecRows { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(").rowcount") - } else if query.Cmd == metadata.CmdExecLastId { - body.WriteLine(fmt.Sprintf(") -> %s:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(").lastrowid") - } else if query.Cmd == metadata.CmdOne { - body.WriteLine(fmt.Sprintf(") -> %s | None:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("row = %s.execute(%s", conn, query.ConstantName)) - sqlite3WriteParams(query, body) - body.WriteLine(").fetchone()") - body.WriteIndentedLine(indentLevel+1, "if row is None:") - body.WriteIndentedLine(indentLevel+2, "return None") - if query.Ret.IsStruct() { - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoOverride() { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } else { - if query.Ret.Typ.DoOverride() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("return %s(row[0])", query.Ret.Typ.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, "return row[0]") - } - } - } else if query.Cmd == metadata.CmdMany { - body.WriteLine(fmt.Sprintf(") -> QueryResults[%s]:", retType.Type)) - body.WriteQueryFunctionDocstring(indentLevel+1, query, docstringConnType, args, retType) - - decodeHook := "_decode_hook" - if !query.Ret.IsStruct() && !query.Ret.Typ.DoOverride() { - decodeHook = "operator.itemgetter(0)" - } else if !query.Ret.IsStruct() && query.Ret.Typ.DoOverride() { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", Sqlite3Result, retType.Type)) - body.WriteIndentedLine(indentLevel+2, fmt.Sprintf("return %s(row[0])", retType.Type)) - } else { - body.WriteIndentedLine(indentLevel+1, fmt.Sprintf("def _decode_hook(row: %s) -> %s:", Sqlite3Result, retType.Type)) - body.WriteIndentedString(indentLevel+2, fmt.Sprintf("return %s(", retType.Type)) - i := 0 - for _, col := range query.Ret.Table.Columns { - if i != 0 { - body.WriteString(", ") - } - if len(col.EmbedFields) != 0 { - var inner []string - body.WriteString(fmt.Sprintf("%s=%s(", col.Name, col.Type.Type)) - for _, embedCol := range col.EmbedFields { - if embedCol.Type.DoOverride() { - inner = append(inner, fmt.Sprintf("%s=%s(row[%s])", embedCol.Name, embedCol.Type.Type, strconv.Itoa(i))) - } else { - inner = append(inner, fmt.Sprintf("%s=row[%s]", embedCol.Name, strconv.Itoa(i))) - } - i++ - } - body.WriteString(strings.Join(inner, ", ") + ")") - } else { - if col.Type.DoOverride() { - body.WriteString(fmt.Sprintf("%s=%s(row[%s])", col.Name, col.Type.Type, strconv.Itoa(i))) - } else { - body.WriteString(fmt.Sprintf("%s=row[%s]", col.Name, strconv.Itoa(i))) - } - i++ - } - } - body.WriteLine(")") - } - body.WriteIndentedString(indentLevel+1, fmt.Sprintf("return QueryResults[%s](%s, %s, %s", retType.Type, conn, query.ConstantName, decodeHook)) - params := "" - for i, arg := range query.Args { - if !arg.IsEmpty() { - if i == len(query.Args)-1 { - params += fmt.Sprintf(" %s", arg.Name) - } else { - params += fmt.Sprintf(" %s,", arg.Name) - } - } - } - if params != "" { - body.WriteString("," + params) - } - body.WriteLine(")") - } - return nil -} - -func SQLite3AcceptedDriverCMDs() []string { - return []string{ - metadata.CmdExec, - metadata.CmdExecResult, - metadata.CmdExecLastId, - metadata.CmdExecRows, - metadata.CmdOne, - metadata.CmdMany, - } -} - -func sqlite3WriteParams(query *core.Query, body *builders.IndentStringBuilder) { - if len(query.Args) == 0 { - return - } - params := "(" - for i, arg := range query.Args { - if !arg.IsEmpty() { - argName := arg.Name - if arg.Typ.DoOverride() { - argName = fmt.Sprintf("%s(%s)", arg.Typ.DefaultType, argName) - } - if i == len(query.Args)-1 && i != 0 { - params += fmt.Sprintf("%s", argName) - } else { - params += fmt.Sprintf("%s, ", argName) - } - } - } - body.WriteString(", " + params + ")") -} diff --git a/internal/codegen/init.go b/internal/codegen/init.go deleted file mode 100644 index 937106d..0000000 --- a/internal/codegen/init.go +++ /dev/null @@ -1,16 +0,0 @@ -package codegen - -import ( - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/sqlc-dev/plugin-sdk-go/plugin" -) - -func (dr *Driver) BuildInitFile() *plugin.File { - body := builders.NewIndentStringBuilder(dr.conf.IndentChar, dr.conf.CharsPerIndentLevel) - body.WriteSqlcHeader() - body.WriteInitFileModuleDocstring() - return &plugin.File{ - Name: "__init__.py", - Contents: []byte(body.String()), - } -} diff --git a/internal/codegen/queries.go b/internal/codegen/queries.go deleted file mode 100644 index 6471ce5..0000000 --- a/internal/codegen/queries.go +++ /dev/null @@ -1,207 +0,0 @@ -package codegen - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/drivers" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "sort" - "strings" -) - -func (dr *Driver) prepareFunctionHeader(query *core.Query, body *builders.IndentStringBuilder) ([]core.FunctionArg, string, []string) { - pyTableNames := make([]string, 0) - args := make([]core.FunctionArg, 0) - for _, arg := range query.Args { - if !arg.IsEmpty() { - argType := arg.Typ.Type - if arg.EmitStruct() && arg.IsStruct() { - BuildPyTabel(dr.conf.ModelType, arg.Table, body) - body.NNewLine(2) - pyTableNames = append(pyTableNames, arg.Table.Name) - if query.Cmd == metadata.CmdCopyFrom { - argType = fmt.Sprintf("collections.abc.Sequence[%s]", argType) - } - args = append(args, core.FunctionArg{ - Name: arg.Name, - Type: argType, - FunctionFormat: fmt.Sprintf("%s: %s", arg.Name, argType), - }) - } else { - if arg.Typ.IsList { - argType = fmt.Sprintf("collections.abc.Sequence[%s]", argType) - } - if arg.Typ.IsNullable { - argType = fmt.Sprintf("%s | None", argType) - } - args = append(args, core.FunctionArg{ - Name: arg.Name, - Type: argType, - FunctionFormat: fmt.Sprintf("%s: %s", arg.Name, argType), - }) - } - } - } - retType := "None" - if query.Ret.EmitStruct() && query.Ret.IsStruct() { - BuildPyTabel(dr.conf.ModelType, query.Ret.Table, body) - body.NNewLine(2) - retType = query.Ret.Table.Name - pyTableNames = append(pyTableNames, query.Ret.Table.Name) - } else if !query.Ret.IsEmpty() { - if query.Ret.IsStruct() { - retType = fmt.Sprintf("models.%s", query.Ret.Table.Name) - } else { - retType = query.Ret.Typ.Type - } - } - if query.Cmd == metadata.CmdExecLastId { - retType = "int | None" - } - if query.Cmd == metadata.CmdExecRows || query.Cmd == metadata.CmdCopyFrom { - retType = "int" - } - return args, retType, pyTableNames -} - -func (dr *Driver) BuildPyQueriesFiles(imp *core.Importer, queries []core.Query) ([]*plugin.File, error) { - files := make([]*plugin.File, 0) - fileQueries := make(map[string][]core.Query) - for _, query := range queries { - if err := dr.supportedCMD(query.Cmd); err != nil { - return nil, err - } - if val, found := fileQueries[query.SourceName]; found { - fileQueries[query.SourceName] = append(val, query) - } else { - fileQueries[query.SourceName] = []core.Query{query} - } - } - - for sourceName, queries := range fileQueries { - data, err := dr.buildPyQueriesFile(imp, queries, sourceName) - if err != nil { - return nil, err - } - files = append(files, &plugin.File{ - Name: core.SQLToPyFileName(sourceName), - Contents: data, - }) - } - - return files, nil -} - -func (dr *Driver) buildQueryHeader(query *core.Query, body *builders.IndentStringBuilder) { - body.WriteLine(fmt.Sprintf(`%s: typing.Final[str] = """-- name: %s %s`, query.ConstantName, query.MethodName, query.Cmd)) - body.WriteLine(query.SQL) - body.WriteLine(`"""`) -} - -func (dr *Driver) buildClassTemplate(sourceName string, body *builders.IndentStringBuilder) string { - className := core.SnakeToCamel(strings.ReplaceAll(sourceName, ".sql", ""), dr.conf) - body.WriteLine(fmt.Sprintf("class %s:", className)) - body.WriteQueryClassDocstring(sourceName, dr.connType) - body.WriteIndentedLine(1, `__slots__ = ("_conn",)`) - body.NewLine() - body.WriteIndentedLine(1, fmt.Sprintf(`def __init__(self, conn: %s) -> None:`, dr.connType)) - body.WriteQueryClassInitDocstring(2, dr.connType) - body.WriteIndentedLine(2, "self._conn = conn") - body.NewLine() - body.WriteIndentedLine(1, "@property") - body.WriteIndentedLine(1, fmt.Sprintf(`def conn(self) -> %s:`, dr.connType)) - body.WriteQueryClassConnDocstring(dr.connType) - body.WriteIndentedLine(2, `return self._conn`) - body.NewLine() - return className -} - -func (dr *Driver) buildPyQueriesFile(imp *core.Importer, queries []core.Query, sourceName string) ([]byte, error) { - body := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) - body.WriteSqlcHeader() - body.WriteQueryFileModuleDocstring(sourceName) - body.WriteImportAnnotations() - - newLines := 2 - if dr.conf.EmitClasses { - newLines = 1 - } - - allNames := make([]string, 0) - funcBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) - pyTableBody := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) - for _, query := range queries { - if !dr.conf.EmitClasses { - allNames = append(allNames, query.FuncName) - } - dr.buildQueryHeader(&query, funcBody) - funcBody.NewLine() - } - if core.IsAnyQueryMany(queries) { - funcBody.NewLine() - allNames = append(allNames, dr.driverBuildQueryResults(funcBody)) - funcBody.NewLine() - } - funcBody.NewLine() - if dr.conf.EmitClasses { - allNames = append(allNames, dr.buildClassTemplate(sourceName, funcBody)) - } - for i, query := range queries { - args, retType, addedPyTableNames := dr.prepareFunctionHeader(&query, pyTableBody) - returnType := core.PyType{ - SqlType: query.Ret.Typ.SqlType, - Type: retType, - } - allNames = append(allNames, addedPyTableNames...) - err := dr.buildPyQueryFunc(&query, funcBody, args, returnType, dr.conf) - if err != nil { - return nil, err - } - if i != len(queries)-1 { - funcBody.NNewLine(newLines) - } - } - body.WriteLine("__all__: collections.abc.Sequence[str] = (") - if len(allNames) > 0 { - sort.Slice(allNames, func(i, j int) bool { return allNames[i] < allNames[j] }) - } - for _, n := range allNames { - body.WriteIndentedLine(1, fmt.Sprintf("\"%s\",", n)) - } - body.WriteLine(")") - body.NewLine() - std, tye, pkg := imp.Imports(sourceName) - tyeHook := dr.driverTypeCheckingHook() - for _, imp := range std { - body.WriteLine(imp) - } - if len(tye) != 0 || len(tyeHook) != 0 { - if len(std) != 0 { - body.NewLine() - } - body.WriteLine("if typing.TYPE_CHECKING:") - for _, imp := range tye { - body.WriteIndentedLine(1, imp) - } - for i, imp := range tyeHook { - if i == 0 && len(tye) != 0 { - body.NewLine() - } - body.WriteIndentedLine(1, imp) - } - } - body.WriteLine("") - for _, imp := range pkg { - body.WriteLine(imp) - } - body.NNewLine(2) - if dr.conf.SqlDriver == core.SQLDriverAioSQLite { - drivers.AioSQLiteBuildTypeConvFunc(queries, body, dr.conf) - } - if dr.conf.SqlDriver == core.SQLDriverSQLite { - drivers.SQLite3BuildTypeConvFunc(queries, body, dr.conf) - } - return []byte(body.String() + pyTableBody.String() + funcBody.String()), nil -} diff --git a/internal/codegen/tables.go b/internal/codegen/tables.go deleted file mode 100644 index 7cd8312..0000000 --- a/internal/codegen/tables.go +++ /dev/null @@ -1,88 +0,0 @@ -package codegen - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen/builders" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/sqlc-dev/plugin-sdk-go/plugin" -) - -func (dr *Driver) BuildPyTablesFile(imp *core.Importer, tables []core.Table) (*plugin.File, error) { - fileName, fileContent, err := dr.buildPyTables(imp, tables) - if err != nil { - return nil, err - } - return &plugin.File{ - Name: core.SQLToPyFileName(fileName), - Contents: fileContent, - }, nil -} - -func BuildPyTabel(modelType string, table *core.Table, body *builders.IndentStringBuilder) { - if modelType == core.ModelTypeDataclass { - body.WriteLine("@dataclasses.dataclass()") - } else if modelType == core.ModelTypeAttrs { - body.WriteLine("@attrs.define()") - } - inheritance := "" - if modelType == core.ModelTypeMsgspec { - inheritance = "(msgspec.Struct)" - } - body.WriteLine(fmt.Sprintf("class %s%s:", table.Name, inheritance)) - body.WriteModelClassDocstring(table) - for _, col := range table.Columns { - type_ := col.Type.Type - if col.Type.IsList { - type_ = "collections.abc.Sequence[" + type_ + "]" - } - if col.Type.IsNullable { - type_ = type_ + " | None" - } - body.WriteIndentedLine(1, col.Name+": "+type_) - } -} - -func (dr *Driver) buildPyTables(imp *core.Importer, tables []core.Table) (string, []byte, error) { - fileName := "models.sql" - body := builders.NewIndentStringBuilder(imp.C.IndentChar, imp.C.CharsPerIndentLevel) - body.WriteSqlcHeader() - body.WriteModelFileModuleDocstring() - body.WriteImportAnnotations() - body.WriteLine("__all__: collections.abc.Sequence[str] = (") - for _, table := range tables { - body.WriteIndentedLine(1, fmt.Sprintf("\"%s\",", table.Name)) - } - body.WriteLine(")") - body.WriteString("\n") - std, tye, pkg := imp.Imports(fileName) - for _, imp := range std { - body.WriteLine(imp) - } - if len(tye) != 0 { - if len(std) != 0 { - body.NewLine() - } - if !dr.conf.OmitTypecheckingBlock { - body.WriteLine("if typing.TYPE_CHECKING:") - for _, imp := range tye { - body.WriteIndentedLine(1, imp) - } - } else { - for _, imp := range tye { - body.WriteLine(imp) - } - } - } - for i, imp := range pkg { - if i == 0 { - body.NewLine() - } - body.WriteLine(imp) - } - for _, table := range tables { - body.WriteString("\n") - body.WriteString("\n") - BuildPyTabel(imp.C.ModelType, &table, body) - } - return fileName, []byte(body.String()), nil -} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..f7752f3 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,128 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/rayakame/sqlc-gen-better-python/internal/utils" + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +type Config struct { + Package string `json:"package" yaml:"package"` + SqlDriver SQLDriver `json:"sql_driver" yaml:"sql_driver"` + ModelType ModelType `json:"model_type" yaml:"model_type"` + Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms,omitempty"` + EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` + EmitClasses bool `json:"emit_classes" yaml:"emit_classes"` + InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names,omitempty"` + OmitUnusedModels bool `json:"omit_unused_models" yaml:"omit_unused_models"` + OmitTypecheckingBlock bool `json:"omit_typechecking_block" yaml:"omit_typechecking_block"` + QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` + OmitKwargsLimit *int32 `json:"omit_kwargs_limit,omitempty" yaml:"omit_kwargs_limit"` + EmitInitFile *bool `json:"emit_init_file" yaml:"emit_init_file"` + EmitDocstrings DocstringConvention `json:"docstrings" yaml:"docstrings"` + OmitDocstringsSQL bool `json:"docstrings_emit_sql" yaml:"docstrings_emit_sql"` + Speedups bool `json:"speedups" yaml:"speedups"` + // Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` + + Debug bool `json:"debug" yaml:"debug"` + + IndentChar string `json:"indent_char" yaml:"indent_char"` + CharsPerIndentLevel int `json:"chars_per_indent_level" yaml:"chars_per_indent_level"` + + InitialismsMap map[string]struct{} `json:"-" yaml:"-"` + Async bool `json:"-" yaml:"-"` +} + +func NewConfig(req *plugin.GenerateRequest) (*Config, error) { + config, err := parseConfig(req) + if err != nil { + return nil, err + } + err = validateConf(config, req.Settings.Engine) + if err != nil { + return nil, err + } + + return config, nil +} + +func parseConfig(req *plugin.GenerateRequest) (*Config, error) { + var config Config + if len(req.PluginOptions) == 0 { + return &config, nil + } + if err := json.Unmarshal(req.PluginOptions, &config); err != nil { + return nil, fmt.Errorf("unmarshalling plugin options: %w", err) + } + config.Async = config.SqlDriver.Async() + + /* + for i := range config.Overrides { + if err := config.Overrides[i].parse(req); err != nil { + return nil, err + } + }*/ + + if config.ModelType == "" { + config.ModelType = ModelTypeDataclass + } + if config.QueryParameterLimit == nil { + config.QueryParameterLimit = utils.ToPtr(int32(1)) + } + if config.OmitKwargsLimit == nil { + config.OmitKwargsLimit = new(int32) + } + if config.Initialisms == nil { + config.Initialisms = utils.ToPtr([]string{"id"}) + } + if config.IndentChar == "" { + config.IndentChar = " " + } + if config.CharsPerIndentLevel <= 0 { + config.CharsPerIndentLevel = 4 + } + if config.EmitDocstrings == "" { + config.EmitDocstrings = DocstringConventionNone + } + + config.InitialismsMap = map[string]struct{}{} + for _, initial := range *config.Initialisms { + config.InitialismsMap[initial] = struct{}{} + } + + return &config, nil +} + +func validateConf(conf *Config, engine string) error { + if *conf.QueryParameterLimit < 0 { + return errors.New("invalid options: query parameter limit must not be negative") + } + if *conf.OmitKwargsLimit < 0 { + return errors.New("invalid options: omit kwarg limit must not be negative") + } + + if conf.EmitInitFile == nil { + return errors.New("invalid options: you need to specify emit_init_file") + } + + if conf.Package == "" { + return errors.New("invalid options: package must not be empty") + } + + if err := conf.SqlDriver.Validate(engine); err != nil { + return fmt.Errorf("invalid options: unknown model type: %w", err) + } + + if !conf.ModelType.Valid() { + return fmt.Errorf("invalid options: unknown model type: %s", conf.ModelType) + } + + if !conf.EmitDocstrings.Valid() { + return fmt.Errorf("invalid options: unknown docstring convention: %s", conf.EmitDocstrings) + } + + return nil +} diff --git a/internal/config/constants.go b/internal/config/constants.go new file mode 100644 index 0000000..f762f75 --- /dev/null +++ b/internal/config/constants.go @@ -0,0 +1,85 @@ +package config + +import "fmt" + +const PluginVersion = "v0.4.5" + +type ( + SQLDriver string + DocstringConvention string + ModelType string +) + +func (dr SQLDriver) String() string { + return string(dr) +} + +const ( + SQLDriverSQLite SQLDriver = "sqlite3" + SQLDriverAioSQLite SQLDriver = "aiosqlite" + SQLDriverAsyncpg SQLDriver = "asyncpg" +) + +const ( + ModelTypeDataclass ModelType = "dataclass" + ModelTypeAttrs ModelType = "attrs" + ModelTypeMsgspec ModelType = "msgspec" +) + +var asyncDrivers = map[SQLDriver]bool{ + SQLDriverSQLite: false, + SQLDriverAioSQLite: true, + SQLDriverAsyncpg: true, +} + +var driversEngine = map[SQLDriver]string{ + SQLDriverSQLite: "sqlite", + SQLDriverAioSQLite: "sqlite", + SQLDriverAsyncpg: "postgresql", +} + +const ( + DocstringConventionNone DocstringConvention = "none" + DocstringConventionGoogle DocstringConvention = "google" + DocstringConventionNumpy DocstringConvention = "numpy" + DocstringConventionPEP257 DocstringConvention = "pep257" +) + +func (dr SQLDriver) Async() bool { + val, found := asyncDrivers[dr] + if !found { + return false + } + + return val +} + +func (dr SQLDriver) Validate(engine string) error { + val, found := driversEngine[dr] + if !found { + return fmt.Errorf("unknown SQL driver: %s", dr) + } + if val != engine { + return fmt.Errorf("SQL driver %s does not support %s", dr, engine) + } + + return nil +} + +func (modelType ModelType) Valid() bool { + switch modelType { + case ModelTypeDataclass, ModelTypeMsgspec, ModelTypeAttrs: + return true + default: + return false + } +} + +func (ds DocstringConvention) Valid() bool { + switch ds { + case DocstringConventionNone, DocstringConventionNumpy, DocstringConventionGoogle, DocstringConventionPEP257: + return true + default: + return false + } +} diff --git a/internal/core/Column.go b/internal/core/Column.go deleted file mode 100644 index bda3f79..0000000 --- a/internal/core/Column.go +++ /dev/null @@ -1,13 +0,0 @@ -package core - -import "github.com/sqlc-dev/plugin-sdk-go/plugin" - -type Column struct { - Name string // CamelCased name for Go - DBName string // Name as used in the DB - Type PyType - Comment string - Column *plugin.Column - // EmbedFields contains the embedded fields that require scanning. - EmbedFields []Column -} diff --git a/internal/core/config.go b/internal/core/config.go deleted file mode 100644 index 3481888..0000000 --- a/internal/core/config.go +++ /dev/null @@ -1,126 +0,0 @@ -package core - -import ( - "encoding/json" - "fmt" - "github.com/sqlc-dev/plugin-sdk-go/plugin" -) - -const PluginVersion = "v0.4.5" - -type Config struct { - Package string `json:"package" yaml:"package"` - SqlDriver SQLDriverType `json:"sql_driver" yaml:"sql_driver"` - ModelType string `json:"model_type" yaml:"model_type"` - Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms,omitempty"` - EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` - EmitClasses bool `json:"emit_classes" yaml:"emit_classes"` - InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names,omitempty"` - OmitUnusedModels bool `json:"omit_unused_models" yaml:"omit_unused_models"` - OmitTypecheckingBlock bool `json:"omit_typechecking_block" yaml:"omit_typechecking_block"` - QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"` - OmitKwargsLimit *int32 `json:"omit_kwargs_limit,omitempty" yaml:"omit_kwargs_limit"` - EmitInitFile *bool `json:"emit_init_file" yaml:"emit_init_file"` - EmitDocstrings *string `json:"docstrings" yaml:"docstrings"` - EmitDocstringsSQL *bool `json:"docstrings_emit_sql" yaml:"docstrings_emit_sql"` - Speedups bool `json:"speedups" yaml:"speedups"` - Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` - - Debug bool `json:"debug" yaml:"debug"` - - IndentChar string `json:"indent_char" yaml:"indent_char"` - CharsPerIndentLevel int `json:"chars_per_indent_level" yaml:"chars_per_indent_level"` - - InitialismsMap map[string]struct{} `json:"-" yaml:"-"` - Async bool -} - -func ParseConfig(req *plugin.GenerateRequest) (*Config, error) { - var config Config - if len(req.PluginOptions) == 0 { - return &config, nil - } - if err := json.Unmarshal(req.PluginOptions, &config); err != nil { - return nil, fmt.Errorf("unmarshalling plugin options: %w", err) - } - if config.SqlDriver == "" { - return nil, fmt.Errorf("invalid options: driver must not be empty") - } - val, err := isDriverAsync(config.SqlDriver) - if err != nil { - return nil, fmt.Errorf("invalid options: %s", err) - } - config.Async = val - - for i := range config.Overrides { - if err := config.Overrides[i].parse(req); err != nil { - return nil, err - } - } - - if config.ModelType == "" { - config.ModelType = ModelTypeDataclass - } - if config.QueryParameterLimit == nil { - config.QueryParameterLimit = new(int32) - *config.QueryParameterLimit = 1 - } - if config.OmitKwargsLimit == nil { - config.OmitKwargsLimit = new(int32) - *config.OmitKwargsLimit = 0 - } - if config.Initialisms == nil { - config.Initialisms = new([]string) - *config.Initialisms = []string{"id"} - } - if config.IndentChar == "" { - config.IndentChar = " " - } - if config.CharsPerIndentLevel == 0 { - config.CharsPerIndentLevel = 4 - } - if config.EmitDocstrings == nil { - config.EmitDocstrings = new(string) - *config.EmitDocstrings = DocstringConventionNone - } - if config.EmitDocstringsSQL == nil { - config.EmitDocstringsSQL = new(bool) - *config.EmitDocstringsSQL = true - } - - config.InitialismsMap = map[string]struct{}{} - for _, initial := range *config.Initialisms { - config.InitialismsMap[initial] = struct{}{} - } - return &config, nil -} -func ValidateConf(conf *Config, engine string) error { - if *conf.QueryParameterLimit < 0 { - return fmt.Errorf("invalid options: query parameter limit must not be negative") - } - if *conf.OmitKwargsLimit < 0 { - return fmt.Errorf("invalid options: omit kwarg limit must not be negative") - } - - if conf.EmitInitFile == nil { - return fmt.Errorf("invalid options: you need to specify emit_init_file") - } - - if conf.Package == "" { - return fmt.Errorf("invalid options: package must not be empty") - } - - if err := isDriverValid(conf.SqlDriver, engine); err != nil { - return err - } - - if err := isModelTypeValid(conf.ModelType); err != nil { - return fmt.Errorf("invalid options: %s", err) - } - - if err := isDocstringValid(conf.EmitDocstrings); err != nil { - return fmt.Errorf("invalid options: %s", err) - } - - return nil -} diff --git a/internal/core/enums.go b/internal/core/enums.go deleted file mode 100644 index caf2bf8..0000000 --- a/internal/core/enums.go +++ /dev/null @@ -1,89 +0,0 @@ -package core - -import "fmt" - -type SQLDriverType string - -func (dr *SQLDriverType) String() string { - return string(*dr) -} - -const ( - SQLDriverSQLite SQLDriverType = "sqlite3" - SQLDriverAioSQLite SQLDriverType = "aiosqlite" - SQLDriverAsyncpg SQLDriverType = "asyncpg" -) - -const ( - ModelTypeDataclass = "dataclass" - ModelTypeAttrs = "attrs" - ModelTypeMsgspec = "msgspec" -) - -var asyncDrivers = map[SQLDriverType]bool{ - SQLDriverSQLite: false, - SQLDriverAioSQLite: true, - SQLDriverAsyncpg: true, -} - -var driversEngine = map[SQLDriverType]string{ - SQLDriverSQLite: "sqlite", - SQLDriverAioSQLite: "sqlite", - SQLDriverAsyncpg: "postgresql", -} - -var validModelTypes = map[string]struct{}{ - string(ModelTypeDataclass): {}, - string(ModelTypeAttrs): {}, - string(ModelTypeMsgspec): {}, -} - -const ( - DocstringConventionNone = "none" - DocstringConventionGoogle = "google" - DocstringConventionNumpy = "numpy" - DocstringConventionPEP257 = "pep257" -) - -var validDocstringConventions = map[string]struct{}{ - DocstringConventionNone: {}, - DocstringConventionGoogle: {}, - DocstringConventionNumpy: {}, - DocstringConventionPEP257: {}, -} - -func isDriverAsync(sqlDriver SQLDriverType) (bool, error) { - val, found := asyncDrivers[sqlDriver] - if !found { - return false, fmt.Errorf("unknown SQL driver: %s", sqlDriver) - } - return val, nil -} - -func isDriverValid(sqlDriver SQLDriverType, engine string) error { - val, found := driversEngine[sqlDriver] - if !found { - return fmt.Errorf("unknown SQL driver: %s", sqlDriver) - } - if val != engine { - return fmt.Errorf("SQL driver %s does not support %s", sqlDriver, engine) - } - return nil -} - -func isModelTypeValid(modelType string) error { - if _, found := validModelTypes[modelType]; !found { - return fmt.Errorf("unknown model type: %s", modelType) - } - return nil -} - -func isDocstringValid(ds *string) error { - if ds == nil { - return nil - } - if _, found := validDocstringConventions[*ds]; !found { - return fmt.Errorf("unknown docstring convention: %s", ds) - } - return nil -} diff --git a/internal/core/function.go b/internal/core/function.go deleted file mode 100644 index 691745e..0000000 --- a/internal/core/function.go +++ /dev/null @@ -1,7 +0,0 @@ -package core - -type FunctionArg struct { - Name string - Type string - FunctionFormat string -} diff --git a/internal/core/importer.go b/internal/core/importer.go deleted file mode 100644 index 3e2ec67..0000000 --- a/internal/core/importer.go +++ /dev/null @@ -1,422 +0,0 @@ -package core - -import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "sort" - "strings" -) - -type importSpec struct { - Module string - Name string - Alias string - TypeChecking bool -} - -func (i importSpec) String() string { - if i.Alias != "" { - if i.Name == "" { - return fmt.Sprintf("import %s as %s", i.Module, i.Alias) - } - return fmt.Sprintf("from %s import %s as %s", i.Module, i.Name, i.Alias) - } - if i.Name == "" { - return "import " + i.Module - } - return fmt.Sprintf("from %s import %s", i.Module, i.Name) -} - -type Importer struct { - Tables []Table - Queries []Query - Enums []Enum - C *Config -} - -func (i *Importer) Imports(fileName string) ([]string, []string, []string) { - if fileName == "models.sql" { - return i.modelImports() - } - return i.queryImports(fileName) -} - -func TableUses(name string, s Table) (bool, PyType) { - for _, col := range s.Columns { - if col.Type.Type == name { - return true, col.Type - } - } - return false, PyType{} - -} - -func (i *Importer) getModelImportSpec() (string, importSpec, error) { - switch i.C.ModelType { - case ModelTypeAttrs: - return "attrs", importSpec{Module: "attrs"}, nil - case ModelTypeDataclass: - return "dataclasses", importSpec{Module: "dataclasses"}, nil - case ModelTypeMsgspec: - return "msgspec", importSpec{Module: "msgspec"}, nil - default: - return "", importSpec{}, fmt.Errorf("unknown model type: %s", i.C.ModelType) - } -} - -func (i *Importer) splitTypeChecking(pks map[string]importSpec) (map[string]importSpec, map[string]importSpec) { - normalImports := make(map[string]importSpec) - typeChecking := make(map[string]importSpec) - for name, val := range pks { - if val.TypeChecking { - typeChecking[name] = val - } else { - normalImports[name] = val - } - } - return normalImports, typeChecking -} - -func (i *Importer) modelImportSpecs() (map[string]importSpec, map[string]importSpec, map[string]importSpec) { - modelUses := func(name string) (bool, bool) { - for _, table := range i.Tables { - if val, _ := TableUses(name, table); val { - return true, true - } - } - return false, false - } - - std := stdImports(modelUses) - for _, override := range i.C.Overrides { - if val1, val2 := modelUses(override.PyTypeName); val1 { - std[override.PyTypeName] = importSpec{Module: override.PyImportPath, Name: override.PyPackageName, TypeChecking: val2} - } - } - std, typeChecking := i.splitTypeChecking(std) - if len(typeChecking) != 0 { - std["typing"] = importSpec{Module: "typing"} - } - modelName, modelImport, err := i.getModelImportSpec() - if err == nil { - std[modelName] = modelImport - } - if len(i.Enums) > 0 { - std["enum"] = importSpec{Module: fmt.Sprintf("from %s import enums", i.C.Package)} - } - - pkg := make(map[string]importSpec) - - return std, typeChecking, pkg -} - -func (i *Importer) queryValueUses(name string, qv QueryValue) (bool, bool) { - if !qv.IsEmpty() { - if qv.IsStruct() && qv.EmitStruct() { - if val, pyType := TableUses(name, *qv.Table); val { - if i.C.SqlDriver == SQLDriverAsyncpg { - if pyType.DoConversion(typeConversion.AsyncpgDoTypeConversion) { - return true, false - } else { - return true, true - } - } else if i.C.SqlDriver == SQLDriverAioSQLite || i.C.SqlDriver == SQLDriverSQLite { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true, false - } else { - return true, true - } - } - return true, false - } - } else if qv.IsStruct() && (i.C.SqlDriver == SQLDriverAioSQLite || i.C.SqlDriver == SQLDriverSQLite) { - if val, pyType := TableUses(name, *qv.Table); val { - if pyType.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true, false - } - } - } else { - if qv.Typ.Type == name { - if i.C.SqlDriver == SQLDriverAsyncpg { - if qv.Typ.DoConversion(typeConversion.AsyncpgDoTypeConversion) { - return true, false - } else { - return true, true - } - } else if i.C.SqlDriver == SQLDriverAioSQLite || i.C.SqlDriver == SQLDriverSQLite { - if qv.Typ.DoConversion(typeConversion.SqliteDoTypeConversion) { - return true, false - } else { - return true, true - } - } - return true, false - } - } - } - return false, false -} - -func (i *Importer) queryImportSpecs(_ string) (map[string]importSpec, map[string]importSpec, map[string]importSpec, map[string]importSpec) { - addCiso := false - queryUses := func(name string) (bool, bool) { - var uses *bool = nil - var typeChecking *bool = nil - - helper := func(val1, val2 bool) { - if uses == nil || typeChecking == nil { - uses = new(bool) - typeChecking = new(bool) - *uses = val1 - *typeChecking = val2 - } else if *typeChecking == true { - *uses = val1 - *typeChecking = val2 - } - } - for _, q := range i.Queries { - //if q.SourceName != fileName { TODO q.SourceName is the name of the sql file - // continue - //} - if val1, val2 := i.queryValueUses(name, q.Ret); val1 { - if q.Cmd == metadata.CmdMany { - helper(val1, false) - } - // if we have speedups enabled then we don't need datetime in the std imports - // we use ciso8601 for the converting and need datetime only in typechecking - if val2 == false && (i.C.SqlDriver == SQLDriverAioSQLite || i.C.SqlDriver == SQLDriverSQLite) && i.C.Speedups && (name == "datetime.datetime" || name == "datetime.date") { - helper(val1, true) - addCiso = true - } else { - helper(val1, val2) - } - } - for _, arg := range q.Args { - if val1, val2 := i.queryValueUses(name, arg); val1 { - // if we have speedups enabled then we don't need datetime in the std imports - // we use ciso8601 for the converting and need datetime only in typechecking - if val2 == false && (i.C.SqlDriver == SQLDriverAioSQLite || i.C.SqlDriver == SQLDriverSQLite) && i.C.Speedups && (name == "datetime.datetime" || name == "datetime.date") { - helper(val1, true) - addCiso = true - } else { - helper(val1, val2) - } - } - } - } - if uses == nil || typeChecking == nil { - return false, false - } - return *uses, *typeChecking - } - querySimpleReturn := func(conv typeConversion.TypeDoTypeConversion) bool { - for _, q := range i.Queries { - if !q.Ret.IsStruct() && !conv(q.Ret.Typ.SqlType) { - return true - } - } - return false - } - - std := stdImports(queryUses) - for _, override := range i.C.Overrides { - if val1, val2 := queryUses(override.PyTypeName); val1 { - std[override.PyTypeName] = importSpec{Module: override.PyImportPath, Name: override.PyPackageName, TypeChecking: val2} - } - } - std, typeChecking := i.splitTypeChecking(std) - if i.C.SqlDriver == SQLDriverAsyncpg { - typeChecking[string(SQLDriverAsyncpg)] = importSpec{Module: string(SQLDriverAsyncpg)} - - if IsAnyQueryMany(i.Queries) { - typeChecking[string(SQLDriverAsyncpg)+".cursor"] = importSpec{Module: string(SQLDriverAsyncpg) + ".cursor"} - if querySimpleReturn(typeConversion.AsyncpgDoTypeConversion) { - std["operator"] = importSpec{Module: "operator"} - } - } - } else if i.C.SqlDriver == SQLDriverAioSQLite { - // if the std mapping has exactly 2 members, these two are collections and typing, - // but if they are more than 2, we need to add type conversion and for that we - // need the aiosqlite in the normal import block, not in the type checking block - if len(std) > 2 { - std[string(SQLDriverAioSQLite)] = importSpec{Module: string(SQLDriverAioSQLite)} - } else { - typeChecking[string(SQLDriverAioSQLite)] = importSpec{Module: string(SQLDriverAioSQLite)} - } - if IsAnyQueryMany(i.Queries) { - typeChecking[string(SQLDriverSQLite)] = importSpec{Module: string(SQLDriverSQLite)} - if querySimpleReturn(typeConversion.SqliteDoTypeConversion) { - std["operator"] = importSpec{Module: "operator"} - } - } - } else if i.C.SqlDriver == SQLDriverSQLite { - // if the std mapping has exactly 2 members, these two are collections and typing, - // but if they are more than 2, we need to add type conversion and for that we - // need the aiosqlite in the normal import block, not in the type checking block - if len(std) > 2 { - std[string(SQLDriverSQLite)] = importSpec{Module: string(SQLDriverSQLite)} - } else { - typeChecking[string(SQLDriverSQLite)] = importSpec{Module: string(SQLDriverSQLite)} - } - if IsAnyQueryMany(i.Queries) { - if querySimpleReturn(typeConversion.SqliteDoTypeConversion) { - std["operator"] = importSpec{Module: "operator"} - } - } - } - if addCiso { - std["ciso8601"] = importSpec{Module: "ciso8601"} - } - - pkg := make(map[string]importSpec) - loc := make(map[string]importSpec) - - queryValueModelImports := func(qv QueryValue) { - if qv.IsStruct() && qv.EmitStruct() { - modelName, modelImport, err := i.getModelImportSpec() - if err == nil { - std[modelName] = modelImport - } - } - } - - for _, q := range i.Queries { - //if q.SourceName != fileName { TODO - // continue - //} - queryValueModelImports(q.Ret) - if q.Cmd == metadata.CmdCopyFrom { - modelName, modelImport, err := i.getModelImportSpec() - if err == nil { - std[modelName] = modelImport - } - } - } - - loc["models"] = importSpec{Module: i.C.Package, Name: "models"} - - return std, typeChecking, pkg, loc -} - -func (i *Importer) queryImports(fileName string) ([]string, []string, []string) { - std, typeCheck, pkg, loc := i.queryImportSpecs(fileName) - - importLines := make([]string, 0) - typeLines := make([]string, 0) - packageLines := make([]string, 0) - if len(std) != 0 { - importLines = append(importLines, buildImportBlock(std)...) - } - if len(typeCheck) != 0 { - typeLines = append(typeLines, buildImportBlock(typeCheck)...) - } - if IsAnyQueryMany(i.Queries) { - if len(typeCheck) != 0 { - typeLines[len(typeLines)-1] = typeLines[len(typeLines)-1] + "\n" - } - queryResultsArgsType := "QueryResultsArgsType: typing.TypeAlias = int | float | str | memoryview" - if IsInMultipleMaps("decimal", std, typeCheck) { - queryResultsArgsType += " | decimal.Decimal" - } - if IsInMultipleMaps("uuid", std, typeCheck) { - queryResultsArgsType += " | uuid.UUID" - } - if IsInMultipleMaps("datetime", std, typeCheck) { - queryResultsArgsType += " | datetime.date | datetime.time | datetime.datetime | datetime.timedelta" - } - queryResultsArgsType += " | None" - typeLines = append(typeLines, queryResultsArgsType) - } - - if len(pkg) != 0 { - packageLines = append(packageLines, buildImportBlock(pkg)...) - } - if len(loc) != 0 { - if len(packageLines) != 0 { - packageLines = append(packageLines, "") - } - packageLines = append(packageLines, buildImportBlock(loc)...) - } - return importLines, typeLines, packageLines -} - -func (i *Importer) modelImports() ([]string, []string, []string) { - std, typeCheck, pkg := i.modelImportSpecs() - importLines := make([]string, 0) - typeLines := make([]string, 0) - packageLines := make([]string, 0) - if len(std) != 0 { - importLines = append(importLines, buildImportBlock(std)...) - } - if len(typeCheck) != 0 { - typeLines = append(typeLines, buildImportBlock(typeCheck)...) - } - if len(pkg) != 0 { - packageLines = append(packageLines, buildImportBlock(pkg)...) - } - return importLines, typeLines, packageLines -} - -func buildImportBlock(pkgs map[string]importSpec) []string { - pkgImports := make([]importSpec, 0) - fromImports := make(map[string][]string) - for _, is := range pkgs { - if is.Name == "" || is.Alias != "" { - pkgImports = append(pkgImports, is) - } else { - names, ok := fromImports[is.Module] - if !ok { - names = make([]string, 0, 1) - } - names = append(names, is.Name) - fromImports[is.Module] = names - } - } - - importStrings := make([]string, 0, len(pkgImports)+len(fromImports)) - for _, is := range pkgImports { - importStrings = append(importStrings, is.String()) - } - for modName, names := range fromImports { - sort.Strings(names) - nameString := strings.Join(names, ", ") - importStrings = append(importStrings, fmt.Sprintf("from %s import %s", modName, nameString)) - } - sort.Strings(importStrings) - return importStrings -} - -// typeCheckingOverwriteProtection function that takes in importSpec map and adds/replaced imports. -// Important here is that importSpec's with TypeChecking set to false have higher priority then -// type checking imports. -func typeCheckingOverwriteProtection(std map[string]importSpec, name string, newImport importSpec) { - if val, found := std[name]; found { - if val.TypeChecking == true { - std[name] = newImport - } - } else { - std[name] = newImport - } -} - -func stdImports(uses func(name string) (bool, bool)) map[string]importSpec { - std := make(map[string]importSpec) - std["collections"] = importSpec{Module: "collections.abc", TypeChecking: true} - std["typing"] = importSpec{Module: "typing", TypeChecking: false} - add := func(name, module string) { - if use, typeChecking := uses(name); use { - typeCheckingOverwriteProtection(std, module, importSpec{Module: module, TypeChecking: typeChecking}) - } - } - - add("decimal.Decimal", "decimal") - - add("datetime.date", "datetime") - add("datetime.time", "datetime") - add("datetime.datetime", "datetime") - add("datetime.timedelta", "datetime") - - add("uuid.UUID", "uuid") - return std -} diff --git a/internal/core/models.go b/internal/core/models.go deleted file mode 100644 index 22bbbcf..0000000 --- a/internal/core/models.go +++ /dev/null @@ -1,131 +0,0 @@ -package core - -import ( - "github.com/rayakame/sqlc-gen-better-python/internal/typeConversion" - "github.com/sqlc-dev/plugin-sdk-go/metadata" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strings" -) - -type Table struct { - Table *plugin.Identifier - Name string - Columns []Column - Comment string -} - -type PyType struct { - SqlType string - Type string - DefaultType string - IsList bool - IsNullable bool - IsEnum bool - IsOverride bool - Override *Override -} - -func (p *PyType) DoConversion(conversion typeConversion.TypeDoTypeConversion) bool { - if p.DoOverride() { - return true - } - return conversion(p.SqlType) -} -func (p *PyType) DoOverride() bool { - return p.IsOverride && p.Override != nil -} - -type Constant struct { - Name string - Type string - Value string -} - -type Enum struct { - Name string - Comment string - Constants []Constant -} - -func enumReplacer(r rune) rune { - if strings.ContainsRune("-/:_", r) { - return '_' - } else if (r >= 'a' && r <= 'z') || - (r >= 'A' && r <= 'Z') || - (r >= '0' && r <= '9') { - return r - } else { - return -1 - } -} - -// EnumReplace removes all non ident symbols (all but letters, numbers and -// underscore) and returns valid ident name for provided name. -func EnumReplace(value string) string { - return strings.Map(enumReplacer, value) -} - -type QueryValue struct { - Emit bool - Name string - DBName string // The name of the field in the database. Only set if Struct==nil. - Table *Table - Typ PyType - - // Column is kept so late in the generation process around to differentiate - // between mysql slices and pg arrays - Column *plugin.Column -} - -func (v QueryValue) EmitStruct() bool { - return v.Emit -} - -func (v QueryValue) IsStruct() bool { - return v.Table != nil -} - -func (v QueryValue) IsEmpty() bool { - return v.Typ.Type == "" && v.Name == "" && v.Table == nil -} - -func (v QueryValue) Type() string { - if v.Typ.Type != "" { - return v.Typ.Type - } - if v.Table != nil { - return v.Table.Name - } - panic("no type for QueryValue: " + v.Name) -} - -type Query struct { - Cmd string - Comments []string - MethodName string - FuncName string - FieldName string - ConstantName string - SQL string - SourceName string - Ret QueryValue - Args []QueryValue - - // Used for :copyfrom - Table *plugin.Identifier -} - -func (q Query) HasRetType() bool { - scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || - q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne - return scanned && !q.Ret.IsEmpty() -} - -func IsAnyQueryMany(queries []Query) bool { - for _, query := range queries { - if query.Cmd == metadata.CmdMany { - return true - } - } - return false -} diff --git a/internal/core/naming.go b/internal/core/naming.go new file mode 100644 index 0000000..9a8bc95 --- /dev/null +++ b/internal/core/naming.go @@ -0,0 +1 @@ +package core diff --git a/internal/core/overrides.go b/internal/core/overrides.go deleted file mode 100644 index a5cfb5f..0000000 --- a/internal/core/overrides.go +++ /dev/null @@ -1,196 +0,0 @@ -package core - -import ( - "fmt" - "github.com/sqlc-dev/plugin-sdk-go/pattern" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "go/types" - "strings" -) - -type OverridePyType struct { - Import string `json:"import" yaml:"import"` - Name string `json:"type" yaml:"type"` - Package string `json:"package" yaml:"package"` - Spec string `json:"-"` - BuiltIn bool `json:"-"` -} - -type ParsedOverridePyType struct { - ImportPath string - TypeName string - PackageName string - BasicType bool -} - -func (gt OverridePyType) parse() (*ParsedOverridePyType, error) { - var o ParsedOverridePyType - - if gt.Spec == "" { - o.ImportPath = gt.Import - o.TypeName = gt.Name - o.PackageName = gt.Package - o.BasicType = gt.Import == "" - return &o, nil - } - - input := gt.Spec - lastDot := strings.LastIndex(input, ".") - lastSlash := strings.LastIndex(input, "/") - typename := input - if lastDot == -1 && lastSlash == -1 { - // if the type name has no slash and no dot, validate that the type is a basic Go type - var found bool - for _, typ := range types.Typ { - info := typ.Info() - if info == 0 { - continue - } - if info&types.IsUntyped != 0 { - continue - } - if typename == typ.Name() { - found = true - } - } - if !found { - return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input) - } - o.BasicType = true - } else { - // assume the type lives in a Go package - if lastDot == -1 { - return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input) - } - typename = input[lastSlash+1:] - // a package name beginning with "go-" will give syntax errors in - // generated code. We should do the right thing and get the actual - // import name, but in lieu of that, stripping the leading "go-" may get - // us what we want. - typename = strings.TrimPrefix(typename, "go-") - typename = strings.TrimSuffix(typename, "-go") - o.ImportPath = input[:lastDot] - } - o.TypeName = typename - isPointer := input[0] == '*' - if isPointer { - o.ImportPath = o.ImportPath[1:] - o.TypeName = "*" + o.TypeName - } - return &o, nil -} - -type Override struct { - // name of the golang type to use, e.g. `github.com/segmentio/ksuid.KSUID` - PyType OverridePyType `json:"py_type" yaml:"py_type"` - - // fully qualified name of the Go type, e.g. `github.com/segmentio/ksuid.KSUID` - DBType string `json:"db_type" yaml:"db_type"` - - // fully qualified name of the column, e.g. `accounts.id` - Column string `json:"column" yaml:"column"` - - ColumnName *pattern.Match `json:"-"` - TableCatalog *pattern.Match `json:"-"` - TableSchema *pattern.Match `json:"-"` - TableRel *pattern.Match `json:"-"` - PyImportPath string `json:"-"` - PyPackageName string `json:"-"` - PyTypeName string `json:"-"` - PyBasicType bool `json:"-"` -} - -func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool { - if n == nil { - return false - } - schema := n.Schema - if n.Schema == "" { - schema = defaultSchema - } - if o.TableCatalog != nil && !o.TableCatalog.MatchString(n.Catalog) { - return false - } - if o.TableSchema == nil && schema != "" { - return false - } - if o.TableSchema != nil && !o.TableSchema.MatchString(schema) { - return false - } - if o.TableRel == nil && n.Name != "" { - return false - } - if o.TableRel != nil && !o.TableRel.MatchString(n.Name) { - return false - } - return true -} - -func (o *Override) parse(req *plugin.GenerateRequest) (err error) { - - schema := "public" - if req != nil && req.Catalog != nil { - schema = req.Catalog.DefaultSchema - } - - // validate option combinations - switch { - case o.Column != "" && o.DBType != "": - return fmt.Errorf("Override specifying both `column` (%q) and `db_type` (%q) is not valid.", o.Column, o.DBType) - case o.Column == "" && o.DBType == "": - return fmt.Errorf("Override must specify one of either `column` or `db_type`") - } - - // validate Column - if o.Column != "" { - colParts := strings.Split(o.Column, ".") - switch len(colParts) { - case 2: - if o.ColumnName, err = pattern.MatchCompile(colParts[1]); err != nil { - return err - } - if o.TableRel, err = pattern.MatchCompile(colParts[0]); err != nil { - return err - } - if o.TableSchema, err = pattern.MatchCompile(schema); err != nil { - return err - } - case 3: - if o.ColumnName, err = pattern.MatchCompile(colParts[2]); err != nil { - return err - } - if o.TableRel, err = pattern.MatchCompile(colParts[1]); err != nil { - return err - } - if o.TableSchema, err = pattern.MatchCompile(colParts[0]); err != nil { - return err - } - case 4: - if o.ColumnName, err = pattern.MatchCompile(colParts[3]); err != nil { - return err - } - if o.TableRel, err = pattern.MatchCompile(colParts[2]); err != nil { - return err - } - if o.TableSchema, err = pattern.MatchCompile(colParts[1]); err != nil { - return err - } - if o.TableCatalog, err = pattern.MatchCompile(colParts[0]); err != nil { - return err - } - default: - return fmt.Errorf("Override `column` specifier %q is not the proper format, expected '[catalog.][schema.]tablename.colname'", o.Column) - } - } - - // validate GoType - parsed, err := o.PyType.parse() - if err != nil { - return err - } - o.PyImportPath = parsed.ImportPath - o.PyTypeName = parsed.TypeName - o.PyBasicType = parsed.BasicType - o.PyPackageName = parsed.PackageName - return nil -} diff --git a/internal/core/reserved.go b/internal/core/reserved.go deleted file mode 100644 index 74f1677..0000000 --- a/internal/core/reserved.go +++ /dev/null @@ -1,89 +0,0 @@ -// Package core Auto-generated using python; DO NOT EDIT -// py 3.13.1 (tags/v3.13.1:0671451, Dec 3 2024, 19:06:28) [MSC v.1942 64 bit (AMD64)] -package core - -func Escape(s string) string { - if IsReserved(s) { - return s + "_" - } - return s -} - -func IsReserved(s string) bool { - switch s { - case "False": - return true - case "None": - return true - case "True": - return true - case "and": - return true - case "as": - return true - case "assert": - return true - case "async": - return true - case "await": - return true - case "break": - return true - case "class": - return true - case "continue": - return true - case "def": - return true - case "del": - return true - case "elif": - return true - case "else": - return true - case "except": - return true - case "finally": - return true - case "for": - return true - case "from": - return true - case "global": - return true - case "if": - return true - case "import": - return true - case "in": - return true - case "is": - return true - case "lambda": - return true - case "nonlocal": - return true - case "not": - return true - case "or": - return true - case "pass": - return true - case "raise": - return true - case "return": - return true - case "try": - return true - case "while": - return true - case "with": - return true - case "yield": - return true - case "id": - return true - default: - return false - } -} diff --git a/internal/typeConversion/asyncpg.go b/internal/doTypeConversion/asyncpg.go similarity index 90% rename from internal/typeConversion/asyncpg.go rename to internal/doTypeConversion/asyncpg.go index f7638c3..59b6e17 100644 --- a/internal/typeConversion/asyncpg.go +++ b/internal/doTypeConversion/asyncpg.go @@ -1,4 +1,4 @@ -package typeConversion +package doTypeConversion func AsyncpgDoTypeConversion(name string) bool { _, found := map[string]struct{}{ @@ -8,5 +8,6 @@ func AsyncpgDoTypeConversion(name string) bool { "inet": {}, "cidr": {}, }[name] + return found } diff --git a/internal/typeConversion/common.go b/internal/doTypeConversion/common.go similarity index 64% rename from internal/typeConversion/common.go rename to internal/doTypeConversion/common.go index bcbca99..16b3b46 100644 --- a/internal/typeConversion/common.go +++ b/internal/doTypeConversion/common.go @@ -1,3 +1,3 @@ -package typeConversion +package doTypeConversion type TypeDoTypeConversion func(string) bool diff --git a/internal/typeConversion/sqlite.go b/internal/doTypeConversion/sqlite.go similarity index 94% rename from internal/typeConversion/sqlite.go rename to internal/doTypeConversion/sqlite.go index dfd8221..e595f83 100644 --- a/internal/typeConversion/sqlite.go +++ b/internal/doTypeConversion/sqlite.go @@ -1,4 +1,4 @@ -package typeConversion +package doTypeConversion import "strings" @@ -19,6 +19,7 @@ func SqliteDoTypeConversion(name string) bool { } else if strings.HasPrefix(name, "decimal") { return true } + return false } diff --git a/internal/gen.go b/internal/gen.go deleted file mode 100644 index fda62d2..0000000 --- a/internal/gen.go +++ /dev/null @@ -1,175 +0,0 @@ -package internal - -import ( - "context" - "encoding/json" - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/codegen" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/log" - "github.com/rayakame/sqlc-gen-better-python/internal/types" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "strings" -) - -type PythonGenerator struct { - req *plugin.GenerateRequest - config *core.Config - - typeConversionFunc types.TypeConversionFunc - sqlDriver *codegen.Driver -} - -func NewPythonGenerator(req *plugin.GenerateRequest) (*PythonGenerator, error) { - config, err := core.ParseConfig(req) - if err != nil { - return nil, err - } - if err = core.ValidateConf(config, req.Settings.Engine); err != nil { - return nil, err - } - var typeConversionFunc types.TypeConversionFunc - switch req.Settings.Engine { - case "postgresql": - typeConversionFunc = types.PostgresTypeToPython - case "sqlite": - typeConversionFunc = types.SqliteTypeToPython - default: - return nil, fmt.Errorf("engine %q is not supported", req.Settings.Engine) - } - - sqlDriver, err := codegen.NewDriver(config) - if err != nil { - return nil, err - } - - return &PythonGenerator{ - req: req, - config: config, - typeConversionFunc: typeConversionFunc, - sqlDriver: sqlDriver, - }, nil -} - -func (gen *PythonGenerator) Run() (*plugin.GenerateResponse, error) { - outputFiles := make([]*plugin.File, 0) - log.GlobalLogger.LogByte(gen.req.PluginOptions) - enums := gen.buildEnums() - tables := gen.buildTables() - queries, err := gen.buildQueries(tables) - if err != nil { - return nil, err - } - - jsonData, _ := json.Marshal(gen.req) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(gen.config) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(enums) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(tables) - log.GlobalLogger.LogByte(jsonData) - jsonData, _ = json.Marshal(queries) - log.GlobalLogger.LogByte(jsonData) - - if gen.config.OmitUnusedModels { - enums, tables = filterUnusedStructs(enums, tables, queries) - } - if err := gen.validate(enums, tables); err != nil { - return nil, err - } - importer := core.Importer{ - Tables: tables, - Queries: queries, - Enums: enums, - C: gen.config, - } - if file, err := gen.sqlDriver.BuildPyTablesFile(&importer, tables); err != nil { - return nil, err - } else { - outputFiles = append(outputFiles, file) - } - if files, err := gen.sqlDriver.BuildPyQueriesFiles(&importer, queries); err != nil { - return nil, err - } else { - outputFiles = append(outputFiles, files...) - } - if *gen.config.EmitInitFile { - outputFiles = append(outputFiles, gen.sqlDriver.BuildInitFile()) - } - jsonData, _ = json.Marshal(outputFiles) - log.GlobalLogger.LogByte(jsonData) - if gen.config.Debug { - fileName, fileContent := log.GlobalLogger.Print() - outputFiles = append(outputFiles, &plugin.File{ - Name: fileName, - Contents: fileContent, - }) - } - return &plugin.GenerateResponse{Files: outputFiles}, nil -} - -func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { - pythonGenerator, err := NewPythonGenerator(req) - if err != nil { - return nil, err - } - return pythonGenerator.Run() -} - -func (gen *PythonGenerator) validate(enums []core.Enum, structs []core.Table) error { - enumNames := make(map[string]struct{}) - for _, enum := range enums { - enumNames[enum.Name] = struct{}{} - enumNames["Null"+enum.Name] = struct{}{} - } - structNames := make(map[string]struct{}) - for _, struckt := range structs { - if _, ok := enumNames[struckt.Name]; ok { - return fmt.Errorf("struct name conflicts with enum name: %s", struckt.Name) - } - structNames[struckt.Name] = struct{}{} - } - return nil -} - -func filterUnusedStructs(enums []core.Enum, tables []core.Table, queries []core.Query) ([]core.Enum, []core.Table) { - keepTypes := make(map[string]struct{}) - - for _, query := range queries { - for _, arg := range query.Args { - if !arg.IsEmpty() { - keepTypes[arg.Type()] = struct{}{} - } - } - if query.HasRetType() { - keepTypes[query.Ret.Type()] = struct{}{} - if query.Ret.IsStruct() { - for _, field := range query.Ret.Table.Columns { - keepTypes[strings.ReplaceAll(field.Type.Type, "models.", "")] = struct{}{} - for _, embedField := range field.EmbedFields { - keepTypes[strings.ReplaceAll(embedField.Type.Type, "models.", "")] = struct{}{} - } - } - } - } - } - - keepEnums := make([]core.Enum, 0, len(enums)) - for _, enum := range enums { - _, keep := keepTypes[enum.Name] - _, keepNull := keepTypes["Null"+enum.Name] - if keep || keepNull { - keepEnums = append(keepEnums, enum) - } - } - - keepStructs := make([]core.Table, 0, len(tables)) - for _, st := range tables { - if _, ok := keepTypes[st.Name]; ok { - keepStructs = append(keepStructs, st) - } - } - - return keepEnums, keepStructs -} diff --git a/internal/handler.go b/internal/handler.go new file mode 100644 index 0000000..1051a5e --- /dev/null +++ b/internal/handler.go @@ -0,0 +1,44 @@ +package internal + +import ( + "context" + "fmt" + + configPackage "github.com/rayakame/sqlc-gen-better-python/internal/config" + "github.com/rayakame/sqlc-gen-better-python/internal/log" + "github.com/rayakame/sqlc-gen-better-python/internal/transform" + "github.com/rayakame/sqlc-gen-better-python/internal/types" + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +func Handler(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { + config, err := configPackage.NewConfig(req) + if err != nil { + return nil, fmt.Errorf("error trying to parse config: %w", err) + } + + typeConversionFunc, err := types.GetTypeConversionFunc(req.Settings.Engine) + if err != nil { + return nil, fmt.Errorf("error trying to parse config: %w", err) + } + + transformer := transform.NewTransformer(config, req, typeConversionFunc) + enums := transformer.BuildEnums() + tables := transformer.BuildTables() + queries := transformer.BuildQueries() + + log.L().LogAny(enums) + log.L().LogAny(tables) + log.L().LogAny(queries) + + outputFiles := make([]*plugin.File, 0) + if config.Debug { + fileName, fileContent := log.L().Export() + outputFiles = append(outputFiles, &plugin.File{ + Name: fileName, + Contents: fileContent, + }) + } + + return &plugin.GenerateResponse{Files: outputFiles}, nil +} diff --git a/internal/log/logger.go b/internal/log/logger.go index b672da1..dfbc0ac 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -1,20 +1,70 @@ package log +import ( + "encoding/json" + "fmt" + "sync" + + "github.com/rayakame/sqlc-gen-better-python/internal/utils" +) + +var ( + loggingInstance *Logger + loggingOnce sync.Once +) + type Logger struct { messages []string } +type logMessage struct { + Message string `json:"message"` +} + +type errMessage struct { + Error string `json:"error"` +} + +func L() *Logger { + loggingOnce.Do(func() { + loggingInstance = utils.ToPtr(Logger{}) + }) + + return loggingInstance +} + +func (logger *Logger) LogErr(message string, err error) { + msg := errMessage{Error: fmt.Sprintf("%s: %e", message, err)} + logger.LogAny(msg) +} func (logger *Logger) Log(message string) { - logger.messages = append(logger.messages, message) + msg := logMessage{Message: message} + logger.LogAny(msg) } -func (logger *Logger) LogByte(message []byte) { - logger.messages = append(logger.messages, string(message)) + +func (logger *Logger) LogAny(message any) { + jsonData, err := json.Marshal(message) + if err != nil { + logger.log(fmt.Sprintf(`{"error": "Error while trying to log any: %e"}`, err)) + } else { + logger.log(string(jsonData)) + } } -func (logger *Logger) Print() (string, []byte) { - var loggedMessages string - for _, message := range logger.messages { - loggedMessages += message + "\n" +func (logger *Logger) Export() (string, []byte) { + loggedMessages := "[\n" + for i, message := range logger.messages { + if i == len(logger.messages)-1 { + loggedMessages += message + "\n" + } else { + loggedMessages += message + ",\n" + } } - return "log.txt", []byte(loggedMessages) + loggedMessages += "]" + + return "log.json", []byte(loggedMessages) +} + +func (logger *Logger) log(data string) { + logger.messages = append(logger.messages, data) } diff --git a/internal/log/main.go b/internal/log/main.go deleted file mode 100644 index f3b31f0..0000000 --- a/internal/log/main.go +++ /dev/null @@ -1,7 +0,0 @@ -package log - -var GlobalLogger Logger - -func init() { - GlobalLogger = Logger{} -} diff --git a/internal/core/utils.go b/internal/model/naming.go similarity index 51% rename from internal/core/utils.go rename to internal/model/naming.go index 51f0896..882d3ff 100644 --- a/internal/core/utils.go +++ b/internal/model/naming.go @@ -1,24 +1,18 @@ -package core +package model import ( - "bufio" "fmt" - "github.com/sqlc-dev/plugin-sdk-go/plugin" - "golang.org/x/text/cases" - "golang.org/x/text/language" "strings" "unicode" "unicode/utf8" -) -func ModelName(enumName string, schemaName string, conf *Config) string { - if schemaName != "" { - enumName = schemaName + "_" + enumName - } - return SnakeToCamel(enumName, conf) -} + "github.com/rayakame/sqlc-gen-better-python/internal/config" + "github.com/sqlc-dev/plugin-sdk-go/plugin" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) -func SnakeToCamel(s string, conf *Config) string { +func SnakeToCamel(conf *config.Config, s string) string { out := "" s = strings.Map(func(r rune) rune { if unicode.IsLetter(r) { @@ -27,6 +21,7 @@ func SnakeToCamel(s string, conf *Config) string { if unicode.IsDigit(r) { return r } + return rune('_') }, s) for _, p := range strings.Split(s, "_") { @@ -44,20 +39,6 @@ func SnakeToCamel(s string, conf *Config) string { } } -func ColumnName(c *plugin.Column, pos int) string { - if c.Name != "" { - return c.Name - } - return fmt.Sprintf("column_%d", pos+1) -} - -func ParamName(p *plugin.Parameter) string { - if p.Column.Name != "" { - return p.Column.Name - } - return fmt.Sprintf("dollar_%d", p.Number) -} - func UpperSnakeCase(s string) string { result := "" for i, r := range s { @@ -68,27 +49,24 @@ func UpperSnakeCase(s string) string { } } result = strings.ToUpper(result) - return result -} -func SQLToPyFileName(s string) string { - return strings.ReplaceAll(s, ".sql", ".py") + return result } -func SplitLines(s string) []string { - var lines []string - sc := bufio.NewScanner(strings.NewReader(s)) - for sc.Scan() { - lines = append(lines, sc.Text()) +func ColumnName(pluginColumn *plugin.Column, pos int) string { + if pluginColumn.Name != "" { + return pluginColumn.Name } - return lines + + return fmt.Sprintf("column_%d", pos+1) } -func IsInMultipleMaps[K comparable, V any](search K, maps ...map[K]V) bool { - for _, m := range maps { - if _, found := m[search]; found { - return true - } +func ModelName(conf *config.Config, modelName string, schemaName string) string { + name := "" + if schemaName != "" { + name += schemaName + "_" } - return false + name += modelName + + return SnakeToCamel(conf, name) } diff --git a/internal/inflection/singular.go b/internal/model/singular.go similarity index 98% rename from internal/inflection/singular.go rename to internal/model/singular.go index 64e042a..ce9e0ac 100644 --- a/internal/inflection/singular.go +++ b/internal/model/singular.go @@ -1,4 +1,4 @@ -package inflection +package model import ( "strings" diff --git a/internal/model/types.go b/internal/model/types.go new file mode 100644 index 0000000..df6ff31 --- /dev/null +++ b/internal/model/types.go @@ -0,0 +1,36 @@ +package model + +type PyType struct { + SQLType string + Type string + IsNullable bool + IsList bool + IsEnum bool +} + +type Enum struct { + Name string + Constants []EnumConstants +} + +type EnumConstants struct { + Name string + Value string +} + +type Table struct { + Name string + Columns []Column +} + +type Column struct { + Name string + Type PyType +} + +type Query struct { + Cmd string // The command of the query: https://docs.sqlc.dev/en/latest/reference/query-annotations.html + SQL string // The raw SQL of the query + ConstantName string // The name of the constant where the raw SQL will be saved in python + FuncName string // The name of the python function +} diff --git a/internal/transform/enums.go b/internal/transform/enums.go new file mode 100644 index 0000000..f86bb4f --- /dev/null +++ b/internal/transform/enums.go @@ -0,0 +1,43 @@ +package transform + +import ( + "cmp" + "slices" + "strings" + + "github.com/rayakame/sqlc-gen-better-python/internal/model" + "github.com/rayakame/sqlc-gen-better-python/internal/utils" +) + +func (t *Transformer) BuildEnums() []model.Enum { + enums := make([]model.Enum, 0) + for _, schema := range t.req.Catalog.Schemas { + if schema.Name == utils.PgCatalog || schema.Name == utils.InformationSchema { + continue + } + for _, enum := range schema.Enums { + var schemaName string + if schema.Name != t.req.Catalog.DefaultSchema { + schemaName = schema.Name + } + + e := model.Enum{ + Name: model.ModelName(t.config, enum.Name, schemaName), + Constants: make([]model.EnumConstants, 0, len(enum.Vals)), + } + + for _, v := range enum.Vals { + e.Constants = append(e.Constants, model.EnumConstants{ + Name: strings.ToUpper(v), + Value: v, + }) + } + enums = append(enums, e) + } + } + slices.SortFunc(enums, func(a, b model.Enum) int { + return cmp.Compare(a.Name, b.Name) + }) + + return enums +} diff --git a/internal/transform/queries.go b/internal/transform/queries.go new file mode 100644 index 0000000..2377458 --- /dev/null +++ b/internal/transform/queries.go @@ -0,0 +1,32 @@ +package transform + +import ( + "strings" + + "github.com/rayakame/sqlc-gen-better-python/internal/model" +) + +func (t *Transformer) BuildQueries() []model.Query { + queries := make([]model.Query, 0, len(t.req.Queries)) + for _, pluginQuery := range t.req.Queries { + if pluginQuery.Name == "" { + continue + } + if pluginQuery.Cmd == "" { + continue + } + + constantName := model.UpperSnakeCase(pluginQuery.Name) + + query := model.Query{ + Cmd: pluginQuery.Cmd, + SQL: pluginQuery.Text, + ConstantName: constantName, + FuncName: strings.ToLower(constantName), + } + + queries = append(queries, query) + } + + return queries +} diff --git a/internal/transform/tables.go b/internal/transform/tables.go new file mode 100644 index 0000000..b6fb4d8 --- /dev/null +++ b/internal/transform/tables.go @@ -0,0 +1,53 @@ +package transform + +import ( + "cmp" + "slices" + + "github.com/rayakame/sqlc-gen-better-python/internal/model" + "github.com/rayakame/sqlc-gen-better-python/internal/utils" + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +func (t *Transformer) BuildTables() []model.Table { + tables := make([]model.Table, 0) + for _, schema := range t.req.Catalog.Schemas { + if schema.Name == utils.PgCatalog || schema.Name == utils.InformationSchema { + continue + } + for _, table := range schema.Tables { + tables = append(tables, t.buildTable(schema, table)) + } + } + slices.SortFunc(tables, func(a, b model.Table) int { + return cmp.Compare(a.Name, b.Name) + }) + + return tables +} + +func (t *Transformer) buildTable(pluginSchema *plugin.Schema, pluginTable *plugin.Table) model.Table { + var schemaName string + if pluginSchema.Name != t.req.Catalog.DefaultSchema { + schemaName = pluginSchema.Name + } + tableName := model.ModelName(t.config, pluginTable.Rel.Name, schemaName) + if !t.config.EmitExactTableNames { + tableName = model.Singular(model.SingularParams{ + Name: tableName, + Exclusions: t.config.InflectionExcludeTableNames, + }) + } + table := model.Table{ + Name: tableName, + Columns: make([]model.Column, 0, len(pluginTable.Columns)), + } + for i, column := range pluginTable.Columns { + table.Columns = append(table.Columns, model.Column{ + Name: model.ColumnName(column, i), + Type: t.buildPyType(column), + }) + } + + return table +} diff --git a/internal/transform/transformer.go b/internal/transform/transformer.go new file mode 100644 index 0000000..a4fd076 --- /dev/null +++ b/internal/transform/transformer.go @@ -0,0 +1,17 @@ +package transform + +import ( + "github.com/rayakame/sqlc-gen-better-python/internal/config" + "github.com/rayakame/sqlc-gen-better-python/internal/types" + "github.com/sqlc-dev/plugin-sdk-go/plugin" +) + +type Transformer struct { + typeConversionFunc types.TypeConversionFunc + config *config.Config + req *plugin.GenerateRequest +} + +func NewTransformer(conf *config.Config, req *plugin.GenerateRequest, convFunc types.TypeConversionFunc) *Transformer { + return &Transformer{typeConversionFunc: convFunc, config: conf, req: req} +} diff --git a/internal/transform/type.go b/internal/transform/type.go new file mode 100644 index 0000000..01dec9e --- /dev/null +++ b/internal/transform/type.go @@ -0,0 +1,46 @@ +package transform + +import ( + "github.com/rayakame/sqlc-gen-better-python/internal/model" + "github.com/rayakame/sqlc-gen-better-python/internal/utils" + "github.com/sqlc-dev/plugin-sdk-go/plugin" + "github.com/sqlc-dev/plugin-sdk-go/sdk" +) + +func (t *Transformer) convertType(columnType *plugin.Identifier) string { + return t.typeConversionFunc(t.req, t.config, columnType) +} + +func (t *Transformer) buildPyType(pluginColumn *plugin.Column) model.PyType { + columnType := sdk.DataType(pluginColumn.Type) + strType := t.convertType(pluginColumn.Type) + + isEnum := false + + if pluginColumn.Type.Schema == "" { + pluginColumn.Type.Schema = t.req.Catalog.DefaultSchema + } + + for _, schema := range t.req.Catalog.Schemas { + if schema.Name == utils.PgCatalog || schema.Name == utils.InformationSchema { + continue + } + if pluginColumn.Type.Schema != schema.GetName() { + continue + } + + for _, enum := range schema.Enums { + if pluginColumn.Type.Name == enum.Name { + isEnum = true + } + } + } + + return model.PyType{ + SQLType: columnType, + Type: strType, + IsNullable: !pluginColumn.GetNotNull(), + IsList: pluginColumn.GetIsArray() || pluginColumn.GetIsSqlcSlice(), + IsEnum: isEnum, + } +} diff --git a/internal/types/common.go b/internal/types/common.go index 0267933..101ff15 100644 --- a/internal/types/common.go +++ b/internal/types/common.go @@ -1,8 +1,21 @@ package types import ( - "github.com/rayakame/sqlc-gen-better-python/internal/core" + "fmt" + + "github.com/rayakame/sqlc-gen-better-python/internal/config" "github.com/sqlc-dev/plugin-sdk-go/plugin" ) -type TypeConversionFunc func(req *plugin.GenerateRequest, col *plugin.Column, conf *core.Config) string +type TypeConversionFunc func(*plugin.GenerateRequest, *config.Config, *plugin.Identifier) string + +func GetTypeConversionFunc(engine string) (TypeConversionFunc, error) { + switch engine { + case "postgresql": + return PostgresTypeToPython, nil + case "sqlite": + return SqliteTypeToPython, nil + default: + return nil, fmt.Errorf("engine %q is not supported", engine) + } +} diff --git a/internal/types/constants.go b/internal/types/constants.go new file mode 100644 index 0000000..4171e7c --- /dev/null +++ b/internal/types/constants.go @@ -0,0 +1,9 @@ +package types + +const ( + Bool = "bool" + Boolean = "boolean" + Str = "str" + Int = "int" + Float = "float" +) diff --git a/internal/types/postgresql.go b/internal/types/postgresql.go index 61e09a3..65dac47 100644 --- a/internal/types/postgresql.go +++ b/internal/types/postgresql.go @@ -1,30 +1,47 @@ package types import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" + "github.com/rayakame/sqlc-gen-better-python/internal/config" "github.com/rayakame/sqlc-gen-better-python/internal/log" - + "github.com/rayakame/sqlc-gen-better-python/internal/model" + "github.com/rayakame/sqlc-gen-better-python/internal/utils" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func PostgresTypeToPython(req *plugin.GenerateRequest, col *plugin.Column, conf *core.Config) string { - columnType := sdk.DataType(col.Type) - +func PostgresTypeToPython(req *plugin.GenerateRequest, config *config.Config, pluginType *plugin.Identifier) string { + columnType := sdk.DataType(pluginType) switch columnType { - case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": - return "int" - case "float", "double precision", "float8", "pg_catalog.float8", "real", "float4", "pg_catalog.float4": - return "float" + case "serial", + "serial4", + "pg_catalog.serial4", + "bigserial", + "serial8", + "pg_catalog.serial8", + "smallserial", + "serial2", + "pg_catalog.serial2", + "integer", + Int, + "int4", + "pg_catalog.int4", + "bigint", + "int8", + "pg_catalog.int8", + "smallint", + "int2", + "pg_catalog.int2": + return Int + case Float, "double precision", "float8", "pg_catalog.float8", "real", "float4", "pg_catalog.float4": + return Float case "numeric", "pg_catalog.numeric": return "decimal.Decimal" case "money": - return "str" - case "boolean", "bool", "pg_catalog.bool": - return "bool" + return Str + case Boolean, Bool, "pg_catalog.bool": + return Bool case "pg_catalog.json", "json", "jsonb": - return "str" + return Str case "bytea", "blob", "pg_catalog.bytea": return "memoryview" case "date": @@ -36,31 +53,40 @@ func PostgresTypeToPython(req *plugin.GenerateRequest, col *plugin.Column, conf case "interval", "pg_catalog.interval": return "datetime.timedelta" case "text", "pg_catalog.varchar", "bpchar", "pg_catalog.bpchar", "char", "string", "citext": - return "str" + return Str case "uuid": return "uuid.UUID" case "inet", "cidr", "macaddr", "macaddr8": // psycopg2 does have support for ipaddress objects, but it is not enabled by default // // https://www.psycopg.org/docs/extras.html#adapt-network - return "str" + return Str case "ltree", "lquery", "ltxtquery": - return "str" + return Str default: + if pluginType.Schema == "" { + pluginType.Schema = req.Catalog.DefaultSchema + } for _, schema := range req.Catalog.Schemas { - if schema.Name == "pg_catalog" || schema.Name == "information_schema" { + if schema.Name == utils.PgCatalog || schema.Name == utils.InformationSchema { + continue + } + if schema.Name != pluginType.Schema { continue } for _, enum := range schema.Enums { - if columnType == enum.Name { - if schema.Name == req.Catalog.DefaultSchema { - return "models." + core.ModelName(enum.Name, "", conf) - } - return "models." + core.ModelName(enum.Name, schema.Name, conf) + if pluginType.Name != enum.Name { + continue + } + if schema.Name == req.Catalog.DefaultSchema { + return "enums." + model.ModelName(config, enum.Name, "") } + + return "enums." + model.ModelName(config, enum.Name, schema.Name) } } - log.GlobalLogger.Log(fmt.Sprintf("unknown PostgreSQL type: %s", columnType)) + log.L().Log("unknown PostgreSQL type: " + columnType) + return "typing.Any" } } diff --git a/internal/types/sqlite.go b/internal/types/sqlite.go index a15ed2a..bf8d881 100644 --- a/internal/types/sqlite.go +++ b/internal/types/sqlite.go @@ -1,27 +1,26 @@ package types import ( - "fmt" - "github.com/rayakame/sqlc-gen-better-python/internal/core" - "github.com/rayakame/sqlc-gen-better-python/internal/log" "strings" + "github.com/rayakame/sqlc-gen-better-python/internal/config" + "github.com/rayakame/sqlc-gen-better-python/internal/log" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func SqliteTypeToPython(_ *plugin.GenerateRequest, col *plugin.Column, _ *core.Config) string { - columnType := strings.ToLower(sdk.DataType(col.Type)) +func SqliteTypeToPython(_ *plugin.GenerateRequest, _ *config.Config, pluginType *plugin.Identifier) string { + columnType := sdk.DataType(pluginType) switch columnType { - case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8", "bigserial": - return "int" + case Int, "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8", "bigserial": + return Int case "blob": return "memoryview" - case "real", "double", "double precision", "doubleprecision", "float", "numeric": - return "float" - case "boolean", "bool": - return "bool" + case "real", "double", "double precision", "doubleprecision", Float, "numeric": + return Float + case Boolean, Bool: + return Bool case "date": return "datetime.date" case "datetime", "timestamp": @@ -40,10 +39,11 @@ func SqliteTypeToPython(_ *plugin.GenerateRequest, col *plugin.Column, _ *core.C columnType == "text", columnType == "clob", columnType == "json": - return "str" + return Str default: - log.GlobalLogger.Log(fmt.Sprintf("unknown SQLite type: %s", columnType)) + log.L().Log("unknown SQLite type: " + columnType) + return "typing.Any" } } diff --git a/internal/utils/common.go b/internal/utils/common.go new file mode 100644 index 0000000..e747d24 --- /dev/null +++ b/internal/utils/common.go @@ -0,0 +1,5 @@ +package utils + +func ToPtr[T any](t T) *T { + return &t +} diff --git a/internal/utils/constants.go b/internal/utils/constants.go new file mode 100644 index 0000000..0805d55 --- /dev/null +++ b/internal/utils/constants.go @@ -0,0 +1,6 @@ +package utils + +const ( + InformationSchema = "information_schema" + PgCatalog = "pg_catalog" +) diff --git a/plugin/main.go b/plugin/main.go index 3abf2c9..015065e 100644 --- a/plugin/main.go +++ b/plugin/main.go @@ -6,5 +6,5 @@ import ( ) func main() { - codegen.Run(python.Generate) + codegen.Run(python.Handler) } diff --git a/sqlc.yaml b/sqlc.yaml index a8e10ce..7f1134f 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: a44d4e6a61b821cc908cc054e285dc17d250cb354a03417213b9b11407b6450e sql: - schema: test/schema.sql queries: test/queries.sql diff --git a/test/driver_aiosqlite/sqlc-gen-better-python.wasm b/test/driver_aiosqlite/sqlc-gen-better-python.wasm index 56caaa4..14d3d83 100644 Binary files a/test/driver_aiosqlite/sqlc-gen-better-python.wasm and b/test/driver_aiosqlite/sqlc-gen-better-python.wasm differ diff --git a/test/driver_aiosqlite/sqlc.yaml b/test/driver_aiosqlite/sqlc.yaml index 157a046..b287db2 100644 --- a/test/driver_aiosqlite/sqlc.yaml +++ b/test/driver_aiosqlite/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: a44d4e6a61b821cc908cc054e285dc17d250cb354a03417213b9b11407b6450e sql: - schema: schema.sql queries: queries.sql diff --git a/test/driver_asyncpg/sqlc-gen-better-python.wasm b/test/driver_asyncpg/sqlc-gen-better-python.wasm index 56caaa4..14d3d83 100644 Binary files a/test/driver_asyncpg/sqlc-gen-better-python.wasm and b/test/driver_asyncpg/sqlc-gen-better-python.wasm differ diff --git a/test/driver_asyncpg/sqlc.yaml b/test/driver_asyncpg/sqlc.yaml index 90e1aee..1438480 100644 --- a/test/driver_asyncpg/sqlc.yaml +++ b/test/driver_asyncpg/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: a44d4e6a61b821cc908cc054e285dc17d250cb354a03417213b9b11407b6450e sql: - schema: schema.sql queries: queries.sql diff --git a/test/driver_sqlite3/sqlc-gen-better-python.wasm b/test/driver_sqlite3/sqlc-gen-better-python.wasm index 56caaa4..14d3d83 100644 Binary files a/test/driver_sqlite3/sqlc-gen-better-python.wasm and b/test/driver_sqlite3/sqlc-gen-better-python.wasm differ diff --git a/test/driver_sqlite3/sqlc.yaml b/test/driver_sqlite3/sqlc.yaml index 25560b7..f47421c 100644 --- a/test/driver_sqlite3/sqlc.yaml +++ b/test/driver_sqlite3/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: python wasm: url: file://sqlc-gen-better-python.wasm - sha256: ee7bd0c07b784b80ea8c5853d9a6a04c51a7abbfd2663f470e9a5d2f623b967e + sha256: a44d4e6a61b821cc908cc054e285dc17d250cb354a03417213b9b11407b6450e sql: - schema: schema.sql queries: queries.sql diff --git a/test/queries.sql b/test/queries.sql index edfe618..73c82cb 100644 --- a/test/queries.sql +++ b/test/queries.sql @@ -1,40 +1,5 @@ --- name: GetEmbeddedTestPostgresType1 :one -SELECT *, sqlc.embed(test_inner_postgres_types) -FROM test_postgres_types - JOIN test_inner_postgres_types ON test_inner_postgres_types.table_id = test_postgres_types.id LIMIT 1; - - --- name: TestThatIsReallyImportant :many -SELECT timestamp_test FROM test_postgres_types WHERE id = $1; - --- name: GetEmbeddedTestPostgresType2 :one -SELECT test_postgres_types.*, sqlc.embed(test_inner_postgres_types), test_inner_postgres_types.bool_test -FROM test_postgres_types - JOIN test_inner_postgres_types ON test_inner_postgres_types.table_id = test_postgres_types.id LIMIT 1; - --- name: GetEmbeddedTestPostgresType3 :one -SELECT test_postgres_types.id, - test_postgres_types.serial_test, - sqlc.embed(test_inner_postgres_types), - test_inner_postgres_types.bool_test -FROM test_postgres_types - JOIN test_inner_postgres_types ON test_inner_postgres_types.table_id = test_postgres_types.id LIMIT 1; - --- name: GetEmbeddedTestPostgresType4 :one -SELECT sqlc.embed(test_postgres_types), - sqlc.embed(test_inner_postgres_types), - test_inner_postgres_types.bool_test -FROM test_postgres_types - JOIN test_inner_postgres_types ON test_inner_postgres_types.table_id = test_postgres_types.id LIMIT 1; - -- name: TestExecute :exec -INSERT INTO test_postgres_types (id, serial_test, timestamp_test) -VALUES ($1, $2, $3); - --- name: GetAll :many -SELECT * FROM test_postgres_types; +INSERT INTO test_enum (id, b, b2, m) +VALUES ($1, $2, $3, $4); --- name: TTTT :one -SELECT serial_test -FROM test_postgres_types LIMIT 1; diff --git a/test/schema.sql b/test/schema.sql index b249f21..7fb1d11 100644 --- a/test/schema.sql +++ b/test/schema.sql @@ -1,15 +1,23 @@ -CREATE TABLE test_postgres_types +-- Public schema (default) +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + +CREATE TABLE test_enum ( - /* ───────────── Integer family ───────────── */ - id int PRIMARY KEY NOT NULL, - serial_test serial NOT NULL, - timestamp_test timestamp NOT NULL + id int PRIMARY KEY NOT NULL, + b boolean NOT NULL, + b2 boolean, + m mood NOT NULL ); -CREATE TABLE test_inner_postgres_types +-- Custom schema +CREATE SCHEMA IF NOT EXISTS custom; + +CREATE TYPE custom.mood AS ENUM ('sad', 'ok', 'happy'); + +CREATE TABLE custom.test_enum ( - /* ───────────── Integer family ───────────── */ - table_id int NOT NULL, - /* ───────────── Boolean ───────────── */ - bool_test boolean NOT NULL + id int PRIMARY KEY NOT NULL, + b boolean NOT NULL, + b2 boolean, + m custom.mood NOT NULL ); \ No newline at end of file