diff --git a/.golangci.yml b/.golangci.yml index dc7c960..d987d15 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -26,14 +26,16 @@ linters: min-len: 2 min-occurrences: 3 cyclop: - max-complexity: 20 + max-complexity: 30 gocyclo: - min-complexity: 20 + min-complexity: 30 exhaustive: default-signifies-exhaustive: true default-case-required: true lll: line-length: 180 + maintidx: + under: 15 exclusions: generated: lax presets: diff --git a/application.go b/application.go index a5e966e..e055a4a 100644 --- a/application.go +++ b/application.go @@ -10,10 +10,11 @@ import ( "go/types" "log" "os" + "regexp" "strings" "github.com/go-openapi/spec" - "github.com/go-openapi/swag" + "github.com/go-openapi/swag/conv" "golang.org/x/tools/go/packages" ) @@ -21,7 +22,7 @@ import ( const pkgLoadMode = packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo func safeConvert(str string) bool { - b, err := swag.ConvertBool(str) + b, err := conv.ConvertBool(str) if err != nil { return false } @@ -29,7 +30,7 @@ func safeConvert(str string) bool { } // Debug is true when process is run with DEBUG=1 env var. -var Debug = safeConvert(os.Getenv("DEBUG")) +var Debug = safeConvert(os.Getenv("DEBUG")) //nolint:gochecknoglobals // package-level configuration from environment type node uint32 @@ -286,49 +287,51 @@ func (d *entityDecl) HasParameterAnnotation() bool { } func (s *scanCtx) FindDecl(pkgPath, name string) (*entityDecl, bool) { - if pkg, ok := s.app.AllPackages[pkgPath]; ok { - for _, file := range pkg.Syntax { - for _, d := range file.Decls { - gd, ok := d.(*ast.GenDecl) + pkg, ok := s.app.AllPackages[pkgPath] + if !ok { + return nil, false + } + + for _, file := range pkg.Syntax { + for _, d := range file.Decls { + gd, ok := d.(*ast.GenDecl) + if !ok { + continue + } + + for _, sp := range gd.Specs { + ts, ok := sp.(*ast.TypeSpec) + if !ok || ts.Name.Name != name { + continue + } + + def, ok := pkg.TypesInfo.Defs[ts.Name] if !ok { + debugLogf("couldn't find type info for %s", ts.Name) continue } - for _, sp := range gd.Specs { - if ts, ok := sp.(*ast.TypeSpec); ok && ts.Name.Name == name { - def, ok := pkg.TypesInfo.Defs[ts.Name] - if !ok { - debugLogf("couldn't find type info for %s", ts.Name) - - continue - } - - nt, isNamed := def.Type().(*types.Named) - at, isAliased := def.Type().(*types.Alias) - if !isNamed && !isAliased { - debugLogf("%s is not a named or an aliased type but a %T", ts.Name, def.Type()) - - continue - } - - comments := ts.Doc // type ( /* doc */ Foo struct{} ) - if comments == nil { - comments = gd.Doc // /* doc */ type ( Foo struct{} ) - } - - decl := &entityDecl{ - Comments: comments, - Type: nt, - Alias: at, - Ident: ts.Name, - Spec: ts, - File: file, - Pkg: pkg, - } - - return decl, true - } + nt, isNamed := def.Type().(*types.Named) + at, isAliased := def.Type().(*types.Alias) + if !isNamed && !isAliased { + debugLogf("%s is not a named or an aliased type but a %T", ts.Name, def.Type()) + continue } + + comments := ts.Doc // type ( /* doc */ Foo struct{} ) + if comments == nil { + comments = gd.Doc // /* doc */ type ( Foo struct{} ) + } + + return &entityDecl{ + Comments: comments, + Type: nt, + Alias: at, + Ident: ts.Name, + Spec: ts, + File: file, + Pkg: pkg, + }, true } } } @@ -612,64 +615,72 @@ func (a *typeIndex) processPackage(pkg *packages.Package) error { } for _, file := range pkg.Syntax { - n, err := a.detectNodes(file) - if err != nil { + if err := a.processFile(pkg, file); err != nil { return err } + } - if n&metaNode != 0 { - a.Meta = append(a.Meta, metaSection{Comments: file.Doc}) - } + return nil +} - if n&operationNode != 0 { - for _, cmts := range file.Comments { - pp := parsePathAnnotation(rxOperation, cmts.List) - if pp.Method == "" { - continue // not a valid operation - } - if !shouldAcceptTag(pp.Tags, a.includeTags, a.excludeTags) { - debugLogf("operation %s %s is ignored due to tag rules", pp.Method, pp.Path) - continue - } - a.Operations = append(a.Operations, pp) - } - } +func (a *typeIndex) processFile(pkg *packages.Package, file *ast.File) error { + n, err := a.detectNodes(file) + if err != nil { + return err + } - if n&routeNode != 0 { - for _, cmts := range file.Comments { - pp := parsePathAnnotation(rxRoute, cmts.List) - if pp.Method == "" { - continue // not a valid operation - } - if !shouldAcceptTag(pp.Tags, a.includeTags, a.excludeTags) { - debugLogf("operation %s %s is ignored due to tag rules", pp.Method, pp.Path) - continue - } - a.Routes = append(a.Routes, pp) - } + if n&metaNode != 0 { + a.Meta = append(a.Meta, metaSection{Comments: file.Doc}) + } + + if n&operationNode != 0 { + a.Operations = a.collectPathAnnotations(rxOperation, file.Comments, a.Operations) + } + + if n&routeNode != 0 { + a.Routes = a.collectPathAnnotations(rxRoute, file.Comments, a.Routes) + } + + a.processFileDecls(pkg, file, n) + + return nil +} + +func (a *typeIndex) collectPathAnnotations(rx *regexp.Regexp, comments []*ast.CommentGroup, dst []parsedPathContent) []parsedPathContent { + for _, cmts := range comments { + pp := parsePathAnnotation(rx, cmts.List) + if pp.Method == "" { + continue } + if !shouldAcceptTag(pp.Tags, a.includeTags, a.excludeTags) { + debugLogf("operation %s %s is ignored due to tag rules", pp.Method, pp.Path) + continue + } + dst = append(dst, pp) + } + return dst +} - for _, dt := range file.Decls { - switch fd := dt.(type) { - case *ast.BadDecl: +func (a *typeIndex) processFileDecls(pkg *packages.Package, file *ast.File, n node) { + for _, dt := range file.Decls { + switch fd := dt.(type) { + case *ast.BadDecl: + continue + case *ast.FuncDecl: + if fd.Body == nil { continue - case *ast.FuncDecl: - if fd.Body == nil { - continue - } - for _, stmt := range fd.Body.List { - if dstm, ok := stmt.(*ast.DeclStmt); ok { - if gd, isGD := dstm.Decl.(*ast.GenDecl); isGD { - a.processDecl(pkg, file, n, gd) - } + } + for _, stmt := range fd.Body.List { + if dstm, ok := stmt.(*ast.DeclStmt); ok { + if gd, isGD := dstm.Decl.(*ast.GenDecl); isGD { + a.processDecl(pkg, file, n, gd) } } - case *ast.GenDecl: - a.processDecl(pkg, file, n, fd) } + case *ast.GenDecl: + a.processDecl(pkg, file, n, fd) } } - return nil } func (a *typeIndex) processDecl(pkg *packages.Package, file *ast.File, n node, gd *ast.GenDecl) { @@ -748,10 +759,23 @@ func (a *typeIndex) walkImports(pkg *packages.Package) error { return nil } +func checkStructConflict(seenStruct *string, annotation string, text string) error { + if *seenStruct != "" && *seenStruct != annotation { + return fmt.Errorf("classifier: already annotated as %s, can't also be %q - %s: %w", *seenStruct, annotation, text, ErrCodeScan) + } + *seenStruct = annotation + return nil +} + +// detectNodes scans all comment groups in a file and returns a bitmask of +// detected swagger annotation types. Node types like route, operation, and +// meta accumulate freely across comment groups. Struct-level annotations +// (model, parameters, response) are mutually exclusive within a single +// comment group — mixing them is an error. func (a *typeIndex) detectNodes(file *ast.File) (node, error) { var n node for _, comments := range file.Comments { - var seenStruct string + var seenStruct string // tracks the struct annotation for this comment group for _, cline := range comments.List { if cline == nil { continue @@ -764,7 +788,7 @@ func (a *typeIndex) detectNodes(file *ast.File) (node, error) { } matches := rxSwaggerAnnotation.FindStringSubmatch(cline.Text) - if len(matches) < 2 { + if len(matches) < minAnnotationMatch { continue } @@ -775,41 +799,36 @@ func (a *typeIndex) detectNodes(file *ast.File) (node, error) { n |= operationNode case "model": n |= modelNode - if seenStruct == "" || seenStruct == matches[1] { - seenStruct = matches[1] - } else { - return 0, fmt.Errorf("classifier: already annotated as %s, can't also be %q - %s", seenStruct, matches[1], cline.Text) + if err := checkStructConflict(&seenStruct, matches[1], cline.Text); err != nil { + return 0, err } case "meta": n |= metaNode case "parameters": n |= parametersNode - if seenStruct == "" || seenStruct == matches[1] { - seenStruct = matches[1] - } else { - return 0, fmt.Errorf("classifier: already annotated as %s, can't also be %q - %s", seenStruct, matches[1], cline.Text) + if err := checkStructConflict(&seenStruct, matches[1], cline.Text); err != nil { + return 0, err } case "response": n |= responseNode - if seenStruct == "" || seenStruct == matches[1] { - seenStruct = matches[1] - } else { - return 0, fmt.Errorf("classifier: already annotated as %s, can't also be %q - %s", seenStruct, matches[1], cline.Text) + if err := checkStructConflict(&seenStruct, matches[1], cline.Text); err != nil { + return 0, err } - case "strfmt", "name", "discriminated", "file", "enum", "default", "alias", "type": + case "strfmt", paramNameKey, "discriminated", "file", "enum", "default", "alias", "type": // TODO: perhaps collect these and pass along to avoid lookups later on case "allOf": case "ignore": default: - return 0, fmt.Errorf("classifier: unknown swagger annotation %q", matches[1]) + return 0, fmt.Errorf("classifier: unknown swagger annotation %q: %w", matches[1], ErrCodeScan) } } } + return n, nil } func debugLogf(format string, args ...any) { if Debug { - _ = log.Output(2, fmt.Sprintf(format, args...)) + _ = log.Output(logCallerDepth, fmt.Sprintf(format, args...)) } } diff --git a/application_test.go b/application_test.go index 193d3d1..5352321 100644 --- a/application_test.go +++ b/application_test.go @@ -19,16 +19,16 @@ import ( ) var ( - petstoreCtx *scanCtx - classificationCtx *scanCtx + petstoreCtx *scanCtx //nolint:gochecknoglobals // test package cache shared across test functions + classificationCtx *scanCtx //nolint:gochecknoglobals // test package cache shared across test functions ) var ( - enableSpecOutput bool - enableDebug bool + enableSpecOutput bool //nolint:gochecknoglobals // test flag registered in init + enableDebug bool //nolint:gochecknoglobals // test flag registered in init ) -func init() { +func init() { //nolint:gochecknoinits // registers test flags before TestMain flag.BoolVar(&enableSpecOutput, "enable-spec-output", false, "enable spec gen test to write output to a file") flag.BoolVar(&enableDebug, "enable-debug", false, "enable debug output in tests") } @@ -108,7 +108,7 @@ func loadPetstorePkgsCtx(t *testing.T) *scanCtx { return petstoreCtx } -func loadClassificationPkgsCtx(t *testing.T, extra ...string) *scanCtx { +func loadClassificationPkgsCtx(t *testing.T) *scanCtx { t.Helper() if classificationCtx != nil { @@ -116,11 +116,11 @@ func loadClassificationPkgsCtx(t *testing.T, extra ...string) *scanCtx { } sctx, err := newScanCtx(&Options{ - Packages: append([]string{ + Packages: []string{ "./goparsing/classification", "./goparsing/classification/models", "./goparsing/classification/operations", - }, extra...), + }, WorkDir: "fixtures", }) require.NoError(t, err) diff --git a/assertions.go b/assertions.go index b715b6e..e78f760 100644 --- a/assertions.go +++ b/assertions.go @@ -8,14 +8,14 @@ import ( "go/types" ) -type Error string +type testError string -func (e Error) Error() string { +func (e testError) Error() string { return string(e) } const ( - ErrInternal Error = "internal error due to a bug or a mishandling of go types AST. This usually indicates a bug in the scanner" + errInternal testError = "internal error due to a bug or a mishandling of go types AST. This usually indicates a bug in the scanner" ) // code assertions to be explicit about the various expectations when entering a function @@ -25,7 +25,7 @@ func mustNotBeABuiltinType(o *types.TypeName) { return } - panic(fmt.Errorf("type %q expected not to be a builtin: %w", o.Name(), ErrInternal)) + panic(fmt.Errorf("type %q expected not to be a builtin: %w", o.Name(), errInternal)) } func mustHaveRightHandSide(a *types.Alias) { @@ -33,5 +33,5 @@ func mustHaveRightHandSide(a *types.Alias) { return } - panic(fmt.Errorf("type alias %q expected to declare a right-hand-side: %w", a.Obj().Name(), ErrInternal)) + panic(fmt.Errorf("type alias %q expected to declare a right-hand-side: %w", a.Obj().Name(), errInternal)) } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..16a68ff --- /dev/null +++ b/errors.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: Copyright 2015-2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package codescan + +import "errors" + +// ErrCodeScan is the sentinel error for all errors originating from the codescan package. +var ErrCodeScan = errors.New("codescan") diff --git a/fixtures/bugs/3125/full/main.go b/fixtures/bugs/3125/full/main.go index 2efb8c8..06a1ccd 100644 --- a/fixtures/bugs/3125/full/main.go +++ b/fixtures/bugs/3125/full/main.go @@ -6,7 +6,6 @@ import ( "log" "net" "net/http" - "swagger/api" ) diff --git a/fixtures/goparsing/bookings/api.go b/fixtures/goparsing/bookings/api.go index 0c2d34f..a720eec 100644 --- a/fixtures/goparsing/bookings/api.go +++ b/fixtures/goparsing/bookings/api.go @@ -71,5 +71,4 @@ type BookingResponse struct { // Responses: // 200: BookingResponse func bookings(w http.ResponseWriter, r *http.Request) { - } diff --git a/fixtures/goparsing/classification/models/extranomodel.go b/fixtures/goparsing/classification/models/extranomodel.go index 038df40..67de821 100644 --- a/fixtures/goparsing/classification/models/extranomodel.go +++ b/fixtures/goparsing/classification/models/extranomodel.go @@ -18,8 +18,8 @@ import ( "database/sql" "time" - "github.com/go-openapi/strfmt" "github.com/go-openapi/codescan/fixtures/goparsing/classification/transitive/mods" + "github.com/go-openapi/strfmt" ) // A Something struct is used by other structs diff --git a/fixtures/goparsing/classification/models/nomodel.go b/fixtures/goparsing/classification/models/nomodel.go index 11a2385..2750761 100644 --- a/fixtures/goparsing/classification/models/nomodel.go +++ b/fixtures/goparsing/classification/models/nomodel.go @@ -18,8 +18,8 @@ import ( "net/url" "time" - "github.com/go-openapi/strfmt" "github.com/go-openapi/codescan/fixtures/goparsing/classification/transitive/mods" + "github.com/go-openapi/strfmt" ) // NoModel is a struct without an annotation. diff --git a/fixtures/goparsing/classification/operations/noparams.go b/fixtures/goparsing/classification/operations/noparams.go index af93611..0c546ec 100644 --- a/fixtures/goparsing/classification/operations/noparams.go +++ b/fixtures/goparsing/classification/operations/noparams.go @@ -17,9 +17,9 @@ package operations import ( "bytes" - "github.com/go-openapi/strfmt" "github.com/go-openapi/codescan/fixtures/goparsing/classification/models" "github.com/go-openapi/codescan/fixtures/goparsing/classification/transitive/mods" + "github.com/go-openapi/strfmt" ) // MyFileParams contains the uploaded file data diff --git a/fixtures/goparsing/classification/operations/responses.go b/fixtures/goparsing/classification/operations/responses.go index 03d4989..e4dff5d 100644 --- a/fixtures/goparsing/classification/operations/responses.go +++ b/fixtures/goparsing/classification/operations/responses.go @@ -15,8 +15,8 @@ package operations import ( - "github.com/go-openapi/strfmt" "github.com/go-openapi/codescan/fixtures/goparsing/classification/transitive/mods" + "github.com/go-openapi/strfmt" ) // A GenericError is an error that is used when no other error is appropriate diff --git a/fixtures/goparsing/classification/operations/todo_operation.go b/fixtures/goparsing/classification/operations/todo_operation.go index fdb6372..a12c6b5 100644 --- a/fixtures/goparsing/classification/operations/todo_operation.go +++ b/fixtures/goparsing/classification/operations/todo_operation.go @@ -22,7 +22,6 @@ type ListPetParams struct { // ServeAPI serves the API for this record store func ServeAPI(host, basePath string, schemes []string) error { - // swagger:route GET /pets pets users listPets // // Lists pets filtered by some parameters. diff --git a/fixtures/goparsing/classification/operations_body/todo_operation_body.go b/fixtures/goparsing/classification/operations_body/todo_operation_body.go index 2f30270..beb9139 100644 --- a/fixtures/goparsing/classification/operations_body/todo_operation_body.go +++ b/fixtures/goparsing/classification/operations_body/todo_operation_body.go @@ -22,7 +22,6 @@ type ListPetParams struct { // ServeAPI serves the API for this record store func ServeAPI(host, basePath string, schemes []string) error { - // swagger:route GET /pets pets users listPets // // Lists pets filtered by some parameters. diff --git a/fixtures/goparsing/go123/aliased/schema/extra.go b/fixtures/goparsing/go123/aliased/schema/extra.go index db18d3e..39f3652 100644 --- a/fixtures/goparsing/go123/aliased/schema/extra.go +++ b/fixtures/goparsing/go123/aliased/schema/extra.go @@ -86,6 +86,8 @@ type SliceOfStructsAlias = []struct{} // swagger:model type ShouldSee bool -type ShouldNotSee bool -type ShouldNotSeeSlice []int -type ShouldNotSeeMap map[string]int +type ( + ShouldNotSee bool + ShouldNotSeeSlice []int + ShouldNotSeeMap map[string]int +) diff --git a/fixtures/goparsing/invalid_model_param/main.go b/fixtures/goparsing/invalid_model_param/main.go index e7425ef..4db5366 100644 --- a/fixtures/goparsing/invalid_model_param/main.go +++ b/fixtures/goparsing/invalid_model_param/main.go @@ -23,5 +23,4 @@ type ModelAndParams struct { } func main() { - } diff --git a/fixtures/goparsing/invalid_model_response/main.go b/fixtures/goparsing/invalid_model_response/main.go index 98ccd81..200aef7 100644 --- a/fixtures/goparsing/invalid_model_response/main.go +++ b/fixtures/goparsing/invalid_model_response/main.go @@ -23,5 +23,4 @@ type ModelAndResponse struct { } func main() { - } diff --git a/fixtures/goparsing/invalid_param_model/main.go b/fixtures/goparsing/invalid_param_model/main.go index 97b2fb6..414d945 100644 --- a/fixtures/goparsing/invalid_param_model/main.go +++ b/fixtures/goparsing/invalid_param_model/main.go @@ -23,5 +23,4 @@ type ParamsAndModel struct { } func main() { - } diff --git a/fixtures/goparsing/invalid_response_model/main.go b/fixtures/goparsing/invalid_response_model/main.go index 0033ff2..56651a4 100644 --- a/fixtures/goparsing/invalid_response_model/main.go +++ b/fixtures/goparsing/invalid_response_model/main.go @@ -23,5 +23,4 @@ type ResponseAndModel struct { } func main() { - } diff --git a/fixtures/goparsing/petstore/models/order.go b/fixtures/goparsing/petstore/models/order.go index 9a49f5a..b5f56fa 100644 --- a/fixtures/goparsing/petstore/models/order.go +++ b/fixtures/goparsing/petstore/models/order.go @@ -37,7 +37,6 @@ type Order struct { // the items for this order // minimum items: 1 Items []struct { - // the id of the pet to order // // required: true diff --git a/fixtures/goparsing/petstore/petstore-fixture/main.go b/fixtures/goparsing/petstore/petstore-fixture/main.go index 8e3b512..c55c2b7 100644 --- a/fixtures/goparsing/petstore/petstore-fixture/main.go +++ b/fixtures/goparsing/petstore/petstore-fixture/main.go @@ -21,10 +21,8 @@ import ( "github.com/go-openapi/codescan/fixtures/goparsing/petstore/rest" ) -var ( - // Version is a compile time constant, injected at build time - Version string -) +// Version is a compile time constant, injected at build time +var Version string // This is an application that doesn't actually do anything, // it's used for testing the scanner diff --git a/fixtures/goparsing/petstore/rest/app.go b/fixtures/goparsing/petstore/rest/app.go index 06ed116..c49f30c 100644 --- a/fixtures/goparsing/petstore/rest/app.go +++ b/fixtures/goparsing/petstore/rest/app.go @@ -17,8 +17,8 @@ package rest import ( "net/http" - "github.com/go-openapi/runtime/middleware/denco" "github.com/go-openapi/codescan/fixtures/goparsing/petstore/rest/handlers" + "github.com/go-openapi/runtime/middleware/denco" ) // ServeAPI serves this api diff --git a/fixtures/goparsing/petstore/rest/handlers/orders.go b/fixtures/goparsing/petstore/rest/handlers/orders.go index 397b720..2f328a5 100644 --- a/fixtures/goparsing/petstore/rest/handlers/orders.go +++ b/fixtures/goparsing/petstore/rest/handlers/orders.go @@ -17,8 +17,8 @@ package handlers import ( "net/http" - "github.com/go-openapi/runtime/middleware/denco" "github.com/go-openapi/codescan/fixtures/goparsing/petstore/models" + "github.com/go-openapi/runtime/middleware/denco" ) // An OrderID parameter model. diff --git a/fixtures/goparsing/petstore/rest/handlers/pets.go b/fixtures/goparsing/petstore/rest/handlers/pets.go index 73850e3..126e368 100644 --- a/fixtures/goparsing/petstore/rest/handlers/pets.go +++ b/fixtures/goparsing/petstore/rest/handlers/pets.go @@ -18,9 +18,9 @@ import ( "net/http" "time" - "github.com/go-openapi/runtime/middleware/denco" "github.com/go-openapi/codescan/fixtures/goparsing/petstore/enums" "github.com/go-openapi/codescan/fixtures/goparsing/petstore/models" + "github.com/go-openapi/runtime/middleware/denco" ) // A GenericError is the default error message that is generated. diff --git a/fixtures/goparsing/spec/api.go b/fixtures/goparsing/spec/api.go index 674e67c..f6816d0 100644 --- a/fixtures/goparsing/spec/api.go +++ b/fixtures/goparsing/spec/api.go @@ -79,5 +79,4 @@ type BookingResponse struct { // Responses: // 200: BookingResponse func bookings(w http.ResponseWriter, r *http.Request) { - } diff --git a/meta.go b/meta.go index 2f4b3e0..0cf4421 100644 --- a/meta.go +++ b/meta.go @@ -61,7 +61,7 @@ func metaVendorExtensibleSetter(meta *spec.Swagger) func(json.RawMessage) error } for k := range jsonData { if !rxAllowedExtensions.MatchString(k) { - return fmt.Errorf("invalid schema extension name, should start from `x-`: %s", k) + return fmt.Errorf("invalid schema extension name, should start from `x-`: %s: %w", k, ErrCodeScan) } } meta.Extensions = jsonData @@ -78,7 +78,7 @@ func infoVendorExtensibleSetter(meta *spec.Swagger) func(json.RawMessage) error } for k := range jsonData { if !rxAllowedExtensions.MatchString(k) { - return fmt.Errorf("invalid schema extension name, should start from `x-`: %s", k) + return fmt.Errorf("invalid schema extension name, should start from `x-`: %s: %w", k, ErrCodeScan) } } meta.Info.Extensions = jsonData diff --git a/meta_test.go b/meta_test.go index 9f60b7c..8cbe944 100644 --- a/meta_test.go +++ b/meta_test.go @@ -85,7 +85,7 @@ func verifyMeta(t *testing.T, doc *spec.Swagger) { }, } expectedSecuritySchemaOAuth := spec.SecurityScheme{ - SecuritySchemeProps: spec.SecuritySchemeProps{ + SecuritySchemeProps: spec.SecuritySchemeProps{ //nolint:gosec // G101: false positive, test fixture not real credentials Type: "oauth2", In: "header", AuthorizationURL: "/oauth2/auth", diff --git a/operations.go b/operations.go index 44d57f6..2254f17 100644 --- a/operations.go +++ b/operations.go @@ -59,24 +59,31 @@ func parsePathAnnotation(annotation *regexp.Regexp, lines []*ast.Comment) (cnt p txt := cmt.Text for line := range strings.SplitSeq(txt, "\n") { matches := annotation.FindStringSubmatch(line) - if len(matches) > 3 { + if len(matches) > routeTagsIndex { cnt.Method, cnt.Path, cnt.ID = matches[1], matches[2], matches[len(matches)-1] cnt.Tags = rxSpace.Split(matches[3], -1) if len(matches[3]) == 0 { cnt.Tags = nil } justMatched = true - } else if cnt.Method != "" { - if cnt.Remaining == nil { - cnt.Remaining = new(ast.CommentGroup) - } - if !justMatched || strings.TrimSpace(rxStripComments.ReplaceAllString(line, "")) != "" { - cc := new(ast.Comment) - cc.Slash = cmt.Slash - cc.Text = line - cnt.Remaining.List = append(cnt.Remaining.List, cc) - justMatched = false - } + + continue + } + + if cnt.Method == "" { + continue + } + + if cnt.Remaining == nil { + cnt.Remaining = new(ast.CommentGroup) + } + + if !justMatched || strings.TrimSpace(rxStripComments.ReplaceAllString(line, "")) != "" { + cc := new(ast.Comment) + cc.Slash = cmt.Slash + cc.Text = line + cnt.Remaining.List = append(cnt.Remaining.List, cc) + justMatched = false } } } @@ -84,6 +91,16 @@ func parsePathAnnotation(annotation *regexp.Regexp, lines []*ast.Comment) (cnt p return cnt } +// assignOrReuse either reuses an existing operation (if the ID matches) +// or assigns op to the slot. +func assignOrReuse(slot **spec.Operation, op *spec.Operation, id string) *spec.Operation { + if *slot != nil && id == (*slot).ID { + return *slot + } + *slot = op + return op +} + func setPathOperation(method, id string, pthObj *spec.PathItem, op *spec.Operation) *spec.Operation { if op == nil { op = new(spec.Operation) @@ -92,81 +109,19 @@ func setPathOperation(method, id string, pthObj *spec.PathItem, op *spec.Operati switch strings.ToUpper(method) { case "GET": - if pthObj.Get != nil { - if id == pthObj.Get.ID { - op = pthObj.Get - } else { - pthObj.Get = op - } - } else { - pthObj.Get = op - } - + op = assignOrReuse(&pthObj.Get, op, id) case "POST": - if pthObj.Post != nil { - if id == pthObj.Post.ID { - op = pthObj.Post - } else { - pthObj.Post = op - } - } else { - pthObj.Post = op - } - + op = assignOrReuse(&pthObj.Post, op, id) case "PUT": - if pthObj.Put != nil { - if id == pthObj.Put.ID { - op = pthObj.Put - } else { - pthObj.Put = op - } - } else { - pthObj.Put = op - } - + op = assignOrReuse(&pthObj.Put, op, id) case "PATCH": - if pthObj.Patch != nil { - if id == pthObj.Patch.ID { - op = pthObj.Patch - } else { - pthObj.Patch = op - } - } else { - pthObj.Patch = op - } - + op = assignOrReuse(&pthObj.Patch, op, id) case "HEAD": - if pthObj.Head != nil { - if id == pthObj.Head.ID { - op = pthObj.Head - } else { - pthObj.Head = op - } - } else { - pthObj.Head = op - } - + op = assignOrReuse(&pthObj.Head, op, id) case "DELETE": - if pthObj.Delete != nil { - if id == pthObj.Delete.ID { - op = pthObj.Delete - } else { - pthObj.Delete = op - } - } else { - pthObj.Delete = op - } - + op = assignOrReuse(&pthObj.Delete, op, id) case "OPTIONS": - if pthObj.Options != nil { - if id == pthObj.Options.ID { - op = pthObj.Options - } else { - pthObj.Options = op - } - } else { - pthObj.Options = op - } + op = assignOrReuse(&pthObj.Options, op, id) } return op diff --git a/parameters.go b/parameters.go index 24e7e88..a38dcbb 100644 --- a/parameters.go +++ b/parameters.go @@ -5,12 +5,9 @@ package codescan import ( "fmt" - "go/ast" "go/types" "strings" - "golang.org/x/tools/go/ast/astutil" - "github.com/go-openapi/spec" ) @@ -30,7 +27,7 @@ func (pt paramTypable) SetRef(ref spec.Ref) { pt.param.Ref = ref } -func (pt paramTypable) Items() swaggerTypable { +func (pt paramTypable) Items() swaggerTypable { //nolint:ireturn // polymorphic by design bdt, schema := bodyTypable(pt.param.In, pt.param.Schema) if bdt != nil { pt.param.Schema = schema @@ -40,12 +37,12 @@ func (pt paramTypable) Items() swaggerTypable { if pt.param.Items == nil { pt.param.Items = new(spec.Items) } - pt.param.Type = "array" + pt.param.Type = typeArray return itemsTypable{pt.param.Items, 1, pt.param.In} } func (pt paramTypable) Schema() *spec.Schema { - if pt.param.In != "body" { + if pt.param.In != bodyTag { return nil } if pt.param.Schema == nil { @@ -55,7 +52,7 @@ func (pt paramTypable) Schema() *spec.Schema { } func (pt paramTypable) AddExtension(key string, value any) { - if pt.param.In == "body" { + if pt.param.In == bodyTag { pt.Schema().AddExtension(key, value) } else { pt.param.AddExtension(key, value) @@ -95,11 +92,11 @@ func (pt itemsTypable) Schema() *spec.Schema { return nil } -func (pt itemsTypable) Items() swaggerTypable { +func (pt itemsTypable) Items() swaggerTypable { //nolint:ireturn // polymorphic by design if pt.items.Items == nil { pt.items.Items = new(spec.Items) } - pt.items.Type = "array" + pt.items.Type = typeArray return itemsTypable{pt.items.Items, pt.level + 1, pt.in} } @@ -214,14 +211,14 @@ func (p *parameterBuilder) buildFromType(otpe types.Type, op *spec.Operation, se debugLogf("alias(parameters.buildFromType): got alias %v to %v", tpe, tpe.Rhs()) return p.buildAlias(tpe, op, seen) default: - return fmt.Errorf("unhandled type (%T): %s", otpe, tpe.String()) + return fmt.Errorf("unhandled type (%T): %s: %w", otpe, tpe.String(), ErrCodeScan) } } func (p *parameterBuilder) buildNamedType(tpe *types.Named, op *spec.Operation, seen map[string]spec.Parameter) error { o := tpe.Obj() if isAny(o) || isStdError(o) { - return fmt.Errorf("%s type not supported in the context of a parameters section definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a parameters section definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) @@ -234,14 +231,14 @@ func (p *parameterBuilder) buildNamedType(tpe *types.Named, op *spec.Operation, return p.buildFromStruct(p.decl, stpe, op, seen) default: - return fmt.Errorf("unhandled type (%T): %s", stpe, o.Type().Underlying().String()) + return fmt.Errorf("unhandled type (%T): %s: %w", stpe, o.Type().Underlying().String(), ErrCodeScan) } } func (p *parameterBuilder) buildAlias(tpe *types.Alias, op *spec.Operation, seen map[string]spec.Parameter) error { o := tpe.Obj() if isAny(o) || isStdError(o) { - return fmt.Errorf("%s type not supported in the context of a parameters section definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a parameters section definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) mustHaveRightHandSide(tpe) @@ -255,7 +252,7 @@ func (p *parameterBuilder) buildAlias(tpe *types.Alias, op *spec.Operation, seen decl, ok := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v -> %v", tpe, rhs) + return fmt.Errorf("can't find source file for aliased type: %v -> %v: %w", tpe, rhs, ErrCodeScan) } p.postDecls = append(p.postDecls, decl) // mark the left-hand side as discovered @@ -268,7 +265,7 @@ func (p *parameterBuilder) buildAlias(tpe *types.Alias, op *spec.Operation, seen } decl, found := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } p.postDecls = append(p.postDecls, decl) case *types.Alias: @@ -278,7 +275,7 @@ func (p *parameterBuilder) buildAlias(tpe *types.Alias, op *spec.Operation, seen } decl, found := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } p.postDecls = append(p.postDecls, decl) } @@ -310,7 +307,7 @@ func (p *parameterBuilder) buildFromField(fld *types.Var, tpe types.Type, typabl debugLogf("alias(parameters.buildFromField): got alias %v to %v", ftpe, ftpe.Rhs()) // TODO return p.buildFieldAlias(ftpe, typable, fld, seen) default: - return fmt.Errorf("unknown type for %s: %T", fld.String(), fld.Type()) + return fmt.Errorf("unknown type for %s: %T: %w", fld.String(), fld.Type(), ErrCodeScan) } } @@ -368,13 +365,13 @@ func (p *parameterBuilder) buildNamedField(ftpe *types.Named, typable swaggerTyp return nil } if isStdError(o) { - return fmt.Errorf("%s type not supported in the context of a parameter definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a parameter definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) decl, found := p.ctx.DeclForType(o.Type()) if !found { - return fmt.Errorf("unable to find package and source file for: %s", ftpe.String()) + return fmt.Errorf("unable to find package and source file for: %s: %w", ftpe.String(), ErrCodeScan) } if isStdTime(o) { @@ -407,7 +404,7 @@ func (p *parameterBuilder) buildFieldAlias(tpe *types.Alias, typable swaggerTypa return nil // just leave an empty schema } if isStdError(o) { - return fmt.Errorf("%s type not supported in the context of a parameter definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a parameter definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) mustHaveRightHandSide(tpe) @@ -429,11 +426,11 @@ func (p *parameterBuilder) buildFieldAlias(tpe *types.Alias, typable swaggerTypa decl, ok := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v -> %v", tpe, rhs) + return fmt.Errorf("can't find source file for aliased type: %v -> %v: %w", tpe, rhs, ErrCodeScan) } p.postDecls = append(p.postDecls, decl) // mark the left-hand side as discovered - if typable.In() != "body" || !p.ctx.app.refAliases { + if typable.In() != bodyTag || !p.ctx.app.refAliases { // if ref option is disabled, and always for non-body parameters: just expand the alias unaliased := types.Unalias(tpe) return p.buildFromField(fld, unaliased, typable, seen) @@ -450,7 +447,7 @@ func (p *parameterBuilder) buildFieldAlias(tpe *types.Alias, typable swaggerTypa decl, found := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } return p.makeRef(decl, typable) @@ -462,7 +459,7 @@ func (p *parameterBuilder) buildFieldAlias(tpe *types.Alias, typable swaggerTypa decl, found := p.ctx.FindModel(o.Pkg().Path(), o.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } return p.makeRef(decl, typable) @@ -498,210 +495,121 @@ func (p *parameterBuilder) buildFromStruct(decl *entityDecl, tpe *types.Struct, continue } - if !fld.Exported() { - debugLogf("skipping field %s because it's not exported", fld.Name()) - continue - } - - tg := tpe.Tag(i) - - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } - - debugLogf("field %s: %s(%T) [%q] ==> %s", fld.Name(), fld.Type().String(), fld.Type(), tg, at.Doc.Text()) - afld = at - break - } - - if afld == nil { - debugLogf("can't find source associated with %s for %s", fld.String(), tpe.String()) - continue - } - - // if the field is annotated with swagger:ignore, ignore it - if ignored(afld.Doc) { - continue - } - - name, ignore, _, _, err := parseJSONTag(afld) + name, err := p.processParamField(fld, decl, seen) if err != nil { return err } - if ignore { - continue + if name != "" { + sequence = append(sequence, name) } + } - in := "query" - // scan for param location first, this changes some behavior down the line - if afld.Doc != nil { - for _, cmt := range afld.Doc.List { - for line := range strings.SplitSeq(cmt.Text, "\n") { - matches := rxIn.FindStringSubmatch(line) - if len(matches) > 0 && len(strings.TrimSpace(matches[1])) > 0 { - in = strings.TrimSpace(matches[1]) - } - } + for _, k := range sequence { + p := seen[k] + for i, v := range op.Parameters { + if v.Name == k { + op.Parameters = append(op.Parameters[:i], op.Parameters[i+1:]...) + break } } + op.Parameters = append(op.Parameters, p) + } - ps := seen[name] - ps.In = in - var pty swaggerTypable = paramTypable{&ps} - if in == "body" { - pty = schemaTypable{pty.Schema(), 0} - } - if in == "formData" && afld.Doc != nil && fileParam(afld.Doc) { - pty.Typed("file", "") - } else if err := p.buildFromField(fld, fld.Type(), pty, seen); err != nil { - return err - } + return nil +} - if strfmtName, ok := strfmtName(afld.Doc); ok { - ps.Typed("string", strfmtName) - ps.Ref = spec.Ref{} - ps.Items = nil - } +// processParamField processes a single non-embedded struct field for parameter building. +// Returns the parameter name if the field was processed, or "" if it was skipped. +func (p *parameterBuilder) processParamField(fld *types.Var, decl *entityDecl, seen map[string]spec.Parameter) (string, error) { + if !fld.Exported() { + debugLogf("skipping field %s because it's not exported", fld.Name()) + return "", nil + } - sp := new(sectionedParser) - sp.setDescription = func(lines []string) { - ps.Description = joinDropLast(lines) - enumDesc := getEnumDesc(ps.Extensions) - if enumDesc != "" { - ps.Description += "\n" + enumDesc - } - } - if ps.Ref.String() == "" { - sp.taggers = []tagParser{ - newSingleLineTagParser("in", &matchOnlyParam{&ps, rxIn}), - newSingleLineTagParser("maximum", &setMaximum{paramValidations{&ps}, rxf(rxMaximumFmt, "")}), - newSingleLineTagParser("minimum", &setMinimum{paramValidations{&ps}, rxf(rxMinimumFmt, "")}), - newSingleLineTagParser("multipleOf", &setMultipleOf{paramValidations{&ps}, rxf(rxMultipleOfFmt, "")}), - newSingleLineTagParser("minLength", &setMinLength{paramValidations{&ps}, rxf(rxMinLengthFmt, "")}), - newSingleLineTagParser("maxLength", &setMaxLength{paramValidations{&ps}, rxf(rxMaxLengthFmt, "")}), - newSingleLineTagParser("pattern", &setPattern{paramValidations{&ps}, rxf(rxPatternFmt, "")}), - newSingleLineTagParser("collectionFormat", &setCollectionFormat{paramValidations{&ps}, rxf(rxCollectionFormatFmt, "")}), - newSingleLineTagParser("minItems", &setMinItems{paramValidations{&ps}, rxf(rxMinItemsFmt, "")}), - newSingleLineTagParser("maxItems", &setMaxItems{paramValidations{&ps}, rxf(rxMaxItemsFmt, "")}), - newSingleLineTagParser("unique", &setUnique{paramValidations{&ps}, rxf(rxUniqueFmt, "")}), - newSingleLineTagParser("enum", &setEnum{paramValidations{&ps}, rxf(rxEnumFmt, "")}), - newSingleLineTagParser("default", &setDefault{&ps.SimpleSchema, paramValidations{&ps}, rxf(rxDefaultFmt, "")}), - newSingleLineTagParser("example", &setExample{&ps.SimpleSchema, paramValidations{&ps}, rxf(rxExampleFmt, "")}), - newSingleLineTagParser("required", &setRequiredParam{&ps}), - newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(&ps)), true), - } + afld := findASTField(decl.File, fld.Pos()) + if afld == nil { + debugLogf("can't find source associated with %s", fld.String()) + return "", nil + } - itemsTaggers := func(items *spec.Items, level int) []tagParser { - // the expression is 1-index based not 0-index - itemsPrefix := fmt.Sprintf(rxItemsPrefixFmt, level+1) - - return []tagParser{ - newSingleLineTagParser(fmt.Sprintf("items%dMaximum", level), &setMaximum{itemsValidations{items}, rxf(rxMaximumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinimum", level), &setMinimum{itemsValidations{items}, rxf(rxMinimumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMultipleOf", level), &setMultipleOf{itemsValidations{items}, rxf(rxMultipleOfFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinLength", level), &setMinLength{itemsValidations{items}, rxf(rxMinLengthFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMaxLength", level), &setMaxLength{itemsValidations{items}, rxf(rxMaxLengthFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dPattern", level), &setPattern{itemsValidations{items}, rxf(rxPatternFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dCollectionFormat", level), &setCollectionFormat{itemsValidations{items}, rxf(rxCollectionFormatFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinItems", level), &setMinItems{itemsValidations{items}, rxf(rxMinItemsFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMaxItems", level), &setMaxItems{itemsValidations{items}, rxf(rxMaxItemsFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dUnique", level), &setUnique{itemsValidations{items}, rxf(rxUniqueFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dEnum", level), &setEnum{itemsValidations{items}, rxf(rxEnumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), &setDefault{&items.SimpleSchema, itemsValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dExample", level), &setExample{&items.SimpleSchema, itemsValidations{items}, rxf(rxExampleFmt, itemsPrefix)}), - } - } + if ignored(afld.Doc) { + return "", nil + } - var parseArrayTypes func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) - parseArrayTypes = func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) { - if items == nil { - return []tagParser{}, nil - } - switch iftpe := expr.(type) { - case *ast.ArrayType: - eleTaggers := itemsTaggers(items, level) - sp.taggers = append(eleTaggers, sp.taggers...) - otherTaggers, err := parseArrayTypes(iftpe.Elt, items.Items, level+1) - if err != nil { - return nil, err - } - return otherTaggers, nil - case *ast.SelectorExpr: - otherTaggers, err := parseArrayTypes(iftpe.Sel, items.Items, level+1) - if err != nil { - return nil, err - } - return otherTaggers, nil - case *ast.Ident: - taggers := []tagParser{} - if iftpe.Obj == nil { - taggers = itemsTaggers(items, level) - } - otherTaggers, err := parseArrayTypes(expr, items.Items, level+1) - if err != nil { - return nil, err - } - return append(taggers, otherTaggers...), nil - case *ast.StarExpr: - otherTaggers, err := parseArrayTypes(iftpe.X, items, level) - if err != nil { - return nil, err - } - return otherTaggers, nil - default: - return nil, fmt.Errorf("unknown field type ele for %q", name) - } - } + name, ignore, _, _, err := parseJSONTag(afld) + if err != nil { + return "", err + } + if ignore { + return "", nil + } - // check if this is a primitive, if so parse the validations from the - // doc comments of the slice declaration. - if ftped, ok := afld.Type.(*ast.ArrayType); ok { - taggers, err := parseArrayTypes(ftped.Elt, ps.Items, 0) - if err != nil { - return err + in := paramInQuery + // scan for param location first, this changes some behavior down the line + if afld.Doc != nil { + for _, cmt := range afld.Doc.List { + for line := range strings.SplitSeq(cmt.Text, "\n") { + matches := rxIn.FindStringSubmatch(line) + if len(matches) > 0 && len(strings.TrimSpace(matches[1])) > 0 { + in = strings.TrimSpace(matches[1]) } - sp.taggers = append(taggers, sp.taggers...) - } - } else { - sp.taggers = []tagParser{ - newSingleLineTagParser("in", &matchOnlyParam{&ps, rxIn}), - newSingleLineTagParser("required", &matchOnlyParam{&ps, rxRequired}), - newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(&ps)), true), } } - if err := sp.Parse(afld.Doc); err != nil { - return err - } - if ps.In == "path" { - ps.Required = true - } + } - if ps.Name == "" { - ps.Name = name - } + ps := seen[name] + ps.In = in + var pty swaggerTypable = paramTypable{&ps} + if in == bodyTag { + pty = schemaTypable{pty.Schema(), 0} + } + + if in == "formData" && afld.Doc != nil && fileParam(afld.Doc) { + pty.Typed("file", "") + } else if err := p.buildFromField(fld, fld.Type(), pty, seen); err != nil { + return "", err + } + + if strfmtName, ok := strfmtName(afld.Doc); ok { + ps.Typed("string", strfmtName) + ps.Ref = spec.Ref{} + ps.Items = nil + } - if name != fld.Name() { - addExtension(&ps.VendorExtensible, "x-go-name", fld.Name()) + sp := new(sectionedParser) + sp.setDescription = func(lines []string) { + ps.Description = joinDropLast(lines) + enumDesc := getEnumDesc(ps.Extensions) + if enumDesc != "" { + ps.Description += "\n" + enumDesc } - seen[name] = ps - sequence = append(sequence, name) } - for _, k := range sequence { - p := seen[k] - for i, v := range op.Parameters { - if v.Name == k { - op.Parameters = append(op.Parameters[:i], op.Parameters[i+1:]...) - break - } + if ps.Ref.String() != "" { + setupRefParamTaggers(sp, &ps) + } else { + if err := setupInlineParamTaggers(sp, &ps, name, afld); err != nil { + return "", err } - op.Parameters = append(op.Parameters, p) } - return nil + + if err := sp.Parse(afld.Doc); err != nil { + return "", err + } + if ps.In == "path" { + ps.Required = true + } + + if ps.Name == "" { + ps.Name = name + } + + if name != fld.Name() { + addExtension(&ps.VendorExtensible, "x-go-name", fld.Name()) + } + + seen[name] = ps + return name, nil } func (p *parameterBuilder) makeRef(decl *entityDecl, prop swaggerTypable) error { diff --git a/parameters_test.go b/parameters_test.go index 7c88fa6..1297f57 100644 --- a/parameters_test.go +++ b/parameters_test.go @@ -13,7 +13,15 @@ import ( ) const ( - gcBadEnum = "bad_enum" + gcBadEnum = "bad_enum" + paramID = "id" + paramAge = "age" + paramExtra = "extra" + paramScore = "score" + paramHdrName = "x-hdr-name" + paramCreated = "created" + paramFooSlice = "foo_slice" + paramBarSlice = "bar_slice" ) func getParameter(sctx *scanCtx, nm string) *entityDecl { @@ -29,7 +37,11 @@ func getParameter(sctx *scanCtx, nm string) *entityDecl { func TestScanFileParam(t *testing.T) { sctx := loadClassificationPkgsCtx(t) operations := make(map[string]*spec.Operation) - for _, rn := range []string{"OrderBodyParams", "MultipleOrderParams", "ComplexerOneParams", "NoParams", "NoParamsAlias", "MyFileParams", "MyFuncFileParams", "EmbeddedFileParams"} { + paramNames := []string{ + "OrderBodyParams", "MultipleOrderParams", "ComplexerOneParams", "NoParams", + "NoParamsAlias", "MyFileParams", "MyFuncFileParams", "EmbeddedFileParams", + } + for _, rn := range paramNames { td := getParameter(sctx, rn) prs := ¶meterBuilder{ @@ -76,7 +88,11 @@ func TestScanFileParam(t *testing.T) { func TestParamsParser(t *testing.T) { sctx := loadClassificationPkgsCtx(t) operations := make(map[string]*spec.Operation) - for _, rn := range []string{"OrderBodyParams", "MultipleOrderParams", "ComplexerOneParams", "NoParams", "NoParamsAlias", "MyFileParams", "MyFuncFileParams", "EmbeddedFileParams"} { + paramNames := []string{ + "OrderBodyParams", "MultipleOrderParams", "ComplexerOneParams", "NoParams", + "NoParamsAlias", "MyFileParams", "MyFuncFileParams", "EmbeddedFileParams", + } + for _, rn := range paramNames { td := getParameter(sctx, rn) prs := ¶meterBuilder{ @@ -87,24 +103,63 @@ func TestParamsParser(t *testing.T) { } assert.Len(t, operations, 10) + + t.Run("yetAnotherOperation", func(t *testing.T) { + assertYetAnotherOperationParams(t, operations) + }) + + ob, okParam := operations["updateOrder"] + assert.TrueT(t, okParam) + assert.Len(t, ob.Parameters, 1) + bodyParam := ob.Parameters[0] + assert.EqualT(t, "The order to submit.", bodyParam.Description) + assert.EqualT(t, "body", bodyParam.In) + assert.EqualT(t, "#/definitions/order", bodyParam.Schema.Ref.String()) + assert.TrueT(t, bodyParam.Required) + + mop, okParam := operations["getOrders"] + assert.TrueT(t, okParam) + require.Len(t, mop.Parameters, 2) + ordersParam := mop.Parameters[0] + assert.EqualT(t, "The orders", ordersParam.Description) + assert.TrueT(t, ordersParam.Required) + assert.EqualT(t, "array", ordersParam.Type) + otherParam := mop.Parameters[1] + assert.EqualT(t, "And another thing", otherParam.Description) + + t.Run("someOperation", func(t *testing.T) { + assertSomeOperationParams(t, operations) + }) + + t.Run("anotherOperation parameter order", func(t *testing.T) { + assertAnotherOperationParamOrder(t, operations) + }) + + t.Run("someAliasOperation", func(t *testing.T) { + assertSomeAliasOperationParams(t, operations) + }) +} + +func assertYetAnotherOperationParams(t *testing.T, operations map[string]*spec.Operation) { + t.Helper() cr, okParam := operations["yetAnotherOperation"] require.TrueT(t, okParam) assert.Len(t, cr.Parameters, 8) for _, param := range cr.Parameters { switch param.Name { - case "id": + case paramID: assert.EqualT(t, "integer", param.Type) assert.EqualT(t, "int64", param.Format) - case "name": + case paramNameKey: assert.EqualT(t, "string", param.Type) assert.Empty(t, param.Format) - case "age": + case paramAge: assert.EqualT(t, "integer", param.Type) assert.EqualT(t, "int32", param.Format) case "notes": assert.EqualT(t, "string", param.Type) assert.Empty(t, param.Format) - case "extra": + case paramExtra: assert.EqualT(t, "string", param.Type) assert.Empty(t, param.Format) case "createdAt": @@ -120,33 +175,17 @@ func TestParamsParser(t *testing.T) { assert.Fail(t, "unknown property: "+param.Name) } } +} - ob, okParam := operations["updateOrder"] - assert.TrueT(t, okParam) - assert.Len(t, ob.Parameters, 1) - bodyParam := ob.Parameters[0] - assert.EqualT(t, "The order to submit.", bodyParam.Description) - assert.EqualT(t, "body", bodyParam.In) - assert.EqualT(t, "#/definitions/order", bodyParam.Schema.Ref.String()) - assert.TrueT(t, bodyParam.Required) - - mop, okParam := operations["getOrders"] - assert.TrueT(t, okParam) - require.Len(t, mop.Parameters, 2) - ordersParam := mop.Parameters[0] - assert.EqualT(t, "The orders", ordersParam.Description) - assert.TrueT(t, ordersParam.Required) - assert.EqualT(t, "array", ordersParam.Type) - otherParam := mop.Parameters[1] - assert.EqualT(t, "And another thing", otherParam.Description) - +func assertSomeOperationParams(t *testing.T, operations map[string]*spec.Operation) { + t.Helper() op, okParam := operations["someOperation"] assert.TrueT(t, okParam) assert.Len(t, op.Parameters, 12) for _, param := range op.Parameters { switch param.Name { - case "id": + case paramID: assert.EqualT(t, "ID of this no model instance.\nids in this application start at 11 and are smaller than 1000", param.Description) assert.EqualT(t, "path", param.In) assert.EqualT(t, "integer", param.Type) @@ -161,7 +200,7 @@ func TestParamsParser(t *testing.T) { assert.TrueT(t, param.ExclusiveMinimum) assert.Equal(t, 1, param.Default, "%s default value is incorrect", param.Name) - case "score": + case paramScore: assert.EqualT(t, "The Score of this model", param.Description) assert.EqualT(t, "query", param.In) assert.EqualT(t, "integer", param.Type) @@ -177,7 +216,7 @@ func TestParamsParser(t *testing.T) { assert.EqualValues(t, 2, param.Default, "%s default value is incorrect", param.Name) assert.EqualValues(t, 27, param.Example) - case "x-hdr-name": + case paramHdrName: assert.EqualT(t, "Name of this no model instance", param.Description) assert.EqualT(t, "header", param.In) assert.EqualT(t, "string", param.Type) @@ -189,7 +228,7 @@ func TestParamsParser(t *testing.T) { assert.EqualT(t, int64(50), *param.MaxLength) assert.EqualT(t, "[A-Za-z0-9-.]*", param.Pattern) - case "created": + case paramCreated: assert.EqualT(t, "Created holds the time when this entry was created", param.Description) assert.EqualT(t, "query", param.In) assert.EqualT(t, "string", param.Type) @@ -218,7 +257,7 @@ func TestParamsParser(t *testing.T) { assert.EqualT(t, "query", param.In) assert.EqualT(t, "integer", param.Type) assert.Equal(t, []any{1, 3, 5}, param.Enum, "%s enum values are incorrect", param.Name) - case "type": + case paramTypeKey: assert.EqualT(t, "Type of this model", param.Description) assert.EqualT(t, "query", param.In) assert.EqualT(t, "integer", param.Type) @@ -227,7 +266,7 @@ func TestParamsParser(t *testing.T) { assert.EqualT(t, "query", param.In) assert.EqualT(t, "integer", param.Type) assert.Equal(t, []any{1, "rsq", "qaz"}, param.Enum, "%s enum values are incorrect", param.Name) - case "foo_slice": + case paramFooSlice: assert.EqualT(t, "a FooSlice has foos which are strings", param.Description) assert.Equal(t, "FooSlice", param.Extensions["x-go-name"]) assert.EqualT(t, "query", param.In) @@ -293,7 +332,7 @@ func TestParamsParser(t *testing.T) { assert.TrueT(t, ok) assert.EqualT(t, "Notes to add to this item.\nThis can be used to add special instructions.", iprop.Description) - case "bar_slice": + case paramBarSlice: assert.EqualT(t, "a BarSlice has bars which are strings", param.Description) assert.Equal(t, "BarSlice", param.Extensions["x-go-name"]) assert.EqualT(t, "query", param.In) @@ -333,21 +372,24 @@ func TestParamsParser(t *testing.T) { assert.Fail(t, "unknown property: "+param.Name) } } +} + +func assertAnotherOperationParamOrder(t *testing.T, operations map[string]*spec.Operation) { + t.Helper() - // assert that the order of the parameters is maintained order, ok := operations["anotherOperation"] assert.TrueT(t, ok) assert.Len(t, order.Parameters, 12) for index, param := range order.Parameters { switch param.Name { - case "id": + case paramID: assert.EqualT(t, 0, index, "%s index incorrect", param.Name) - case "score": + case paramScore: assert.EqualT(t, 1, index, "%s index incorrect", param.Name) - case "x-hdr-name": + case paramHdrName: assert.EqualT(t, 2, index, "%s index incorrect", param.Name) - case "created": + case paramCreated: assert.EqualT(t, 3, index, "%s index incorrect", param.Name) case "category_old": assert.EqualT(t, 4, index, "%s index incorrect", param.Name) @@ -355,13 +397,13 @@ func TestParamsParser(t *testing.T) { assert.EqualT(t, 5, index, "%s index incorrect", param.Name) case "type_old": assert.EqualT(t, 6, index, "%s index incorrect", param.Name) - case "type": + case paramTypeKey: assert.EqualT(t, 7, index, "%s index incorrect", param.Name) case gcBadEnum: assert.EqualT(t, 8, index, "%s index incorrect", param.Name) - case "foo_slice": + case paramFooSlice: assert.EqualT(t, 9, index, "%s index incorrect", param.Name) - case "bar_slice": + case paramBarSlice: assert.EqualT(t, 10, index, "%s index incorrect", param.Name) case "items": assert.EqualT(t, 11, index, "%s index incorrect", param.Name) @@ -369,8 +411,11 @@ func TestParamsParser(t *testing.T) { assert.Fail(t, "unknown property: "+param.Name) } } +} + +func assertSomeAliasOperationParams(t *testing.T, operations map[string]*spec.Operation) { + t.Helper() - // check that aliases work correctly aliasOp, ok := operations["someAliasOperation"] assert.TrueT(t, ok) assert.Len(t, aliasOp.Parameters, 4) diff --git a/parser.go b/parser.go index b85eeea..d3c1827 100644 --- a/parser.go +++ b/parser.go @@ -5,7 +5,6 @@ package codescan import ( "encoding/json" - "errors" "fmt" "go/ast" "go/types" @@ -22,6 +21,25 @@ import ( "github.com/go-openapi/spec" ) +const ( + // Go builtin type names used for type-to-schema mapping. + goTypeByte = "byte" + goTypeFloat64 = "float64" + goTypeInt = "int" + goTypeInt16 = "int16" + goTypeInt32 = "int32" + goTypeInt64 = "int64" + + // kvParts is the number of parts when splitting key:value pairs. + kvParts = 2 + // logCallerDepth is the caller depth for log.Output. + logCallerDepth = 2 + // minAnnotationMatch is the minimum submatch count for an annotation regex. + minAnnotationMatch = 2 + // routeTagsIndex is the regex submatch index where route tags begin. + routeTagsIndex = 3 +) + func shouldAcceptTag(tags []string, includeTags map[string]bool, excludeTags map[string]bool) bool { for _, tag := range tags { if len(includeTags) > 0 { @@ -184,12 +202,12 @@ type swaggerTypable interface { // See https://golang.org/pkg/builtin/ and http://swagger.io/specification/ func swaggerSchemaForType(typeName string, prop swaggerTypable) error { switch typeName { - case "bool": + case typeBool: prop.Typed("boolean", "") - case "byte": + case goTypeByte: prop.Typed("integer", "uint8") case "complex128", "complex64": - return fmt.Errorf("unsupported builtin %q (no JSON marshaller)", typeName) + return fmt.Errorf("unsupported builtin %q (no JSON marshaller): %w", typeName, ErrCodeScan) case "error": // TODO: error is often marshalled into a string but not always (e.g. errors package creates // errors that are marshalled into an empty object), this could be handled the same way @@ -197,21 +215,21 @@ func swaggerSchemaForType(typeName string, prop swaggerTypable) error { prop.Typed("string", "") case "float32": prop.Typed("number", "float") - case "float64": + case goTypeFloat64: prop.Typed("number", "double") - case "int": - prop.Typed("integer", "int64") - case "int16": - prop.Typed("integer", "int16") - case "int32": - prop.Typed("integer", "int32") - case "int64": - prop.Typed("integer", "int64") + case goTypeInt: + prop.Typed("integer", goTypeInt64) + case goTypeInt16: + prop.Typed("integer", goTypeInt16) + case goTypeInt32: + prop.Typed("integer", goTypeInt32) + case goTypeInt64: + prop.Typed("integer", goTypeInt64) case "int8": prop.Typed("integer", "int8") case "rune": - prop.Typed("integer", "int32") - case "string": + prop.Typed("integer", goTypeInt32) + case typeString: prop.Typed("string", "") case "uint": prop.Typed("integer", "uint64") @@ -228,7 +246,7 @@ func swaggerSchemaForType(typeName string, prop swaggerTypable) error { case "object": prop.Typed("object", "") default: - return fmt.Errorf("unsupported type %q", typeName) + return fmt.Errorf("unsupported type %q: %w", typeName, ErrCodeScan) } return nil } @@ -284,7 +302,7 @@ func (y *yamlParser) Parse(lines []string) error { return nil } - var uncommented []string + uncommented := make([]string, 0, len(lines)) uncommented = append(uncommented, removeYamlIndent(lines)...) yamlContent := strings.Join(uncommented, "\n") @@ -371,11 +389,11 @@ COMMENTS: func (sp *yamlSpecScanner) UnmarshalSpec(u func([]byte) error) (err error) { specYaml := cleanupScannerLines(sp.yamlSpec, rxUncommentYAML) if len(specYaml) == 0 { - return errors.New("no spec available to unmarshal") + return fmt.Errorf("no spec available to unmarshal: %w", ErrCodeScan) } if !strings.Contains(specYaml[0], "---") { - return errors.New("yaml spec has to start with `---`") + return fmt.Errorf("yaml spec has to start with `---`: %w", ErrCodeScan) } // remove indentation @@ -446,7 +464,7 @@ func removeIndent(spec []string) []string { continue } - s[i] = spec[i][loc[1]-1:] + s[i] = spec[i][loc[1]-1:] //nolint:gosec // G602: bounds already checked on line 445 start := rxNotIndent.FindStringIndex(s[i]) if len(start) < 2 || start[1] == 0 { continue @@ -512,70 +530,24 @@ func (st *sectionedParser) Parse(doc *ast.CommentGroup) error { if doc == nil { return nil } + COMMENTS: for _, c := range doc.List { for line := range strings.SplitSeq(c.Text, "\n") { - if rxSwaggerAnnotation.MatchString(line) { - if rxIgnoreOverride.MatchString(line) { - st.ignored = true - break COMMENTS // an explicit ignore terminates this parser - } - if st.annotation == nil || !st.annotation.Matches(line) { - break COMMENTS // a new swagger: annotation terminates this parser - } - - _ = st.annotation.Parse([]string{line}) - if len(st.header) > 0 { - st.seenTag = true - } - continue - } - - var matched bool - for _, tg := range st.taggers { - tagger := tg - if tagger.Matches(line) { - st.seenTag = true - st.currentTagger = &tagger - matched = true - break - } - } - - if st.currentTagger == nil { - if !st.skipHeader && !st.seenTag { - st.header = append(st.header, line) - } - // didn't match a tag, moving on - continue - } - - if st.currentTagger.MultiLine && matched { - // the first line of a multiline tagger doesn't count - continue - } - - ts, ok := st.matched[st.currentTagger.Name] - if !ok { - ts = *st.currentTagger - } - ts.Lines = append(ts.Lines, line) - if st.matched == nil { - st.matched = make(map[string]tagParser) - } - st.matched[st.currentTagger.Name] = ts - - if !st.currentTagger.MultiLine { - st.currentTagger = nil + if st.parseLine(line) { + break COMMENTS } } } + if st.setTitle != nil { st.setTitle(st.Title()) } + if st.setDescription != nil { st.setDescription(st.Description()) } + for _, mt := range st.matched { if !mt.SkipCleanUp { mt.Lines = cleanupScannerLines(mt.Lines, rxUncommentHeaders) @@ -584,9 +556,68 @@ COMMENTS: return err } } + return nil } +// parseLine processes a single comment line. It returns true when the +// caller should stop processing further comments (break COMMENTS). +func (st *sectionedParser) parseLine(line string) (stop bool) { + if rxSwaggerAnnotation.MatchString(line) { + if rxIgnoreOverride.MatchString(line) { + st.ignored = true + return true // an explicit ignore terminates this parser + } + if st.annotation == nil || !st.annotation.Matches(line) { + return true // a new swagger: annotation terminates this parser + } + + _ = st.annotation.Parse([]string{line}) + if len(st.header) > 0 { + st.seenTag = true + } + return false + } + + var matched bool + for _, tg := range st.taggers { + tagger := tg + if tagger.Matches(line) { + st.seenTag = true + st.currentTagger = &tagger + matched = true + break + } + } + + if st.currentTagger == nil { + if !st.skipHeader && !st.seenTag { + st.header = append(st.header, line) + } + return false + } + + if st.currentTagger.MultiLine && matched { + // the first line of a multiline tagger doesn't count + return false + } + + ts, ok := st.matched[st.currentTagger.Name] + if !ok { + ts = *st.currentTagger + } + ts.Lines = append(ts.Lines, line) + if st.matched == nil { + st.matched = make(map[string]tagParser) + } + st.matched[st.currentTagger.Name] = ts + + if !st.currentTagger.MultiLine { + st.currentTagger = nil + } + return false +} + func (st *sectionedParser) collectTitleDescription() { if st.workedOutTitle { return @@ -600,7 +631,7 @@ func (st *sectionedParser) collectTitleDescription() { st.title, st.header = collectScannerTitleDescription(st.header) } -type validationBuilder interface { +type validationBuilder interface { //nolint:interfacebloat // mirrors the full set of Swagger validation properties SetMaximum(maxium float64, isExclusive bool) SetMinimum(minimum float64, isExclusive bool) SetMultipleOf(multiple float64) @@ -892,15 +923,13 @@ func parseValueFromSchema(s string, schema *spec.SimpleSchema) (any, error) { case "object": var obj map[string]any if err := json.Unmarshal([]byte(s), &obj); err != nil { - // If we can't parse it, just return the string. - return s, nil + return s, nil //nolint:nilerr // fallback: return raw string when JSON is invalid } return obj, nil case "array": var slice []any if err := json.Unmarshal([]byte(s), &slice); err != nil { - // If we can't parse it, just return the string. - return s, nil + return s, nil //nolint:nilerr // fallback: return raw string when JSON is invalid } return slice, nil default: @@ -1085,25 +1114,27 @@ func (su *setRequiredSchema) Parse(lines []string) error { return nil } matches := rxRequired.FindStringSubmatch(lines[0]) - if len(matches) > 1 && len(matches[1]) > 0 { - req, err := strconv.ParseBool(matches[1]) - if err != nil { - return err - } - midx := -1 - for i, nm := range su.schema.Required { - if nm == su.field { - midx = i - break - } + if len(matches) <= 1 || len(matches[1]) == 0 { + return nil + } + + req, err := strconv.ParseBool(matches[1]) + if err != nil { + return err + } + midx := -1 + for i, nm := range su.schema.Required { + if nm == su.field { + midx = i + break } - if req { - if midx < 0 { - su.schema.Required = append(su.schema.Required, su.field) - } - } else if midx >= 0 { - su.schema.Required = append(su.schema.Required[:midx], su.schema.Required[midx+1:]...) + } + if req { + if midx < 0 { + su.schema.Required = append(su.schema.Required, su.field) } + } else if midx >= 0 { + su.schema.Required = append(su.schema.Required[:midx], su.schema.Required[midx+1:]...) } return nil } @@ -1188,7 +1219,7 @@ func (ss *setSecurity) Parse(lines []string) error { var result []map[string][]string for _, line := range lines { - kv := strings.SplitN(line, ":", 2) + kv := strings.SplitN(line, ":", kvParts) scopes := []string{} var key string @@ -1231,21 +1262,21 @@ func (ss *setOpResponses) Matches(line string) bool { return ss.rx.MatchString(line) } -// ResponseTag used when specifying a response to point to a defined swagger:response. -const ResponseTag = "response" +// responseTag used when specifying a response to point to a defined swagger:response. +const responseTag = "response" -// BodyTag used when specifying a response to point to a model/schema. -const BodyTag = "body" +// bodyTag used when specifying a response to point to a model/schema. +const bodyTag = "body" -// DescriptionTag used when specifying a response that gives a description of the response. -const DescriptionTag = "description" +// descriptionTag used when specifying a response that gives a description of the response. +const descriptionTag = "description" func parseTags(line string) (modelOrResponse string, arrays int, isDefinitionRef bool, description string, err error) { tags := strings.Split(line, " ") parsedModelOrResponse := false for i, tagAndValue := range tags { - tagValList := strings.SplitN(tagAndValue, ":", 2) + tagValList := strings.SplitN(tagAndValue, ":", kvParts) var tag, value string if len(tagValList) > 1 { tag = tagValList[0] @@ -1254,20 +1285,20 @@ func parseTags(line string) (modelOrResponse string, arrays int, isDefinitionRef // TODO: Print a warning, and in the long term, do not support not tagged values // Add a default tag if none is supplied if i == 0 { - tag = ResponseTag + tag = responseTag } else { - tag = DescriptionTag + tag = descriptionTag } value = tagValList[0] } foundModelOrResponse := false if !parsedModelOrResponse { - if tag == BodyTag { + if tag == bodyTag { foundModelOrResponse = true isDefinitionRef = true } - if tag == ResponseTag { + if tag == responseTag { foundModelOrResponse = true isDefinitionRef = false } @@ -1283,34 +1314,50 @@ func parseTags(line string) (modelOrResponse string, arrays int, isDefinitionRef } // What's left over is the model name modelOrResponse = value - } else { - foundDescription := false - if tag == DescriptionTag { - foundDescription = true - } - if foundDescription { - // Descriptions are special, they make they read the rest of the line - descriptionWords := []string{value} - if i < len(tags)-1 { - descriptionWords = append(descriptionWords, tags[i+1:]...) - } - description = strings.Join(descriptionWords, " ") - break - } - if tag == ResponseTag || tag == BodyTag || tag == DescriptionTag { - err = fmt.Errorf("valid tag %s, but not in a valid position", tag) - } else { - err = fmt.Errorf("invalid tag: %s", tag) + continue + } + + if tag == descriptionTag { + // Descriptions are special, they read the rest of the line + descriptionWords := []string{value} + if i < len(tags)-1 { + descriptionWords = append(descriptionWords, tags[i+1:]...) } - // return error - return modelOrResponse, arrays, isDefinitionRef, description, err + description = strings.Join(descriptionWords, " ") + break + } + + if tag == responseTag || tag == bodyTag { + err = fmt.Errorf("valid tag %s, but not in a valid position: %w", tag, ErrCodeScan) + } else { + err = fmt.Errorf("invalid tag: %s: %w", tag, ErrCodeScan) } + // return error + return modelOrResponse, arrays, isDefinitionRef, description, err } // TODO: Maybe do, if !parsedModelOrResponse {return some error} return modelOrResponse, arrays, isDefinitionRef, description, err } +func assignResponse(key string, resp spec.Response, def *spec.Response, scr map[int]spec.Response) (*spec.Response, map[int]spec.Response) { + if strings.EqualFold("default", key) { + if def == nil { + def = &resp + } + return def, scr + } + + if sc, err := strconv.Atoi(key); err == nil { + if scr == nil { + scr = make(map[int]spec.Response) + } + scr[sc] = resp + } + + return def, scr +} + func (ss *setOpResponses) Parse(lines []string) error { if len(lines) == 0 || (len(lines) == 1 && len(lines[0]) == 0) { return nil @@ -1320,95 +1367,84 @@ func (ss *setOpResponses) Parse(lines []string) error { var scr map[int]spec.Response for _, line := range lines { - kv := strings.SplitN(line, ":", 2) - var key, value string + var err error + def, scr, err = ss.parseResponseLine(line, def, scr) + if err != nil { + return err + } + } - if len(kv) > 1 { - key = strings.TrimSpace(kv[0]) - if key == "" { - // this must be some weird empty line - continue - } - value = strings.TrimSpace(kv[1]) - if value == "" { - var resp spec.Response - if strings.EqualFold("default", key) { - if def == nil { - def = &resp - } - } else { - if sc, err := strconv.Atoi(key); err == nil { - if scr == nil { - scr = make(map[int]spec.Response) - } - scr[sc] = resp - } - } - continue - } - refTarget, arrays, isDefinitionRef, description, err := parseTags(value) - if err != nil { - return err - } - // A possible exception for having a definition - if _, ok := ss.responses[refTarget]; !ok { - if _, ok := ss.definitions[refTarget]; ok { - isDefinitionRef = true - } - } + ss.set(def, scr) + return nil +} - var ref spec.Ref - if isDefinitionRef { - if description == "" { - description = refTarget - } - ref, err = spec.NewRef("#/definitions/" + refTarget) - } else { - ref, err = spec.NewRef("#/responses/" + refTarget) - } - if err != nil { - return err - } +func (ss *setOpResponses) parseResponseLine(line string, def *spec.Response, scr map[int]spec.Response) (*spec.Response, map[int]spec.Response, error) { + kv := strings.SplitN(line, ":", kvParts) + if len(kv) <= 1 { + return def, scr, nil + } - // description should used on anyway. - resp := spec.Response{ResponseProps: spec.ResponseProps{Description: description}} - - if isDefinitionRef { - resp.Schema = new(spec.Schema) - resp.Description = description - if arrays == 0 { - resp.Schema.Ref = ref - } else { - cs := resp.Schema - for range arrays { - cs.Typed("array", "") - cs.Items = new(spec.SchemaOrArray) - cs.Items.Schema = new(spec.Schema) - cs = cs.Items.Schema - } - cs.Ref = ref - } - // ref. could be empty while use description tag - } else if len(refTarget) > 0 { - resp.Ref = ref - } + key := strings.TrimSpace(kv[0]) + if key == "" { + return def, scr, nil + } - if strings.EqualFold("default", key) { - if def == nil { - def = &resp - } - } else { - if sc, err := strconv.Atoi(key); err == nil { - if scr == nil { - scr = make(map[int]spec.Response) - } - scr[sc] = resp - } + value := strings.TrimSpace(kv[1]) + if value == "" { + def, scr = assignResponse(key, spec.Response{}, def, scr) + return def, scr, nil + } + + refTarget, arrays, isDefinitionRef, description, err := parseTags(value) + if err != nil { + return def, scr, err + } + + // A possible exception for having a definition + if _, ok := ss.responses[refTarget]; !ok { + if _, ok := ss.definitions[refTarget]; ok { + isDefinitionRef = true + } + } + + var ref spec.Ref + if isDefinitionRef { + if description == "" { + description = refTarget + } + ref, err = spec.NewRef("#/definitions/" + refTarget) + } else { + ref, err = spec.NewRef("#/responses/" + refTarget) + } + if err != nil { + return def, scr, err + } + + // description should used on anyway. + resp := spec.Response{ResponseProps: spec.ResponseProps{Description: description}} + + if isDefinitionRef { + resp.Schema = new(spec.Schema) + resp.Description = description + if arrays == 0 { + resp.Schema.Ref = ref + } else { + cs := resp.Schema + for range arrays { + cs.Typed("array", "") + cs.Items = new(spec.SchemaOrArray) + cs.Items.Schema = new(spec.Schema) + cs = cs.Items.Schema } + cs.Ref = ref } + // ref. could be empty while use description tag + } else if len(refTarget) > 0 { + resp.Ref = ref } - ss.set(def, scr) - return nil + + def, scr = assignResponse(key, resp, def, scr) + return def, scr, nil } func parseEnumOld(val string, s *spec.SimpleSchema) []any { @@ -1454,8 +1490,8 @@ func parseEnum(val string, s *spec.SimpleSchema) []any { return interfaceSlice } -// AlphaChars used when parsing for Vendor Extensions. -const AlphaChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +// alphaChars used when parsing for Vendor Extensions. +const alphaChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" func newSetExtensions(setter func(*spec.Extensions)) *setOpExtensions { return &setOpExtensions{ @@ -1478,23 +1514,25 @@ type extensionParsingStack []any // Helper function to walk back through extensions until the proper nest level is reached. func (stack *extensionParsingStack) walkBack(rawLines []string, lineIndex int) { - indent := strings.IndexAny(rawLines[lineIndex], AlphaChars) - nextIndent := strings.IndexAny(rawLines[lineIndex+1], AlphaChars) - if nextIndent < indent { - // Pop elements off the stack until we're back where we need to be - runbackIndex := 0 - poppedIndent := 1000 - for { - checkIndent := strings.IndexAny(rawLines[lineIndex-runbackIndex], AlphaChars) - if nextIndent == checkIndent { - break - } - if checkIndent < poppedIndent { - *stack = (*stack)[:len(*stack)-1] - poppedIndent = checkIndent - } - runbackIndex++ + indent := strings.IndexAny(rawLines[lineIndex], alphaChars) + nextIndent := strings.IndexAny(rawLines[lineIndex+1], alphaChars) + if nextIndent >= indent { + return + } + + // Pop elements off the stack until we're back where we need to be + runbackIndex := 0 + poppedIndent := 1000 + for { + checkIndent := strings.IndexAny(rawLines[lineIndex-runbackIndex], alphaChars) + if nextIndent == checkIndent { + break } + if checkIndent < poppedIndent { + *stack = (*stack)[:len(*stack)-1] + poppedIndent = checkIndent + } + runbackIndex++ } } @@ -1509,7 +1547,8 @@ func buildExtensionObjects(rawLines []string, cleanLines []string, lineIndex int } return } - kv := strings.SplitN(cleanLines[lineIndex], ":", 2) + + kv := strings.SplitN(cleanLines[lineIndex], ":", kvParts) key := strings.TrimSpace(kv[0]) if key == "" { // Some odd empty line @@ -1518,104 +1557,130 @@ func buildExtensionObjects(rawLines []string, cleanLines []string, lineIndex int nextIsList := false if lineIndex < len(rawLines)-1 { - next := strings.SplitAfterN(cleanLines[lineIndex+1], ":", 2) + next := strings.SplitAfterN(cleanLines[lineIndex+1], ":", kvParts) nextIsList = len(next) == 1 } - if len(kv) > 1 { - // Should be the start of a map or a key:value pair - value := strings.TrimSpace(kv[1]) - - if rxAllowedExtensions.MatchString(key) { - // New extension started - if stack != nil { - if ext, ok := (*stack)[0].(extensionObject); ok { - *extObjs = append(*extObjs, ext) - } - } - - if value != "" { - ext := extensionObject{ - Extension: key, - } - // Extension is simple key:value pair, no stack - rootMap := make(map[string]string) - rootMap[key] = value - ext.Root = rootMap - *extObjs = append(*extObjs, ext) - buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, nil) - } else { - ext := extensionObject{ - Extension: key, - } - if nextIsList { - // Extension is an array - rootMap := make(map[string]*[]string) - rootList := make([]string, 0) - rootMap[key] = &rootList - ext.Root = rootMap - stack = &extensionParsingStack{} - *stack = append(*stack, ext) - *stack = append(*stack, ext.Root.(map[string]*[]string)[key]) - } else { - // Extension is an object - rootMap := make(map[string]any) - innerMap := make(map[string]any) - rootMap[key] = innerMap - ext.Root = rootMap - stack = &extensionParsingStack{} - *stack = append(*stack, ext) - *stack = append(*stack, innerMap) - } - buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, stack) - } - } else if stack != nil && len(*stack) != 0 { - stackIndex := len(*stack) - 1 - if value == "" { - if nextIsList { - // start of new list - newList := make([]string, 0) - asMap, ok := (*stack)[stackIndex].(map[string]any) - if !ok { - panic(fmt.Errorf("internal error: stack index expected to be map[string]any, but got %T", (*stack)[stackIndex])) - } - asMap[key] = &newList - *stack = append(*stack, &newList) - } else { - // start of new map - newMap := make(map[string]any) - asMap, ok := (*stack)[stackIndex].(map[string]any) - if !ok { - panic(fmt.Errorf("internal error: stack index expected to be map[string]any, but got %T", (*stack)[stackIndex])) - } - asMap[key] = newMap - *stack = append(*stack, newMap) - } - } else { - // key:value - if reflect.TypeOf((*stack)[stackIndex]).Kind() == reflect.Map { - asMap, ok := (*stack)[stackIndex].(map[string]any) - if !ok { - panic(fmt.Errorf("internal error: stack index expected to be map[string]any, but got %T", (*stack)[stackIndex])) - } - asMap[key] = value - } - if lineIndex < len(rawLines)-1 && !rxAllowedExtensions.MatchString(cleanLines[lineIndex+1]) { - stack.walkBack(rawLines, lineIndex) - } - } - buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, stack) - } - } else if stack != nil && len(*stack) != 0 { + if len(kv) <= 1 { // Should be a list item + if stack == nil || len(*stack) == 0 { + return + } stackIndex := len(*stack) - 1 - list := (*stack)[stackIndex].(*[]string) + list, ok := (*stack)[stackIndex].(*[]string) + if !ok { + panic(fmt.Errorf("internal error: expected *[]string, got %T: %w", (*stack)[stackIndex], ErrCodeScan)) + } *list = append(*list, key) (*stack)[stackIndex] = list if lineIndex < len(rawLines)-1 && !rxAllowedExtensions.MatchString(cleanLines[lineIndex+1]) { stack.walkBack(rawLines, lineIndex) } buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, stack) + return + } + + // Should be the start of a map or a key:value pair + value := strings.TrimSpace(kv[1]) + + if rxAllowedExtensions.MatchString(key) { + buildNewExtension(key, value, nextIsList, stack, rawLines, cleanLines, lineIndex, extObjs) + return + } + + if stack == nil || len(*stack) == 0 { + return + } + + buildStackEntry(key, value, nextIsList, stack, rawLines, cleanLines, lineIndex) + buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, stack) +} + +// buildNewExtension handles the start of a new x- extension key. +func buildNewExtension(key, value string, nextIsList bool, stack *extensionParsingStack, rawLines, cleanLines []string, lineIndex int, extObjs *[]extensionObject) { + // Flush any previous extension on the stack + if stack != nil { + if ext, ok := (*stack)[0].(extensionObject); ok { + *extObjs = append(*extObjs, ext) + } + } + + if value != "" { + ext := extensionObject{ + Extension: key, + } + // Extension is simple key:value pair, no stack + rootMap := make(map[string]string) + rootMap[key] = value + ext.Root = rootMap + *extObjs = append(*extObjs, ext) + buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, nil) + return + } + + ext := extensionObject{ + Extension: key, + } + if nextIsList { + // Extension is an array + rootMap := make(map[string]*[]string) + rootList := make([]string, 0) + rootMap[key] = &rootList + ext.Root = rootMap + stack = &extensionParsingStack{} + *stack = append(*stack, ext) + rootListMap, ok := ext.Root.(map[string]*[]string) + if !ok { + panic(fmt.Errorf("internal error: expected map[string]*[]string, got %T: %w", ext.Root, ErrCodeScan)) + } + *stack = append(*stack, rootListMap[key]) + } else { + // Extension is an object + rootMap := make(map[string]any) + innerMap := make(map[string]any) + rootMap[key] = innerMap + ext.Root = rootMap + stack = &extensionParsingStack{} + *stack = append(*stack, ext) + *stack = append(*stack, innerMap) + } + buildExtensionObjects(rawLines, cleanLines, lineIndex+1, extObjs, stack) +} + +func assertStackMap(stack *extensionParsingStack, index int) map[string]any { + asMap, ok := (*stack)[index].(map[string]any) + if !ok { + panic(fmt.Errorf("internal error: stack index expected to be map[string]any, but got %T: %w", (*stack)[index], ErrCodeScan)) + } + return asMap +} + +// buildStackEntry adds a key/value, nested list, or nested map to the current stack. +func buildStackEntry(key, value string, nextIsList bool, stack *extensionParsingStack, rawLines, cleanLines []string, lineIndex int) { + stackIndex := len(*stack) - 1 + if value == "" { + asMap := assertStackMap(stack, stackIndex) + if nextIsList { + // start of new list + newList := make([]string, 0) + asMap[key] = &newList + *stack = append(*stack, &newList) + } else { + // start of new map + newMap := make(map[string]any) + asMap[key] = newMap + *stack = append(*stack, newMap) + } + return + } + + // key:value + if reflect.TypeOf((*stack)[stackIndex]).Kind() == reflect.Map { + asMap := assertStackMap(stack, stackIndex) + asMap[key] = value + } + if lineIndex < len(rawLines)-1 && !rxAllowedExtensions.MatchString(cleanLines[lineIndex+1]) { + stack.walkBack(rawLines, lineIndex) } } @@ -1639,12 +1704,12 @@ func (ss *setOpExtensions) Parse(lines []string) error { // list/array // object for _, ext := range extList { - if _, ok := ext.Root.(map[string]string); ok { - exts.AddExtension(ext.Extension, ext.Root.(map[string]string)[ext.Extension]) - } else if _, ok := ext.Root.(map[string]*[]string); ok { - exts.AddExtension(ext.Extension, *(ext.Root.(map[string]*[]string)[ext.Extension])) - } else if _, ok := ext.Root.(map[string]any); ok { - exts.AddExtension(ext.Extension, ext.Root.(map[string]any)[ext.Extension]) + if m, ok := ext.Root.(map[string]string); ok { + exts.AddExtension(ext.Extension, m[ext.Extension]) + } else if m, ok := ext.Root.(map[string]*[]string); ok { + exts.AddExtension(ext.Extension, *m[ext.Extension]) + } else if m, ok := ext.Root.(map[string]any); ok { + exts.AddExtension(ext.Extension, m[ext.Extension]) } else { debugLogf("Unknown Extension type: %s", fmt.Sprint(reflect.TypeOf(ext.Root))) } @@ -1654,7 +1719,7 @@ func (ss *setOpExtensions) Parse(lines []string) error { return nil } -var unsupportedTypes = map[string]struct{}{ +var unsupportedTypes = map[string]struct{}{ //nolint:gochecknoglobals // immutable lookup table "complex64": {}, "complex128": {}, } diff --git a/parser_go119_test.go b/parser_go119_test.go index 0f09b31..4379d5a 100644 --- a/parser_go119_test.go +++ b/parser_go119_test.go @@ -40,5 +40,8 @@ The punctuation here does indeed matter. But it won't for go. require.NoError(t, err) assert.Equal(t, []string{"This has a title without whitespace."}, st.Title()) - assert.Equal(t, []string{"The punctuation here does indeed matter. But it won't for go.", "", "# There is an inline header here that doesn't count for finding a title"}, st.Description()) + assert.Equal(t, []string{ + "The punctuation here does indeed matter. But it won't for go.", "", + "# There is an inline header here that doesn't count for finding a title", + }, st.Description()) } diff --git a/parser_test.go b/parser_test.go index fa36cdf..6385997 100644 --- a/parser_test.go +++ b/parser_test.go @@ -49,16 +49,17 @@ func (sap *schemaAnnotationParser) Parse(lines []string) error { } func TestSectionedParser_TitleDescription(t *testing.T) { - text := `This has a title, separated by a whitespace line + const ( + text = `This has a title, separated by a whitespace line In this example the punctuation for the title should not matter for swagger. For go it will still make a difference though. ` - text2 := `This has a title without whitespace. + text2 = `This has a title without whitespace. The punctuation here does indeed matter. But it won't for go. ` - text3 := `This has a title, and markdown in the description + text3 = `This has a title, and markdown in the description See how markdown works now, we can have lists: @@ -69,7 +70,7 @@ See how markdown works now, we can have lists: [Links works too](http://localhost) ` - text4 := `This has whitespace sensitive markdown in the description + text4 = `This has whitespace sensitive markdown in the description |+ first item | + nested item @@ -80,6 +81,7 @@ Sample code block: | fmt.Println("Hello World!") ` + ) var err error @@ -105,7 +107,11 @@ Sample code block: require.NoError(t, err) assert.Equal(t, []string{"This has a title, and markdown in the description"}, st.Title()) - assert.Equal(t, []string{"See how markdown works now, we can have lists:", "", "+ first item", "+ second item", "+ third item", "", "[Links works too](http://localhost)"}, st.Description()) + assert.Equal(t, []string{ + "See how markdown works now, we can have lists:", "", + "+ first item", "+ second item", "+ third item", "", + "[Links works too](http://localhost)", + }, st.Description()) st = §ionedParser{} st.setTitle = func(_ []string) {} @@ -121,17 +127,19 @@ func dummyBuilder() schemaValidations { } func TestSectionedParser_TagsDescription(t *testing.T) { - block := `This has a title without whitespace. + const ( + block = `This has a title without whitespace. The punctuation here does indeed matter. But it won't for go. minimum: 10 maximum: 20 ` - block2 := `This has a title without whitespace. + block2 = `This has a title without whitespace. The punctuation here does indeed matter. But it won't for go. minimum: 10 maximum: 20 ` + ) var err error @@ -173,7 +181,7 @@ maximum: 20 } func TestSectionedParser_Empty(t *testing.T) { - block := `swagger:response someResponse` + const block = `swagger:response someResponse` var err error @@ -192,16 +200,13 @@ func TestSectionedParser_Empty(t *testing.T) { assert.EqualT(t, "someResponse", ap.Name) } -func TestSectionedParser_SkipSectionAnnotation(t *testing.T) { - block := `swagger:model someModel - -This has a title without whitespace. -The punctuation here does indeed matter. But it won't for go. - -minimum: 10 -maximum: 20 -` - var err error +func testSectionedParserWithBlock( + t *testing.T, + block string, + expectedMatchedCount int, + maximumExpected bool, +) { + t.Helper() st := §ionedParser{} st.setTitle = func(_ []string) {} @@ -213,52 +218,42 @@ maximum: 20 {"MultipleOf", false, false, nil, &setMultipleOf{dummyBuilder(), regexp.MustCompile(fmt.Sprintf(rxMultipleOfFmt, ""))}}, } - err = st.Parse(ascg(block)) + err := st.Parse(ascg(block)) require.NoError(t, err) assert.Equal(t, []string{"This has a title without whitespace."}, st.Title()) assert.Equal(t, []string{"The punctuation here does indeed matter. But it won't for go."}, st.Description()) - assert.Len(t, st.matched, 2) + assert.Len(t, st.matched, expectedMatchedCount) _, ok := st.matched["Maximum"] - assert.TrueT(t, ok) + assert.EqualT(t, maximumExpected, ok) _, ok = st.matched["Minimum"] assert.TrueT(t, ok) assert.EqualT(t, "SomeModel", ap.GoName) assert.EqualT(t, "someModel", ap.Name) } -func TestSectionedParser_TerminateOnNewAnnotation(t *testing.T) { - block := `swagger:model someModel +func TestSectionedParser_SkipSectionAnnotation(t *testing.T) { + const block = `swagger:model someModel This has a title without whitespace. The punctuation here does indeed matter. But it won't for go. minimum: 10 -swagger:meta maximum: 20 ` - var err error + testSectionedParserWithBlock(t, block, 2, true) +} - st := §ionedParser{} - st.setTitle = func(_ []string) {} - ap := newSchemaAnnotationParser("SomeModel") - st.annotation = ap - st.taggers = []tagParser{ - {"Maximum", false, false, nil, &setMaximum{dummyBuilder(), regexp.MustCompile(fmt.Sprintf(rxMaximumFmt, ""))}}, - {"Minimum", false, false, nil, &setMinimum{dummyBuilder(), regexp.MustCompile(fmt.Sprintf(rxMinimumFmt, ""))}}, - {"MultipleOf", false, false, nil, &setMultipleOf{dummyBuilder(), regexp.MustCompile(fmt.Sprintf(rxMultipleOfFmt, ""))}}, - } +func TestSectionedParser_TerminateOnNewAnnotation(t *testing.T) { + const block = `swagger:model someModel - err = st.Parse(ascg(block)) - require.NoError(t, err) - assert.Equal(t, []string{"This has a title without whitespace."}, st.Title()) - assert.Equal(t, []string{"The punctuation here does indeed matter. But it won't for go."}, st.Description()) - assert.Len(t, st.matched, 1) - _, ok := st.matched["Maximum"] - assert.FalseT(t, ok) - _, ok = st.matched["Minimum"] - assert.TrueT(t, ok) - assert.EqualT(t, "SomeModel", ap.GoName) - assert.EqualT(t, "someModel", ap.Name) +This has a title without whitespace. +The punctuation here does indeed matter. But it won't for go. + +minimum: 10 +swagger:meta +maximum: 20 +` + testSectionedParserWithBlock(t, block, 1, false) } func ascg(txt string) *ast.CommentGroup { diff --git a/regexprs_test.go b/regexprs_test.go index 2fa898b..a201704 100644 --- a/regexprs_test.go +++ b/regexprs_test.go @@ -70,14 +70,15 @@ func TestSchemaValueExtractors(t *testing.T) { "date-time", "long-combo-1-with-combo-2-and-a-3rd-one-too", } - invalidParams := []string{ + invalidParams := make([]string, 0, 9) + invalidParams = append(invalidParams, "1-yada-3", "1-2-3", "-yada-3", "-2-3", "*blah", "blah*", - } + ) verifySwaggerOneArgSwaggerTag(t, rxStrFmt, strfmts, validParams, append(invalidParams, "", " ", " ")) verifySwaggerOneArgSwaggerTag(t, rxModelOverride, models, append(validParams, "", " ", " "), invalidParams) @@ -95,26 +96,15 @@ func TestSchemaValueExtractors(t *testing.T) { verifyIntegerMinMaxManyWords(t, rxf(rxMinLengthFmt, ""), "min", []string{"len", "length"}) // pattern - extraSpaces := []string{"", " ", " ", " "} - prefixes := []string{"//", "*", ""} - patArgs := []string{"^\\w+$", "[A-Za-z0-9-.]*"} - patNames := []string{"pattern", "Pattern"} - for _, pref := range prefixes { - for _, es1 := range extraSpaces { - for _, nm := range patNames { - for _, es2 := range extraSpaces { - for _, es3 := range extraSpaces { - for _, arg := range patArgs { - line := strings.Join([]string{pref, es1, nm, es2, ":", es3, arg}, "") - matches := rxf(rxPatternFmt, "").FindStringSubmatch(line) - assert.Len(t, matches, 2) - assert.EqualT(t, arg, matches[1]) - } - } - } - } - } - } + patPrefixes := cartesianJoin( + []string{"//", "*", ""}, + []string{"", " ", " ", " "}, + []string{"pattern", "Pattern"}, + []string{"", " ", " ", " "}, + []string{":"}, + []string{"", " ", " ", " "}, + ) + verifyRegexpArgs(t, rxf(rxPatternFmt, ""), patPrefixes, []string{"^\\w+$", "[A-Za-z0-9-.]*"}, nil, 2, 1) verifyIntegerMinMaxManyWords(t, rxf(rxMinItemsFmt, ""), "min", []string{"items"}) verifyBoolean(t, rxf(rxUniqueFmt, ""), []string{"unique"}, nil) @@ -127,9 +117,59 @@ func makeMinMax(lower string) (res []string) { for _, a := range []string{"", "imum"} { res = append(res, lower+a, strings.Title(lower)+a) //nolint:staticcheck // Title is deprecated, yet still useful here. The replacement is bit heavy for just this test } + return res } +// cartesianJoin returns all concatenations formed by picking one element from each slot. +func cartesianJoin(slots ...[]string) []string { + result := []string{""} + for _, slot := range slots { + next := make([]string, 0, len(result)*len(slot)) + for _, prefix := range result { + for _, s := range slot { + next = append(next, prefix+s) + } + } + result = next + } + + return result +} + +// titleCaseVariants returns each name paired with its Title-cased form. +func titleCaseVariants(names []string) []string { + result := make([]string, 0, len(names)*2) + for _, nm := range names { + result = append(result, nm, strings.Title(nm)) //nolint:staticcheck // Title is deprecated, yet still useful here + } + + return result +} + +// verifyRegexpArgs tests that matcher matches lines formed by each prefix+validArg +// (expecting expectedMatchLen matches with the value at matchIdx) and rejects prefix+invalidArg. +func verifyRegexpArgs(t *testing.T, matcher *regexp.Regexp, prefixes, validArgs, invalidArgs []string, expectedMatchLen, matchIdx int) int { + t.Helper() + cnt := 0 + for _, prefix := range prefixes { + for _, vv := range validArgs { + matches := matcher.FindStringSubmatch(prefix + vv) + assert.Len(t, matches, expectedMatchLen) + assert.EqualT(t, vv, matches[matchIdx]) + cnt++ + } + + for _, iv := range invalidArgs { + matches := matcher.FindStringSubmatch(prefix + iv) + assert.Empty(t, matches) + cnt++ + } + } + + return cnt +} + func verifyBoolean(t *testing.T, matcher *regexp.Regexp, names, names2 []string) { t.Helper() @@ -137,23 +177,17 @@ func verifyBoolean(t *testing.T, matcher *regexp.Regexp, names, names2 []string) prefixes := []string{"//", "*", ""} validArgs := []string{"true", "false"} invalidArgs := []string{"TRUE", "FALSE", "t", "f", "1", "0", "True", "False", "true*", "false*"} - nms := make([]string, 0, len(names)) - for _, nm := range names { - nms = append(nms, nm, strings.Title(nm)) //nolint:staticcheck - } - - nms2 := make([]string, 0, len(names2)) - for _, nm := range names2 { - nms2 = append(nms2, nm, strings.Title(nm)) //nolint:staticcheck - } + nms := titleCaseVariants(names) var rnms []string - if len(nms2) > 0 { + if len(names2) > 0 { + nms2 := titleCaseVariants(names2) + spacesAndDash := []string{"", " ", " ", " ", "-"} for _, nm := range nms { - for _, es := range append(extraSpaces, "-") { + for _, sep := range spacesAndDash { for _, nm2 := range nms2 { - rnms = append(rnms, strings.Join([]string{nm, es, nm2}, "")) + rnms = append(rnms, nm+sep+nm2) } } } @@ -161,36 +195,13 @@ func verifyBoolean(t *testing.T, matcher *regexp.Regexp, names, names2 []string) rnms = nms } - var cnt int - for _, pref := range prefixes { - for _, es1 := range extraSpaces { - for _, nm := range rnms { - for _, es2 := range extraSpaces { - for _, es3 := range extraSpaces { - for _, vv := range validArgs { - line := strings.Join([]string{pref, es1, nm, es2, ":", es3, vv}, "") - matches := matcher.FindStringSubmatch(line) - assert.Len(t, matches, 2) - assert.EqualT(t, vv, matches[1]) - cnt++ - } - for _, iv := range invalidArgs { - line := strings.Join([]string{pref, es1, nm, es2, ":", es3, iv}, "") - matches := matcher.FindStringSubmatch(line) - assert.Empty(t, matches) - cnt++ - } - } - } - } - } - } + linePrefixes := cartesianJoin(prefixes, extraSpaces, rnms, extraSpaces, []string{":"}, extraSpaces) + cnt := verifyRegexpArgs(t, matcher, linePrefixes, validArgs, invalidArgs, 2, 1) var nm2 string if len(names2) > 0 { nm2 = " " + names2[0] } - t.Logf("tested %d %s%s combinations\n", cnt, names[0], nm2) } @@ -199,47 +210,18 @@ func verifyIntegerMinMaxManyWords(t *testing.T, matcher *regexp.Regexp, name1 st extraSpaces := []string{"", " ", " ", " "} prefixes := []string{"//", "*", ""} - validNumericArgs := []string{"0", "1234"} - invalidNumericArgs := []string{"1A3F", "2e10", "*12", "12*", "-1235", "0.0", "1234.0394", "-2948.484"} + validArgs := []string{"0", "1234"} + invalidArgs := []string{"1A3F", "2e10", "*12", "12*", "-1235", "0.0", "1234.0394", "-2948.484"} - names := make([]string, 0, len(words)) - for _, w := range words { - names = append(names, w, strings.Title(w)) //nolint:staticcheck - } + wordVariants := titleCaseVariants(words) + spacesAndDash := []string{"", " ", " ", " ", "-"} + linePrefixes := cartesianJoin(prefixes, extraSpaces, makeMinMax(name1), spacesAndDash, wordVariants, extraSpaces, []string{":"}, extraSpaces) + cnt := verifyRegexpArgs(t, matcher, linePrefixes, validArgs, invalidArgs, 2, 1) - var cnt int - for _, pref := range prefixes { - for _, es1 := range extraSpaces { - for _, nm1 := range makeMinMax(name1) { - for _, es2 := range append(extraSpaces, "-") { - for _, nm2 := range names { - for _, es3 := range extraSpaces { - for _, es4 := range extraSpaces { - for _, vv := range validNumericArgs { - line := strings.Join([]string{pref, es1, nm1, es2, nm2, es3, ":", es4, vv}, "") - matches := matcher.FindStringSubmatch(line) - assert.Len(t, matches, 2) - assert.EqualT(t, vv, matches[1]) - cnt++ - } - for _, iv := range invalidNumericArgs { - line := strings.Join([]string{pref, es1, nm1, es2, nm2, es3, ":", es4, iv}, "") - matches := matcher.FindStringSubmatch(line) - assert.Empty(t, matches) - cnt++ - } - } - } - } - } - } - } - } var nm2 string if len(words) > 0 { nm2 = " " + words[0] } - t.Logf("tested %d %s%s combinations\n", cnt, name1, nm2) } @@ -248,48 +230,23 @@ func verifyNumeric2Words(t *testing.T, matcher *regexp.Regexp, name1, name2 stri extraSpaces := []string{"", " ", " ", " "} prefixes := []string{"//", "*", ""} - validNumericArgs := []string{"0", "1234", "-1235", "0.0", "1234.0394", "-2948.484"} - invalidNumericArgs := []string{"1A3F", "2e10", "*12", "12*"} - - var cnt int - for _, pref := range prefixes { - for _, es1 := range extraSpaces { - for _, es2 := range extraSpaces { - for _, es3 := range extraSpaces { - for _, es4 := range extraSpaces { - for _, vv := range validNumericArgs { - lines := []string{ - strings.Join([]string{pref, es1, name1, es2, name2, es3, ":", es4, vv}, ""), - strings.Join([]string{pref, es1, strings.Title(name1), es2, strings.Title(name2), es3, ":", es4, vv}, ""), //nolint:staticcheck - strings.Join([]string{pref, es1, strings.Title(name1), es2, name2, es3, ":", es4, vv}, ""), //nolint:staticcheck - strings.Join([]string{pref, es1, name1, es2, strings.Title(name2), es3, ":", es4, vv}, ""), //nolint:staticcheck - } - for _, line := range lines { - matches := matcher.FindStringSubmatch(line) - assert.Len(t, matches, 2) - assert.EqualT(t, vv, matches[1]) - cnt++ - } - } - for _, iv := range invalidNumericArgs { - lines := []string{ - strings.Join([]string{pref, es1, name1, es2, name2, es3, ":", es4, iv}, ""), - strings.Join([]string{pref, es1, strings.Title(name1), es2, strings.Title(name2), es3, ":", es4, iv}, ""), //nolint:staticcheck - strings.Join([]string{pref, es1, strings.Title(name1), es2, name2, es3, ":", es4, iv}, ""), //nolint:staticcheck - strings.Join([]string{pref, es1, name1, es2, strings.Title(name2), es3, ":", es4, iv}, ""), //nolint:staticcheck - } - for _, line := range lines { - matches := matcher.FindStringSubmatch(line) - assert.Empty(t, matches) - cnt++ - } - } - } - } - } - } + validArgs := []string{"0", "1234", "-1235", "0.0", "1234.0394", "-2948.484"} + invalidArgs := []string{"1A3F", "2e10", "*12", "12*"} + + titleName1 := strings.Title(name1) //nolint:staticcheck // Title is deprecated, yet still useful here + titleName2 := strings.Title(name2) //nolint:staticcheck // Title is deprecated, yet still useful here + nameVariants := make([]string, 0, 4*len(extraSpaces)) + for _, es := range extraSpaces { + nameVariants = append(nameVariants, + name1+es+name2, + titleName1+es+titleName2, + titleName1+es+name2, + name1+es+titleName2, + ) } + linePrefixes := cartesianJoin(prefixes, extraSpaces, nameVariants, extraSpaces, []string{":"}, extraSpaces) + cnt := verifyRegexpArgs(t, matcher, linePrefixes, validArgs, invalidArgs, 2, 1) t.Logf("tested %d %s %s combinations\n", cnt, name1, name2) } @@ -298,39 +255,11 @@ func verifyMinMax(t *testing.T, matcher *regexp.Regexp, name string, operators [ extraSpaces := []string{"", " ", " ", " "} prefixes := []string{"//", "*", ""} - validNumericArgs := []string{"0", "1234", "-1235", "0.0", "1234.0394", "-2948.484"} - invalidNumericArgs := []string{"1A3F", "2e10", "*12", "12*"} - - var cnt int - for _, pref := range prefixes { - for _, es1 := range extraSpaces { - for _, wrd := range makeMinMax(name) { - for _, es2 := range extraSpaces { - for _, es3 := range extraSpaces { - for _, op := range operators { - for _, es4 := range extraSpaces { - for _, vv := range validNumericArgs { - line := strings.Join([]string{pref, es1, wrd, es2, ":", es3, op, es4, vv}, "") - matches := matcher.FindStringSubmatch(line) - // fmt.Printf("matching %q with %q, matches (%d): %v\n", line, matcher, len(matches), matches) - assert.Len(t, matches, 3) - assert.EqualT(t, vv, matches[2]) - cnt++ - } - for _, iv := range invalidNumericArgs { - line := strings.Join([]string{pref, es1, wrd, es2, ":", es3, op, es4, iv}, "") - matches := matcher.FindStringSubmatch(line) - assert.Empty(t, matches) - cnt++ - } - } - } - } - } - } - } - } + validArgs := []string{"0", "1234", "-1235", "0.0", "1234.0394", "-2948.484"} + invalidArgs := []string{"1A3F", "2e10", "*12", "12*"} + linePrefixes := cartesianJoin(prefixes, extraSpaces, makeMinMax(name), extraSpaces, []string{":"}, extraSpaces, operators, extraSpaces) + cnt := verifyRegexpArgs(t, matcher, linePrefixes, validArgs, invalidArgs, 3, 2) t.Logf("tested %d %s combinations\n", cnt, name) } @@ -365,7 +294,7 @@ func verifySwaggerMultiArgSwaggerTag(t *testing.T, matcher *regexp.Regexp, prefi for i := range validParams { vp = vp[:0] for j := range i + 1 { - vp = append(vp, validParams[j]) + vp = append(vp, validParams[j]) //nolint:gosec // G602: j is bounded by i+1 which is bounded by len(validParams) } actualParams = append(actualParams, strings.Join(vp, " ")) diff --git a/responses.go b/responses.go index 77e7b50..82b49fe 100644 --- a/responses.go +++ b/responses.go @@ -4,14 +4,10 @@ package codescan import ( - "errors" "fmt" - "go/ast" "go/types" "strings" - "golang.org/x/tools/go/ast/astutil" - "github.com/go-openapi/spec" ) @@ -29,8 +25,8 @@ func (ht responseTypable) Typed(tpe, format string) { ht.header.Typed(tpe, format) } -func bodyTypable(in string, schema *spec.Schema) (swaggerTypable, *spec.Schema) { - if in == "body" { +func bodyTypable(in string, schema *spec.Schema) (swaggerTypable, *spec.Schema) { //nolint:ireturn // polymorphic by design + if in == bodyTag { // get the schema for items on the schema property if schema == nil { schema = new(spec.Schema) @@ -47,7 +43,7 @@ func bodyTypable(in string, schema *spec.Schema) (swaggerTypable, *spec.Schema) return nil, nil } -func (ht responseTypable) Items() swaggerTypable { +func (ht responseTypable) Items() swaggerTypable { //nolint:ireturn // polymorphic by design bdt, schema := bodyTypable(ht.in, ht.response.Schema) if bdt != nil { ht.response.Schema = schema @@ -208,7 +204,7 @@ func (r *responseBuilder) buildFromField(fld *types.Var, tpe types.Type, typable debugLogf("alias(responses.buildFromField): got alias %v to %v", ftpe, ftpe.Rhs()) return r.buildFieldAlias(ftpe, typable, fld, seen) default: - return fmt.Errorf("unknown type for %s: %T", fld.String(), fld.Type()) + return fmt.Errorf("unknown type for %s: %T: %w", fld.String(), fld.Type(), ErrCodeScan) } } @@ -270,14 +266,14 @@ func (r *responseBuilder) buildFromType(otpe types.Type, resp *spec.Response, se debugLogf("alias(responses.buildFromType): got alias %v to %v", tpe, tpe.Rhs()) return r.buildAlias(tpe, resp, seen) default: - return errors.New("anonymous types are currently not supported for responses") + return fmt.Errorf("anonymous types are currently not supported for responses: %w", ErrCodeScan) } } func (r *responseBuilder) buildNamedType(tpe *types.Named, resp *spec.Response, seen map[string]bool) error { o := tpe.Obj() if isAny(o) || isStdError(o) { - return fmt.Errorf("%s type not supported in the context of a responses section definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a responses section definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) // ICI @@ -313,7 +309,7 @@ func (r *responseBuilder) buildNamedType(tpe *types.Named, resp *spec.Response, r.postDecls = append(r.postDecls, sb.postDecls...) return nil } - return fmt.Errorf("responses can only be structs, did you mean for %s to be the response body?", tpe.String()) + return fmt.Errorf("responses can only be structs, did you mean for %s to be the response body?: %w", tpe.String(), ErrCodeScan) } } @@ -322,7 +318,7 @@ func (r *responseBuilder) buildAlias(tpe *types.Alias, resp *spec.Response, seen o := tpe.Obj() if isAny(o) || isStdError(o) { // wrong: TODO(fred): see what object exactly we want to build here - figure out with specific tests - return fmt.Errorf("%s type not supported in the context of a responses section definition", o.Name()) + return fmt.Errorf("%s type not supported in the context of a responses section definition: %w", o.Name(), ErrCodeScan) } mustNotBeABuiltinType(o) mustHaveRightHandSide(tpe) @@ -336,7 +332,7 @@ func (r *responseBuilder) buildAlias(tpe *types.Alias, resp *spec.Response, seen decl, ok := r.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v -> %v", tpe, rhs) + return fmt.Errorf("can't find source file for aliased type: %v -> %v: %w", tpe, rhs, ErrCodeScan) } r.postDecls = append(r.postDecls, decl) // mark the left-hand side as discovered @@ -373,7 +369,7 @@ func (r *responseBuilder) buildAlias(tpe *types.Alias, resp *spec.Response, seen func (r *responseBuilder) buildNamedField(ftpe *types.Named, typable swaggerTypable) error { decl, found := r.ctx.DeclForType(ftpe.Obj().Type()) if !found { - return fmt.Errorf("unable to find package and source file for: %s", ftpe.String()) + return fmt.Errorf("unable to find package and source file for: %s: %w", ftpe.String(), ErrCodeScan) } d := decl.Obj() @@ -424,7 +420,7 @@ func (r *responseBuilder) buildFieldAlias(tpe *types.Alias, typable swaggerTypab decl, ok := r.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v", tpe) + return fmt.Errorf("can't find source file for aliased type: %v: %w", tpe, ErrCodeScan) } r.postDecls = append(r.postDecls, decl) // mark the left-hand side as discovered @@ -436,8 +432,7 @@ func (r *responseBuilder) buildFromStruct(decl *entityDecl, tpe *types.Struct, r return nil } - for i := range tpe.NumFields() { - fld := tpe.Field(i) + for fld := range tpe.Fields() { if fld.Embedded() { if err := r.buildFromType(fld.Type(), resp, seen); err != nil { return err @@ -449,177 +444,91 @@ func (r *responseBuilder) buildFromStruct(decl *entityDecl, tpe *types.Struct, r continue } - tg := tpe.Tag(i) - - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } - - debugLogf("field %s: %s(%T) [%q] ==> %s", fld.Name(), fld.Type().String(), fld.Type(), tg, at.Doc.Text()) - afld = at - break - } - - if afld == nil { - debugLogf("can't find source associated with %s for %s", fld.String(), tpe.String()) - continue - } - - // if the field is annotated with swagger:ignore, ignore it - if ignored(afld.Doc) { - debugLogf("field %v of type %v is deliberately ignored", fld, tpe) - continue - } - - name, ignore, _, _, err := parseJSONTag(afld) - if err != nil { + if err := r.processResponseField(fld, decl, resp, seen); err != nil { return err } - if ignore { - continue - } + } - var in string - // scan for param location first, this changes some behavior down the line - if afld.Doc != nil { - for _, cmt := range afld.Doc.List { - for line := range strings.SplitSeq(cmt.Text, "\n") { - matches := rxIn.FindStringSubmatch(line) - if len(matches) > 0 && len(strings.TrimSpace(matches[1])) > 0 { - in = strings.TrimSpace(matches[1]) - } - } - } + for k := range resp.Headers { + if !seen[k] { + delete(resp.Headers, k) } + } + return nil +} - ps := resp.Headers[name] - - // support swagger:file for response - // An API operation can return a file, such as an image or PDF. In this case, - // define the response schema with type: file and specify the appropriate MIME types in the produces section. - if afld.Doc != nil && fileParam(afld.Doc) { - resp.Schema = &spec.Schema{} - resp.Schema.Typed("file", "") - } else { - debugLogf("build response %v (%v) (not a file)", fld, fld.Type()) - if err := r.buildFromField(fld, fld.Type(), responseTypable{in, &ps, resp}, seen); err != nil { - return err - } - } +func (r *responseBuilder) processResponseField(fld *types.Var, decl *entityDecl, resp *spec.Response, seen map[string]bool) error { + if !fld.Exported() { + return nil + } - if strfmtName, ok := strfmtName(afld.Doc); ok { - ps.Typed("string", strfmtName) - } + afld := findASTField(decl.File, fld.Pos()) + if afld == nil { + debugLogf("can't find source associated with %s", fld.String()) + return nil + } - sp := new(sectionedParser) - sp.setDescription = func(lines []string) { ps.Description = joinDropLast(lines) } - sp.taggers = []tagParser{ - newSingleLineTagParser("maximum", &setMaximum{headerValidations{&ps}, rxf(rxMaximumFmt, "")}), - newSingleLineTagParser("minimum", &setMinimum{headerValidations{&ps}, rxf(rxMinimumFmt, "")}), - newSingleLineTagParser("multipleOf", &setMultipleOf{headerValidations{&ps}, rxf(rxMultipleOfFmt, "")}), - newSingleLineTagParser("minLength", &setMinLength{headerValidations{&ps}, rxf(rxMinLengthFmt, "")}), - newSingleLineTagParser("maxLength", &setMaxLength{headerValidations{&ps}, rxf(rxMaxLengthFmt, "")}), - newSingleLineTagParser("pattern", &setPattern{headerValidations{&ps}, rxf(rxPatternFmt, "")}), - newSingleLineTagParser("collectionFormat", &setCollectionFormat{headerValidations{&ps}, rxf(rxCollectionFormatFmt, "")}), - newSingleLineTagParser("minItems", &setMinItems{headerValidations{&ps}, rxf(rxMinItemsFmt, "")}), - newSingleLineTagParser("maxItems", &setMaxItems{headerValidations{&ps}, rxf(rxMaxItemsFmt, "")}), - newSingleLineTagParser("unique", &setUnique{headerValidations{&ps}, rxf(rxUniqueFmt, "")}), - newSingleLineTagParser("enum", &setEnum{headerValidations{&ps}, rxf(rxEnumFmt, "")}), - newSingleLineTagParser("default", &setDefault{&ps.SimpleSchema, headerValidations{&ps}, rxf(rxDefaultFmt, "")}), - newSingleLineTagParser("example", &setExample{&ps.SimpleSchema, headerValidations{&ps}, rxf(rxExampleFmt, "")}), - } - itemsTaggers := func(items *spec.Items, level int) []tagParser { - // the expression is 1-index based not 0-index - itemsPrefix := fmt.Sprintf(rxItemsPrefixFmt, level+1) - - return []tagParser{ - newSingleLineTagParser(fmt.Sprintf("items%dMaximum", level), &setMaximum{itemsValidations{items}, rxf(rxMaximumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinimum", level), &setMinimum{itemsValidations{items}, rxf(rxMinimumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMultipleOf", level), &setMultipleOf{itemsValidations{items}, rxf(rxMultipleOfFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinLength", level), &setMinLength{itemsValidations{items}, rxf(rxMinLengthFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMaxLength", level), &setMaxLength{itemsValidations{items}, rxf(rxMaxLengthFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dPattern", level), &setPattern{itemsValidations{items}, rxf(rxPatternFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dCollectionFormat", level), &setCollectionFormat{itemsValidations{items}, rxf(rxCollectionFormatFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMinItems", level), &setMinItems{itemsValidations{items}, rxf(rxMinItemsFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dMaxItems", level), &setMaxItems{itemsValidations{items}, rxf(rxMaxItemsFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dUnique", level), &setUnique{itemsValidations{items}, rxf(rxUniqueFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dEnum", level), &setEnum{itemsValidations{items}, rxf(rxEnumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), &setDefault{&items.SimpleSchema, itemsValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dExample", level), &setExample{&items.SimpleSchema, itemsValidations{items}, rxf(rxExampleFmt, itemsPrefix)}), - } - } + if ignored(afld.Doc) { + debugLogf("field %v is deliberately ignored", fld) + return nil + } - var parseArrayTypes func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) - parseArrayTypes = func(expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) { - if items == nil { - return []tagParser{}, nil - } - switch iftpe := expr.(type) { - case *ast.ArrayType: - eleTaggers := itemsTaggers(items, level) - sp.taggers = append(eleTaggers, sp.taggers...) - otherTaggers, err := parseArrayTypes(iftpe.Elt, items.Items, level+1) - if err != nil { - return nil, err - } - return otherTaggers, nil - case *ast.Ident: - taggers := []tagParser{} - if iftpe.Obj == nil { - taggers = itemsTaggers(items, level) - } - otherTaggers, err := parseArrayTypes(expr, items.Items, level+1) - if err != nil { - return nil, err - } - return append(taggers, otherTaggers...), nil - case *ast.SelectorExpr: - otherTaggers, err := parseArrayTypes(iftpe.Sel, items.Items, level+1) - if err != nil { - return nil, err - } - return otherTaggers, nil - case *ast.StarExpr: - otherTaggers, err := parseArrayTypes(iftpe.X, items, level) - if err != nil { - return nil, err + name, ignore, _, _, err := parseJSONTag(afld) + if err != nil { + return err + } + if ignore { + return nil + } + + var in string + // scan for param location first, this changes some behavior down the line + if afld.Doc != nil { + for _, cmt := range afld.Doc.List { + for line := range strings.SplitSeq(cmt.Text, "\n") { + matches := rxIn.FindStringSubmatch(line) + if len(matches) > 0 && len(strings.TrimSpace(matches[1])) > 0 { + in = strings.TrimSpace(matches[1]) } - return otherTaggers, nil - default: - return nil, fmt.Errorf("unknown field type ele for %q", name) - } - } - // check if this is a primitive, if so parse the validations from the - // doc comments of the slice declaration. - if ftped, ok := afld.Type.(*ast.ArrayType); ok { - taggers, err := parseArrayTypes(ftped.Elt, ps.Items, 0) - if err != nil { - return err } - sp.taggers = append(taggers, sp.taggers...) } + } - if err := sp.Parse(afld.Doc); err != nil { + ps := resp.Headers[name] + + // support swagger:file for response + // An API operation can return a file, such as an image or PDF. In this case, + // define the response schema with type: file and specify the appropriate MIME types in the produces section. + if afld.Doc != nil && fileParam(afld.Doc) { + resp.Schema = &spec.Schema{} + resp.Schema.Typed("file", "") + } else { + debugLogf("build response %v (%v) (not a file)", fld, fld.Type()) + if err := r.buildFromField(fld, fld.Type(), responseTypable{in, &ps, resp}, seen); err != nil { return err } + } - if in != "body" { - seen[name] = true - if resp.Headers == nil { - resp.Headers = make(map[string]spec.Header) - } - resp.Headers[name] = ps - } + if strfmtName, ok := strfmtName(afld.Doc); ok { + ps.Typed("string", strfmtName) } - for k := range resp.Headers { - if !seen[k] { - delete(resp.Headers, k) + sp := new(sectionedParser) + sp.setDescription = func(lines []string) { ps.Description = joinDropLast(lines) } + if err := setupResponseHeaderTaggers(sp, &ps, name, afld); err != nil { + return err + } + + if err := sp.Parse(afld.Doc); err != nil { + return err + } + + if in != bodyTag { + seen[name] = true + if resp.Headers == nil { + resp.Headers = make(map[string]spec.Header) } + resp.Headers[name] = ps } return nil } diff --git a/responses_test.go b/responses_test.go index 3710243..1a53a0c 100644 --- a/responses_test.go +++ b/responses_test.go @@ -24,7 +24,12 @@ func getResponse(sctx *scanCtx, nm string) *entityDecl { func TestParseResponses(t *testing.T) { sctx := loadClassificationPkgsCtx(t) responses := make(map[string]spec.Response) - for _, rn := range []string{"ComplexerOne", "SimpleOnes", "SimpleOnesFunc", "ComplexerPointerOne", "SomeResponse", "ValidationError", "Resp", "FileResponse", "GenericError", "ValidationError"} { + responseNames := []string{ + "ComplexerOne", "SimpleOnes", "SimpleOnesFunc", "ComplexerPointerOne", + "SomeResponse", "ValidationError", "Resp", "FileResponse", + "GenericError", "ValidationError", + } + for _, rn := range responseNames { td := getResponse(sctx, rn) prs := &responseBuilder{ ctx: sctx, @@ -34,12 +39,41 @@ func TestParseResponses(t *testing.T) { } require.Len(t, responses, 9) + + t.Run("complexerOne headers", func(t *testing.T) { + assertComplexerOneHeaders(t, responses) + }) + + t.Run("complexerPointerOne headers", func(t *testing.T) { + assertComplexerPointerOneHeaders(t, responses) + }) + + sos, ok := responses["simpleOnes"] + assert.TrueT(t, ok) + assert.Len(t, sos.Headers, 1) + + sosf, ok := responses["simpleOnesFunc"] + assert.TrueT(t, ok) + assert.Len(t, sosf.Headers, 1) + + t.Run("someResponse headers and schema", func(t *testing.T) { + assertSomeResponseHeaders(t, responses) + }) + + res, ok := responses["resp"] + assert.TrueT(t, ok) + assert.NotNil(t, res.Schema) + assert.EqualT(t, "#/definitions/user", res.Schema.Ref.String()) +} + +func assertComplexerOneHeaders(t *testing.T, responses map[string]spec.Response) { + t.Helper() cr, ok := responses["complexerOne"] assert.TrueT(t, ok) assert.Len(t, cr.Headers, 7) for k, header := range cr.Headers { switch k { - case "id": + case paramID: assert.EqualT(t, "integer", header.Type) assert.EqualT(t, "int64", header.Format) case "name": @@ -64,13 +98,16 @@ func TestParseResponses(t *testing.T) { assert.Fail(t, "unknown header: "+k) } } +} +func assertComplexerPointerOneHeaders(t *testing.T, responses map[string]spec.Response) { + t.Helper() cpr, ok := responses["complexerPointerOne"] assert.TrueT(t, ok) assert.Len(t, cpr.Headers, 4) for k, header := range cpr.Headers { switch k { - case "id": + case paramID: assert.EqualT(t, "integer", header.Type) assert.EqualT(t, "int64", header.Format) case "name": @@ -86,22 +123,17 @@ func TestParseResponses(t *testing.T) { assert.Fail(t, "unknown header: "+k) } } +} - sos, ok := responses["simpleOnes"] - assert.TrueT(t, ok) - assert.Len(t, sos.Headers, 1) - - sosf, ok := responses["simpleOnesFunc"] - assert.TrueT(t, ok) - assert.Len(t, sosf.Headers, 1) - +func assertSomeResponseHeaders(t *testing.T, responses map[string]spec.Response) { + t.Helper() res, ok := responses["someResponse"] assert.TrueT(t, ok) assert.Len(t, res.Headers, 7) for k, header := range res.Headers { switch k { - case "id": + case paramID: assert.EqualT(t, "ID of this some response instance.\nids in this application start at 11 and are smaller than 1000", header.Description) assert.EqualT(t, "integer", header.Type) assert.EqualT(t, "int64", header.Format) @@ -249,11 +281,6 @@ func TestParseResponses(t *testing.T) { iprop, ok = itprop.Properties["notes"] assert.TrueT(t, ok) assert.EqualT(t, "Notes to add to this item.\nThis can be used to add special instructions.", iprop.Description) - - res, ok = responses["resp"] - assert.TrueT(t, ok) - assert.NotNil(t, res.Schema) - assert.EqualT(t, "#/definitions/user", res.Schema.Ref.String()) } func TestParseResponses_TransparentAliases(t *testing.T) { diff --git a/resume b/resume new file mode 100644 index 0000000..5edfecf --- /dev/null +++ b/resume @@ -0,0 +1 @@ +claude --resume bede006e-f3b0-4a5e-a6e3-3fa30ffed38d diff --git a/route_params.go b/route_params.go index 1be6e65..2a6248a 100644 --- a/route_params.go +++ b/route_params.go @@ -4,7 +4,7 @@ package codescan import ( - "errors" + "fmt" "slices" "strconv" "strings" @@ -13,53 +13,56 @@ import ( ) const ( - // ParamDescriptionKey indicates the tag used to define a parameter description in swagger:route. - ParamDescriptionKey = "description" - // ParamNameKey indicates the tag used to define a parameter name in swagger:route. - ParamNameKey = "name" - // ParamInKey indicates the tag used to define a parameter location in swagger:route. - ParamInKey = "in" - // ParamRequiredKey indicates the tag used to declare whether a parameter is required in swagger:route. - ParamRequiredKey = "required" - // ParamTypeKey indicates the tag used to define the parameter type in swagger:route. - ParamTypeKey = "type" - // ParamAllowEmptyKey indicates the tag used to indicate whether a parameter allows empty values in swagger:route. - ParamAllowEmptyKey = "allowempty" - - // SchemaMinKey indicates the tag used to indicate the minimum value allowed for this type in swagger:route. - SchemaMinKey = "min" - // SchemaMaxKey indicates the tag used to indicate the maximum value allowed for this type in swagger:route. - SchemaMaxKey = "max" - // SchemaEnumKey indicates the tag used to specify the allowed values for this type in swagger:route. - SchemaEnumKey = "enum" - // SchemaFormatKey indicates the expected format for this field in swagger:route. - SchemaFormatKey = "format" - // SchemaDefaultKey indicates the default value for this field in swagger:route. - SchemaDefaultKey = "default" - // SchemaMinLenKey indicates the minimum length this field in swagger:route. - SchemaMinLenKey = "minlength" - // SchemaMaxLenKey indicates the minimum length this field in swagger:route. - SchemaMaxLenKey = "maxlength" - - // TypeArray is the identifier for an array type in swagger:route. - TypeArray = "array" - // TypeNumber is the identifier for a number type in swagger:route. - TypeNumber = "number" - // TypeInteger is the identifier for an integer type in swagger:route. - TypeInteger = "integer" - // TypeBoolean is the identifier for a boolean type in swagger:route. - TypeBoolean = "boolean" - // TypeBool is the identifier for a boolean type in swagger:route. - TypeBool = "bool" - // TypeObject is the identifier for an object type in swagger:route. - TypeObject = "object" - // TypeString is the identifier for a string type in swagger:route. - TypeString = "string" + // paramDescriptionKey indicates the tag used to define a parameter description in swagger:route. + paramDescriptionKey = "description" + // paramNameKey indicates the tag used to define a parameter name in swagger:route. + paramNameKey = "name" + // paramInKey indicates the tag used to define a parameter location in swagger:route. + paramInKey = "in" + // paramRequiredKey indicates the tag used to declare whether a parameter is required in swagger:route. + paramRequiredKey = "required" + // paramTypeKey indicates the tag used to define the parameter type in swagger:route. + paramTypeKey = "type" + // paramAllowEmptyKey indicates the tag used to indicate whether a parameter allows empty values in swagger:route. + paramAllowEmptyKey = "allowempty" + + // schemaMinKey indicates the tag used to indicate the minimum value allowed for this type in swagger:route. + schemaMinKey = "min" + // schemaMaxKey indicates the tag used to indicate the maximum value allowed for this type in swagger:route. + schemaMaxKey = "max" + // schemaEnumKey indicates the tag used to specify the allowed values for this type in swagger:route. + schemaEnumKey = "enum" + // schemaFormatKey indicates the expected format for this field in swagger:route. + schemaFormatKey = "format" + // schemaDefaultKey indicates the default value for this field in swagger:route. + schemaDefaultKey = "default" + // schemaMinLenKey indicates the minimum length this field in swagger:route. + schemaMinLenKey = "minlength" + // schemaMaxLenKey indicates the minimum length this field in swagger:route. + schemaMaxLenKey = "maxlength" + + // paramInQuery is the default parameter location for query parameters. + paramInQuery = "query" + + // typeArray is the identifier for an array type in swagger:route. + typeArray = "array" + // typeNumber is the identifier for a number type in swagger:route. + typeNumber = "number" + // typeInteger is the identifier for an integer type in swagger:route. + typeInteger = "integer" + // typeBoolean is the identifier for a boolean type in swagger:route. + typeBoolean = "boolean" + // typeBool is the identifier for a boolean type in swagger:route. + typeBool = "bool" + // typeObject is the identifier for an object type in swagger:route. + typeObject = "object" + // typeString is the identifier for a string type in swagger:route. + typeString = "string" ) var ( - validIn = []string{"path", "query", "header", "body", "form"} - basicTypes = []string{TypeInteger, TypeNumber, TypeString, TypeBoolean, TypeBool, TypeArray} + validIn = []string{"path", "query", "header", "body", "form"} //nolint:gochecknoglobals // immutable lookup table + basicTypes = []string{typeInteger, typeNumber, typeString, typeBoolean, typeBool, typeArray} //nolint:gochecknoglobals // immutable lookup table ) func newSetParams(params []*spec.Parameter, setter func([]*spec.Parameter)) *setOpParams { @@ -96,7 +99,7 @@ func (s *setOpParams) Parse(lines []string) error { l = strings.TrimPrefix(l, "+") } - kv := strings.SplitN(l, ":", 2) + kv := strings.SplitN(l, ":", kvParts) if len(kv) <= 1 { continue @@ -106,44 +109,10 @@ func (s *setOpParams) Parse(lines []string) error { value := strings.TrimSpace(kv[1]) if current == nil { - return errors.New("invalid route/operation schema provided") + return fmt.Errorf("invalid route/operation schema provided: %w", ErrCodeScan) } - switch key { - case ParamDescriptionKey: - current.Description = value - case ParamNameKey: - current.Name = value - case ParamInKey: - v := strings.ToLower(value) - if contains(validIn, v) { - current.In = v - } - case ParamRequiredKey: - if v, err := strconv.ParseBool(value); err == nil { - current.Required = v - } - case ParamTypeKey: - if current.Schema == nil { - current.Schema = new(spec.Schema) - } - if contains(basicTypes, value) { - current.Type = strings.ToLower(value) - if current.Type == TypeBool { - current.Type = TypeBoolean - } - } else if ref, err := spec.NewRef("#/definitions/" + value); err == nil { - current.Type = TypeObject - current.Schema.Ref = ref - } - current.Schema.Type = spec.StringOrArray{current.Type} - case ParamAllowEmptyKey: - if v, err := strconv.ParseBool(value); err == nil { - current.AllowEmptyValue = v - } - default: - extraData[key] = value - } + applyParamField(current, extraData, key, value) } s.finalizeParam(current, extraData) @@ -151,6 +120,44 @@ func (s *setOpParams) Parse(lines []string) error { return nil } +func applyParamField(current *spec.Parameter, extraData map[string]string, key, value string) { + switch key { + case paramDescriptionKey: + current.Description = value + case paramNameKey: + current.Name = value + case paramInKey: + v := strings.ToLower(value) + if contains(validIn, v) { + current.In = v + } + case paramRequiredKey: + if v, err := strconv.ParseBool(value); err == nil { + current.Required = v + } + case paramTypeKey: + if current.Schema == nil { + current.Schema = new(spec.Schema) + } + if contains(basicTypes, value) { + current.Type = strings.ToLower(value) + if current.Type == typeBool { + current.Type = typeBoolean + } + } else if ref, err := spec.NewRef("#/definitions/" + value); err == nil { + current.Type = typeObject + current.Schema.Ref = ref + } + current.Schema.Type = spec.StringOrArray{current.Type} + case paramAllowEmptyKey: + if v, err := strconv.ParseBool(value); err == nil { + current.AllowEmptyValue = v + } + default: + extraData[key] = value + } +} + func (s *setOpParams) finalizeParam(param *spec.Parameter, data map[string]string) { if param == nil { return @@ -161,7 +168,7 @@ func (s *setOpParams) finalizeParam(param *spec.Parameter, data map[string]strin // schema is only allowed for parameters in "body" // see https://swagger.io/specification/v2/#parameterObject switch { - case param.In == "body": + case param.In == bodyTag: param.Type = "" case param.Schema != nil: @@ -184,31 +191,31 @@ func processSchema(data map[string]string, param *spec.Parameter) { for key, value := range data { switch key { - case SchemaMinKey: - if t := getType(param.Schema); t == TypeNumber || t == TypeInteger { + case schemaMinKey: + if t := getType(param.Schema); t == typeNumber || t == typeInteger { v, _ := strconv.ParseFloat(value, 64) param.Schema.Minimum = &v } - case SchemaMaxKey: - if t := getType(param.Schema); t == TypeNumber || t == TypeInteger { + case schemaMaxKey: + if t := getType(param.Schema); t == typeNumber || t == typeInteger { v, _ := strconv.ParseFloat(value, 64) param.Schema.Maximum = &v } - case SchemaMinLenKey: - if getType(param.Schema) == TypeArray { + case schemaMinLenKey: + if getType(param.Schema) == typeArray { v, _ := strconv.ParseInt(value, 10, 64) param.Schema.MinLength = &v } - case SchemaMaxLenKey: - if getType(param.Schema) == TypeArray { + case schemaMaxLenKey: + if getType(param.Schema) == typeArray { v, _ := strconv.ParseInt(value, 10, 64) param.Schema.MaxLength = &v } - case SchemaEnumKey: + case schemaEnumKey: enumValues = strings.Split(value, ",") - case SchemaFormatKey: + case schemaFormatKey: param.Schema.Format = value - case SchemaDefaultKey: + case schemaDefaultKey: param.Schema.Default = convert(param.Type, value) } } @@ -234,15 +241,15 @@ func convertEnum(schema *spec.Schema, enumValues []string) { func convert(typeStr, valueStr string) any { switch typeStr { - case TypeInteger: + case typeInteger: fallthrough - case TypeNumber: + case typeNumber: if num, err := strconv.ParseFloat(valueStr, 64); err == nil { return num } - case TypeBoolean: + case typeBoolean: fallthrough - case TypeBool: + case typeBool: if b, err := strconv.ParseBool(valueStr); err == nil { return b } diff --git a/schema.go b/schema.go index 113ab59..911ed8e 100644 --- a/schema.go +++ b/schema.go @@ -5,10 +5,10 @@ package codescan import ( "encoding/json" - "errors" "fmt" "go/ast" "go/importer" + "go/token" "go/types" "log" "os" @@ -17,6 +17,7 @@ import ( "strings" "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/go/packages" "github.com/go-openapi/spec" ) @@ -48,7 +49,7 @@ func (st schemaTypable) Schema() *spec.Schema { return st.schema } -func (st schemaTypable) Items() swaggerTypable { +func (st schemaTypable) Items() swaggerTypable { //nolint:ireturn // polymorphic by design if st.schema.Items == nil { st.schema.Items = new(spec.SchemaOrArray) } @@ -60,7 +61,7 @@ func (st schemaTypable) Items() swaggerTypable { return schemaTypable{st.schema.Items.Schema, st.level + 1} } -func (st schemaTypable) AdditionalProperties() swaggerTypable { +func (st schemaTypable) AdditionalProperties() swaggerTypable { //nolint:ireturn // polymorphic by design if st.schema.AdditionalProperties == nil { st.schema.AdditionalProperties = new(spec.SchemaOrBool) } @@ -264,7 +265,7 @@ func (s *schemaBuilder) buildDeclNamed(tpe *types.Named, schema *spec.Schema) er ps := schemaTypable{schema, 0} ti := s.decl.Pkg.TypesInfo.Types[s.decl.Spec.Type] if !ti.IsType() { - return fmt.Errorf("declaration is not a type: %v", o) + return fmt.Errorf("declaration is not a type: %v: %w", o, ErrCodeScan) } return s.buildFromType(ti.Type, ps) @@ -375,7 +376,7 @@ func (s *schemaBuilder) buildFromType(tpe types.Type, tgt swaggerTypable) error log.Printf("WARNING: functions are not supported %[1]v (%[1]T). Skipped", tpe) return nil default: - panic(fmt.Errorf("ERROR: can't determine refined type %[1]v (%[1]T): %w", titpe, ErrInternal)) + panic(fmt.Errorf("ERROR: can't determine refined type %[1]v (%[1]T): %w", titpe, errInternal)) } } @@ -385,6 +386,7 @@ func (s *schemaBuilder) buildNamedType(titpe *types.Named, tgt swaggerTypable) e log.Printf("WARNING: skipped unsupported builtin type: %v", titpe) return nil } + if isAny(tio) { // e.g type X any or type X interface{} _ = tgt.Schema() @@ -445,123 +447,22 @@ func (s *schemaBuilder) buildNamedType(titpe *types.Named, tgt swaggerTypable) e // invariant: the Underlying cannot be an alias or named type switch utitpe := titpe.Underlying().(type) { case *types.Struct: - debugLogf("found struct: %s.%s", tio.Pkg().Path(), tio.Name()) - - decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()) - if !ok { - debugLogf("could not find model in index: %s.%s", tio.Pkg().Path(), tio.Name()) - return nil - } - - o := decl.Obj() - if isStdTime(o) { - tgt.Typed("string", "date-time") - return nil - } - - if sfnm, isf := strfmtName(cmt); isf { - tgt.Typed("string", sfnm) - return nil - } - - if typeName, ok := typeName(cmt); ok { - _ = swaggerSchemaForType(typeName, tgt) - return nil - } - - return s.makeRef(decl, tgt) + return s.buildNamedStruct(tio, cmt, tgt) case *types.Interface: debugLogf("found interface: %s.%s", tio.Pkg().Path(), tio.Name()) decl, found := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()) if !found { - return fmt.Errorf("can't find source file for type: %v", utitpe) + return fmt.Errorf("can't find source file for type: %v: %w", utitpe, ErrCodeScan) } return s.makeRef(decl, tgt) case *types.Basic: - if unsupportedBuiltinType(utitpe) { - log.Printf("WARNING: skipped unsupported builtin type: %v", utitpe) - return nil - } - - debugLogf("found primitive type: %s.%s", tio.Pkg().Path(), tio.Name()) - - if sfnm, isf := strfmtName(cmt); isf { - tgt.Typed("string", sfnm) - return nil - } - - if enumName, ok := enumName(cmt); ok { - enumValues, enumDesces, _ := s.ctx.FindEnumValues(pkg, enumName) - if len(enumValues) > 0 { - tgt.WithEnum(enumValues...) - enumTypeName := reflect.TypeOf(enumValues[0]).String() - _ = swaggerSchemaForType(enumTypeName, tgt) - } - if len(enumDesces) > 0 { - tgt.WithEnumDescription(strings.Join(enumDesces, "\n")) - } - return nil - } - - if defaultName, ok := defaultName(cmt); ok { - debugLogf("default name: %s", defaultName) - return nil - } - - if typeName, ok := typeName(cmt); ok { - _ = swaggerSchemaForType(typeName, tgt) - return nil - } - - if isAliasParam(tgt) || aliasParam(cmt) { - err := swaggerSchemaForType(utitpe.Name(), tgt) - if err == nil { - return nil - } - } - - if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { - return s.makeRef(decl, tgt) - } - - return swaggerSchemaForType(utitpe.String(), tgt) + return s.buildNamedBasic(tio, pkg, cmt, utitpe, tgt) case *types.Array: - debugLogf("found array type: %s.%s", tio.Pkg().Path(), tio.Name()) - - if sfnm, isf := strfmtName(cmt); isf { - if sfnm == "byte" { - tgt.Typed("string", sfnm) - return nil - } - if sfnm == "bsonobjectid" { - tgt.Typed("string", sfnm) - return nil - } - - tgt.Items().Typed("string", sfnm) - return nil - } - if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { - return s.makeRef(decl, tgt) - } - return s.buildFromType(utitpe.Elem(), tgt.Items()) + return s.buildNamedArray(tio, cmt, utitpe.Elem(), tgt) case *types.Slice: - debugLogf("found slice type: %s.%s", tio.Pkg().Path(), tio.Name()) - - if sfnm, isf := strfmtName(cmt); isf { - if sfnm == "byte" { - tgt.Typed("string", sfnm) - return nil - } - tgt.Items().Typed("string", sfnm) - return nil - } - if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { - return s.makeRef(decl, tgt) - } - return s.buildFromType(utitpe.Elem(), tgt.Items()) + return s.buildNamedSlice(tio, cmt, utitpe.Elem(), tgt) case *types.Map: debugLogf("found map type: %s.%s", tio.Pkg().Path(), tio.Name()) @@ -588,6 +489,123 @@ func (s *schemaBuilder) buildNamedType(titpe *types.Named, tgt swaggerTypable) e } } +func (s *schemaBuilder) buildNamedBasic(tio *types.TypeName, pkg *packages.Package, cmt *ast.CommentGroup, utitpe *types.Basic, tgt swaggerTypable) error { + if unsupportedBuiltinType(utitpe) { + log.Printf("WARNING: skipped unsupported builtin type: %v", utitpe) + return nil + } + + debugLogf("found primitive type: %s.%s", tio.Pkg().Path(), tio.Name()) + + if sfnm, isf := strfmtName(cmt); isf { + tgt.Typed("string", sfnm) + return nil + } + + if enumName, ok := enumName(cmt); ok { + enumValues, enumDesces, _ := s.ctx.FindEnumValues(pkg, enumName) + if len(enumValues) > 0 { + tgt.WithEnum(enumValues...) + enumTypeName := reflect.TypeOf(enumValues[0]).String() + _ = swaggerSchemaForType(enumTypeName, tgt) + } + if len(enumDesces) > 0 { + tgt.WithEnumDescription(strings.Join(enumDesces, "\n")) + } + return nil + } + + if defaultName, ok := defaultName(cmt); ok { + debugLogf("default name: %s", defaultName) + return nil + } + + if typeName, ok := typeName(cmt); ok { + _ = swaggerSchemaForType(typeName, tgt) + return nil + } + + if isAliasParam(tgt) || aliasParam(cmt) { + err := swaggerSchemaForType(utitpe.Name(), tgt) + if err == nil { + return nil + } + } + + if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { + return s.makeRef(decl, tgt) + } + + return swaggerSchemaForType(utitpe.String(), tgt) +} + +func (s *schemaBuilder) buildNamedStruct(tio *types.TypeName, cmt *ast.CommentGroup, tgt swaggerTypable) error { + debugLogf("found struct: %s.%s", tio.Pkg().Path(), tio.Name()) + + decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()) + if !ok { + debugLogf("could not find model in index: %s.%s", tio.Pkg().Path(), tio.Name()) + return nil + } + + o := decl.Obj() + if isStdTime(o) { + tgt.Typed("string", "date-time") + return nil + } + + if sfnm, isf := strfmtName(cmt); isf { + tgt.Typed("string", sfnm) + return nil + } + + if typeName, ok := typeName(cmt); ok { + _ = swaggerSchemaForType(typeName, tgt) + return nil + } + + return s.makeRef(decl, tgt) +} + +func (s *schemaBuilder) buildNamedArray(tio *types.TypeName, cmt *ast.CommentGroup, elem types.Type, tgt swaggerTypable) error { + debugLogf("found array type: %s.%s", tio.Pkg().Path(), tio.Name()) + + if sfnm, isf := strfmtName(cmt); isf { + if sfnm == goTypeByte { + tgt.Typed("string", sfnm) + return nil + } + if sfnm == "bsonobjectid" { + tgt.Typed("string", sfnm) + return nil + } + + tgt.Items().Typed("string", sfnm) + return nil + } + if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { + return s.makeRef(decl, tgt) + } + return s.buildFromType(elem, tgt.Items()) +} + +func (s *schemaBuilder) buildNamedSlice(tio *types.TypeName, cmt *ast.CommentGroup, elem types.Type, tgt swaggerTypable) error { + debugLogf("found slice type: %s.%s", tio.Pkg().Path(), tio.Name()) + + if sfnm, isf := strfmtName(cmt); isf { + if sfnm == goTypeByte { + tgt.Typed("string", sfnm) + return nil + } + tgt.Items().Typed("string", sfnm) + return nil + } + if decl, ok := s.ctx.FindModel(tio.Pkg().Path(), tio.Name()); ok { + return s.makeRef(decl, tgt) + } + return s.buildFromType(elem, tgt.Items()) +} + // buildDeclAlias builds a top-level alias declaration. func (s *schemaBuilder) buildDeclAlias(tpe *types.Alias, tgt swaggerTypable) error { if unsupportedBuiltinType(tpe) { @@ -621,7 +639,7 @@ func (s *schemaBuilder) buildDeclAlias(tpe *types.Alias, tgt swaggerTypable) err decl, ok := s.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v -> %v", tpe, rhs) + return fmt.Errorf("can't find source file for aliased type: %v -> %v: %w", tpe, rhs, ErrCodeScan) } s.postDecls = append(s.postDecls, decl) // mark the left-hand side as discovered @@ -638,7 +656,7 @@ func (s *schemaBuilder) buildDeclAlias(tpe *types.Alias, tgt swaggerTypable) err ro := rtpe.Obj() rdecl, found := s.ctx.FindModel(ro.Pkg().Path(), ro.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } return s.makeRef(rdecl, tgt) @@ -662,7 +680,7 @@ func (s *schemaBuilder) buildDeclAlias(tpe *types.Alias, tgt swaggerTypable) err rdecl, found := s.ctx.FindModel(ro.Pkg().Path(), ro.Name()) if !found { - return fmt.Errorf("can't find source file for target type of alias: %v -> %v", tpe, rtpe) + return fmt.Errorf("can't find source file for target type of alias: %v -> %v: %w", tpe, rtpe, ErrCodeScan) } return s.makeRef(rdecl, tgt) @@ -676,88 +694,72 @@ func (s *schemaBuilder) buildAnonymousInterface(it *types.Interface, tgt swagger tgt.Typed("object", "") for fld := range it.ExplicitMethods() { - if !fld.Exported() { - continue - } - sig, isSignature := fld.Type().(*types.Signature) - if !isSignature { - continue - } - if sig.Params().Len() > 0 { - continue - } - if sig.Results() == nil || sig.Results().Len() != 1 { - continue + if err := s.processAnonInterfaceMethod(fld, it, decl, tgt.Schema()); err != nil { + return err } + } - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - // debugLogf("got %d nodes (exact: %t)", len(ans), isExact) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } + return nil +} - debugLogf("maybe interface field %s: %s(%T)", fld.Name(), fld.Type().String(), fld.Type()) - afld = at - break - } +func (s *schemaBuilder) processAnonInterfaceMethod(fld *types.Func, it *types.Interface, decl *entityDecl, schema *spec.Schema) error { + if !fld.Exported() { + return nil + } + sig, isSignature := fld.Type().(*types.Signature) + if !isSignature { + return nil + } + if sig.Params().Len() > 0 { + return nil + } + if sig.Results() == nil || sig.Results().Len() != 1 { + return nil + } - if afld == nil { - debugLogf("can't find source associated with %s for %s", fld.String(), it.String()) - continue - } + afld := findASTField(decl.File, fld.Pos()) + if afld == nil { + debugLogf("can't find source associated with %s for %s", fld.String(), it.String()) + return nil + } - // if the field is annotated with swagger:ignore, ignore it - if ignored(afld.Doc) { - continue - } + if ignored(afld.Doc) { + return nil + } - name := fld.Name() - if afld.Doc != nil { - for _, cmt := range afld.Doc.List { - for ln := range strings.SplitSeq(cmt.Text, "\n") { - matches := rxName.FindStringSubmatch(ln) - ml := len(matches) - if ml > 1 { - name = matches[ml-1] - } - } - } - } + name := nameOverride(fld.Name(), afld.Doc) - if tgt.Schema().Properties == nil { - tgt.Schema().Properties = make(map[string]spec.Schema) - } - ps := tgt.Schema().Properties[name] - if err := s.buildFromType(sig.Results().At(0).Type(), schemaTypable{&ps, 0}); err != nil { - return err - } - if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { - ps.Typed("string", sfName) - ps.Ref = spec.Ref{} - ps.Items = nil - } + if schema.Properties == nil { + schema.Properties = make(map[string]spec.Schema) + } + ps := schema.Properties[name] + if err := s.buildFromType(sig.Results().At(0).Type(), schemaTypable{&ps, 0}); err != nil { + return err + } + if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { + ps.Typed("string", sfName) + ps.Ref = spec.Ref{} + ps.Items = nil + } - if err := s.createParser(name, tgt.Schema(), &ps, afld).Parse(afld.Doc); err != nil { - return err - } + if err := s.createParser(name, schema, &ps, afld).Parse(afld.Doc); err != nil { + return err + } - if ps.Ref.String() == "" && name != fld.Name() { - ps.AddExtension("x-go-name", fld.Name()) - } + if ps.Ref.String() == "" && name != fld.Name() { + ps.AddExtension("x-go-name", fld.Name()) + } - if s.ctx.app.setXNullableForPointers { - if _, isPointer := fld.Type().(*types.Signature).Results().At(0).Type().(*types.Pointer); isPointer && (ps.Extensions == nil || (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil)) { - ps.AddExtension("x-nullable", true) - } + if s.ctx.app.setXNullableForPointers { + _, isPointer := fld.Type().(*types.Signature).Results().At(0).Type().(*types.Pointer) + noNullableExt := ps.Extensions == nil || + (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil) + if isPointer && noNullableExt { + ps.AddExtension("x-nullable", true) } - - // seen[name] = fld.Name() - tgt.Schema().Properties[name] = ps } + schema.Properties[name] = ps return nil } @@ -783,7 +785,7 @@ func (s *schemaBuilder) buildAlias(tpe *types.Alias, tgt swaggerTypable) error { decl, ok := s.ctx.FindModel(o.Pkg().Path(), o.Name()) if !ok { - return fmt.Errorf("can't find source file for aliased type: %v", tpe) + return fmt.Errorf("can't find source file for aliased type: %v: %w", tpe, ErrCodeScan) } return s.makeRef(decl, tgt) @@ -797,7 +799,7 @@ func (s *schemaBuilder) buildFromMap(titpe *types.Map, tgt swaggerTypable) error sch := tgt.Schema() if sch == nil { - return errors.New("items doesn't support maps") + return fmt.Errorf("items doesn't support maps: %w", ErrCodeScan) } eleProp := schemaTypable{sch, tgt.Level()} @@ -832,68 +834,14 @@ func (s *schemaBuilder) buildFromInterface(decl *entityDecl, it *types.Interface // 1. the embedded interface is decorated with an allOf annotation // 2. the embedded interface is an alias for fld := range it.EmbeddedTypes() { - debugLogf("inspecting embedded type in interface: %v", fld) - var ( - fieldHasAllOf bool - err error - ) - if tgt == nil { tgt = &spec.Schema{} } - switch ftpe := fld.(type) { - case *types.Named: - debugLogf("embedded named type (buildInterface): %v", ftpe) - o := ftpe.Obj() - if isAny(o) || isStdError(o) { - // ignore bultin interfaces - continue - } - - if fieldHasAllOf, err = s.buildNamedInterface(ftpe, flist, decl, schema, seen); err != nil { - return err - } - case *types.Interface: - debugLogf("embedded anonymous interface type (buildInterface): %v", ftpe) // e.g. type X interface{ interface{Error() string}} - var aliasedSchema spec.Schema - ps := schemaTypable{schema: &aliasedSchema} - if err = s.buildAnonymousInterface(ftpe, ps, decl); err != nil { - return err - } - - if aliasedSchema.Ref.String() != "" || len(aliasedSchema.Properties) > 0 || len(aliasedSchema.AllOf) > 0 { - schema.AddToAllOf(aliasedSchema) - fieldHasAllOf = true - } - case *types.Alias: - debugLogf("embedded alias (buildInterface): %v -> %v", ftpe, ftpe.Rhs()) - var aliasedSchema spec.Schema - ps := schemaTypable{schema: &aliasedSchema} - if err = s.buildAlias(ftpe, ps); err != nil { - return err - } - - if aliasedSchema.Ref.String() != "" || len(aliasedSchema.Properties) > 0 || len(aliasedSchema.AllOf) > 0 { - schema.AddToAllOf(aliasedSchema) - fieldHasAllOf = true - } - case *types.Union: // e.g. type X interface{ ~uint16 | ~float32 } - log.Printf("WARNING: union type constraints are not supported yet %[1]v (%[1]T). Skipped", ftpe) - case *types.TypeParam: - log.Printf("WARNING: generic type parameters are not supported yet %[1]v (%[1]T). Skipped", ftpe) - case *types.Chan: - log.Printf("WARNING: channels are not supported %[1]v (%[1]T). Skipped", ftpe) - case *types.Signature: - log.Printf("WARNING: functions are not supported %[1]v (%[1]T). Skipped", ftpe) - default: - log.Printf( - "WARNING: can't figure out object type for allOf named type (%T): %v", - ftpe, ftpe.Underlying(), - ) + fieldHasAllOf, err := s.processEmbeddedType(fld, flist, decl, schema, seen) + if err != nil { + return err } - - debugLogf("got embedded interface: %v {%T}, fieldHasAllOf: %t", fld, fld, fieldHasAllOf) hasAllOf = hasAllOf || fieldHasAllOf } @@ -908,97 +856,160 @@ func (s *schemaBuilder) buildFromInterface(decl *entityDecl, it *types.Interface tgt.Typed("object", "") for fld := range it.ExplicitMethods() { - if !fld.Exported() { - continue - } - sig, isSignature := fld.Type().(*types.Signature) - if !isSignature { - continue - } - if sig.Params().Len() > 0 { - continue - } - if sig.Results() == nil || sig.Results().Len() != 1 { - continue + if err := s.processInterfaceMethod(fld, it, decl, tgt, seen); err != nil { + return err } + } - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - // debugLogf("got %d nodes (exact: %t)", len(ans), isExact) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } + if tgt == nil { + return nil + } + if hasAllOf && len(tgt.Properties) > 0 { + schema.AllOf = append(schema.AllOf, *tgt) + } - debugLogf("maybe interface field %s: %s(%T)", fld.Name(), fld.Type().String(), fld.Type()) - afld = at - break + for k := range tgt.Properties { + if _, ok := seen[k]; !ok { + delete(tgt.Properties, k) } + } - if afld == nil { - debugLogf("can't find source associated with %s for %s", fld.String(), it.String()) - continue - } + return nil +} - // if the field is annotated with swagger:ignore, ignore it - if ignored(afld.Doc) { - continue - } +func (s *schemaBuilder) processEmbeddedType(fld types.Type, flist []*ast.Field, decl *entityDecl, schema *spec.Schema, seen map[string]string) (fieldHasAllOf bool, err error) { + debugLogf("inspecting embedded type in interface: %v", fld) - name := fld.Name() - if afld.Doc != nil { - for _, cmt := range afld.Doc.List { - for ln := range strings.SplitSeq(cmt.Text, "\n") { - matches := rxName.FindStringSubmatch(ln) - ml := len(matches) - if ml > 1 { - name = matches[ml-1] - } - } - } + switch ftpe := fld.(type) { + case *types.Named: + debugLogf("embedded named type (buildInterface): %v", ftpe) + o := ftpe.Obj() + if isAny(o) || isStdError(o) { + return false, nil } - ps := tgt.Properties[name] - if err := s.buildFromType(sig.Results().At(0).Type(), schemaTypable{&ps, 0}); err != nil { - return err + return s.buildNamedInterface(ftpe, flist, decl, schema, seen) + case *types.Interface: + debugLogf("embedded anonymous interface type (buildInterface): %v", ftpe) + var aliasedSchema spec.Schema + ps := schemaTypable{schema: &aliasedSchema} + if err = s.buildAnonymousInterface(ftpe, ps, decl); err != nil { + return false, err } - if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { - ps.Typed("string", sfName) - ps.Ref = spec.Ref{} - ps.Items = nil + if aliasedSchema.Ref.String() != "" || len(aliasedSchema.Properties) > 0 || len(aliasedSchema.AllOf) > 0 { + fieldHasAllOf = true + schema.AddToAllOf(aliasedSchema) } - - if err := s.createParser(name, tgt, &ps, afld).Parse(afld.Doc); err != nil { - return err + case *types.Alias: + debugLogf("embedded alias (buildInterface): %v -> %v", ftpe, ftpe.Rhs()) + var aliasedSchema spec.Schema + ps := schemaTypable{schema: &aliasedSchema} + if err = s.buildAlias(ftpe, ps); err != nil { + return false, err } + if aliasedSchema.Ref.String() != "" || len(aliasedSchema.Properties) > 0 || len(aliasedSchema.AllOf) > 0 { + fieldHasAllOf = true + schema.AddToAllOf(aliasedSchema) + } + case *types.Union: + log.Printf("WARNING: union type constraints are not supported yet %[1]v (%[1]T). Skipped", ftpe) + case *types.TypeParam: + log.Printf("WARNING: generic type parameters are not supported yet %[1]v (%[1]T). Skipped", ftpe) + case *types.Chan: + log.Printf("WARNING: channels are not supported %[1]v (%[1]T). Skipped", ftpe) + case *types.Signature: + log.Printf("WARNING: functions are not supported %[1]v (%[1]T). Skipped", ftpe) + default: + log.Printf( + "WARNING: can't figure out object type for allOf named type (%T): %v", + ftpe, ftpe.Underlying(), + ) + } - if ps.Ref.String() == "" && name != fld.Name() { - ps.AddExtension("x-go-name", fld.Name()) + debugLogf("got embedded interface: %v {%T}, fieldHasAllOf: %t", fld, fld, fieldHasAllOf) + return fieldHasAllOf, nil +} + +func findASTField(file *ast.File, pos token.Pos) *ast.Field { + ans, _ := astutil.PathEnclosingInterval(file, pos, pos) + for _, an := range ans { + if at, valid := an.(*ast.Field); valid { + return at } + } + return nil +} - if s.ctx.app.setXNullableForPointers { - if _, isPointer := fld.Type().(*types.Signature).Results().At(0).Type().(*types.Pointer); isPointer && (ps.Extensions == nil || (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil)) { - ps.AddExtension("x-nullable", true) +func nameOverride(defaultName string, doc *ast.CommentGroup) string { + name := defaultName + if doc != nil { + for _, cmt := range doc.List { + for ln := range strings.SplitSeq(cmt.Text, "\n") { + matches := rxName.FindStringSubmatch(ln) + if ml := len(matches); ml > 1 { + name = matches[ml-1] + } } } + } + return name +} - seen[name] = fld.Name() - tgt.Properties[name] = ps +func (s *schemaBuilder) processInterfaceMethod(fld *types.Func, it *types.Interface, decl *entityDecl, tgt *spec.Schema, seen map[string]string) error { + if !fld.Exported() { + return nil + } + sig, isSignature := fld.Type().(*types.Signature) + if !isSignature { + return nil + } + if sig.Params().Len() > 0 { + return nil + } + if sig.Results() == nil || sig.Results().Len() != 1 { + return nil } - if tgt == nil { + afld := findASTField(decl.File, fld.Pos()) + if afld == nil { + debugLogf("can't find source associated with %s for %s", fld.String(), it.String()) return nil } - if hasAllOf && len(tgt.Properties) > 0 { - schema.AllOf = append(schema.AllOf, *tgt) + + // if the field is annotated with swagger:ignore, ignore it + if ignored(afld.Doc) { + return nil } - for k := range tgt.Properties { - if _, ok := seen[k]; !ok { - delete(tgt.Properties, k) + name := nameOverride(fld.Name(), afld.Doc) + ps := tgt.Properties[name] + if err := s.buildFromType(sig.Results().At(0).Type(), schemaTypable{&ps, 0}); err != nil { + return err + } + if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { + ps.Typed("string", sfName) + ps.Ref = spec.Ref{} + ps.Items = nil + } + + if err := s.createParser(name, tgt, &ps, afld).Parse(afld.Doc); err != nil { + return err + } + + if ps.Ref.String() == "" && name != fld.Name() { + ps.AddExtension("x-go-name", fld.Name()) + } + + if s.ctx.app.setXNullableForPointers { + _, isPointer := fld.Type().(*types.Signature).Results().At(0).Type().(*types.Pointer) + noNullableExt := ps.Extensions == nil || + (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil) + if isPointer && noNullableExt { + ps.AddExtension("x-nullable", true) } } + seen[name] = fld.Name() + tgt.Properties[name] = ps return nil } @@ -1074,8 +1085,23 @@ func (s *schemaBuilder) buildNamedInterface(ftpe *types.Named, flist []*ast.Fiel return hasAllOf, nil } +func extractAllOfClass(doc *ast.CommentGroup, schema *spec.Schema) { + if doc == nil { + return + } + for _, cmt := range doc.List { + for ln := range strings.SplitSeq(cmt.Text, "\n") { + matches := rxAllOf.FindStringSubmatch(ln) + if ml := len(matches); ml > 1 { + if mv := matches[ml-1]; mv != "" { + schema.AddExtension("x-class", mv) + } + } + } + } +} + func (s *schemaBuilder) buildFromStruct(decl *entityDecl, st *types.Struct, schema *spec.Schema, seen map[string]string) error { - s.ctx.FindComments(decl.Pkg, decl.Obj().Name()) cmt, hasComments := s.ctx.FindComments(decl.Pkg, decl.Obj().Name()) if !hasComments { cmt = new(ast.CommentGroup) @@ -1085,14 +1111,53 @@ func (s *schemaBuilder) buildFromStruct(decl *entityDecl, st *types.Struct, sche _ = swaggerSchemaForType(name, schemaTypable{schema: schema}) return nil } - // First check for all of schemas - var tgt *spec.Schema - hasAllOf := false + // First pass: scan anonymous/embedded fields for allOf composition. + // Returns the target schema for properties (may differ from schema when allOf is used). + tgt, hasAllOf, err := s.scanEmbeddedFields(decl, st, schema, seen) + if err != nil { + return err + } + if tgt == nil { + if schema != nil { + tgt = schema + } else { + tgt = &spec.Schema{} + } + } + if tgt.Properties == nil { + tgt.Properties = make(map[string]spec.Schema) + } + tgt.Typed("object", "") + + // Second pass: build properties from non-embedded exported fields. + if err := s.buildStructFields(decl, st, tgt, seen); err != nil { + return err + } + + if tgt == nil { + return nil + } + if hasAllOf && len(tgt.Properties) > 0 { + schema.AllOf = append(schema.AllOf, *tgt) + } + for k := range tgt.Properties { + if _, ok := seen[k]; !ok { + delete(tgt.Properties, k) + } + } + return nil +} + +// scanEmbeddedFields iterates over anonymous struct fields to detect allOf composition. +// It returns: +// - tgt: the schema that should receive properties (nil if no embedded fields were processed, +// schema itself for plain embeds, or a new schema when allOf is detected) +// - hasAllOf: whether any allOf member was found +func (s *schemaBuilder) scanEmbeddedFields(decl *entityDecl, st *types.Struct, schema *spec.Schema, seen map[string]string) (tgt *spec.Schema, hasAllOf bool, err error) { for i := range st.NumFields() { fld := st.Field(i) if !fld.Anonymous() { - // e.g. struct { _ struct{} } debugLogf("skipping field %q for allOf scan because not anonymous", fld.Name()) continue } @@ -1102,33 +1167,19 @@ func (s *schemaBuilder) buildFromStruct(decl *entityDecl, st *types.Struct, sche "maybe allof field(%t) %s: %s (%T) [%q](anon: %t, embedded: %t)", fld.IsField(), fld.Name(), fld.Type().String(), fld.Type(), tg, fld.Anonymous(), fld.Embedded(), ) - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - // debugLogf("got %d nodes (exact: %t)", len(ans), isExact) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } - - debugLogf("maybe allof field %s: %s(%T) [%q]", fld.Name(), fld.Type().String(), fld.Type(), tg) - afld = at - break - } - + afld := findASTField(decl.File, fld.Pos()) if afld == nil { debugLogf("can't find source associated with %s for %s", fld.String(), st.String()) continue } - // if the field is annotated with swagger:ignore, ignore it if ignored(afld.Doc) { continue } _, ignore, _, _, err := parseJSONTag(afld) if err != nil { - return err + return nil, false, err } if ignore { continue @@ -1137,11 +1188,12 @@ func (s *schemaBuilder) buildFromStruct(decl *entityDecl, st *types.Struct, sche _, isAliased := fld.Type().(*types.Alias) if !allOfMember(afld.Doc) && !isAliased { + // Plain embed: merge fields into the main schema if tgt == nil { tgt = schema } if err := s.buildEmbedded(fld.Type(), tgt, seen); err != nil { - return err + return nil, false, err } continue } @@ -1150,151 +1202,99 @@ func (s *schemaBuilder) buildFromStruct(decl *entityDecl, st *types.Struct, sche debugLogf("alias member in struct: %v", fld) } - // if this created an allOf property then we have to rejig the schema var - // because all the fields collected that aren't from embedded structs should go in - // their own proper schema - // first process embedded structs in order of embedding + // allOf member: fields go into a separate schema, embedded struct becomes an allOf entry hasAllOf = true if tgt == nil { tgt = &spec.Schema{} } var newSch spec.Schema - // when the embedded struct is annotated with swagger:allOf it will be used as allOf property - // otherwise the fields will just be included as normal properties if err := s.buildAllOf(fld.Type(), &newSch); err != nil { - return err - } - - if afld.Doc != nil { - for _, cmt := range afld.Doc.List { - for ln := range strings.SplitSeq(cmt.Text, "\n") { - matches := rxAllOf.FindStringSubmatch(ln) - ml := len(matches) - if ml > 1 { - mv := matches[ml-1] - if mv != "" { - schema.AddExtension("x-class", mv) - } - } - } - } + return nil, false, err } + extractAllOfClass(afld.Doc, schema) schema.AllOf = append(schema.AllOf, newSch) } - if tgt == nil { - if schema != nil { - tgt = schema - } else { - tgt = &spec.Schema{} - } - } - // We can finally build the actual schema for the struct - if tgt.Properties == nil { - tgt.Properties = make(map[string]spec.Schema) - } - tgt.Typed("object", "") - - for i := range st.NumFields() { - fld := st.Field(i) - tg := st.Tag(i) - - if fld.Embedded() { - continue - } + return tgt, hasAllOf, nil +} - if !fld.Exported() { - debugLogf("skipping field %s because it's not exported", fld.Name()) - continue +func (s *schemaBuilder) buildStructFields(decl *entityDecl, st *types.Struct, tgt *spec.Schema, seen map[string]string) error { + for fld := range st.Fields() { + if err := s.processStructField(fld, decl, tgt, seen); err != nil { + return err } + } + return nil +} - var afld *ast.Field - ans, _ := astutil.PathEnclosingInterval(decl.File, fld.Pos(), fld.Pos()) - for _, an := range ans { - at, valid := an.(*ast.Field) - if !valid { - continue - } +func (s *schemaBuilder) processStructField(fld *types.Var, decl *entityDecl, tgt *spec.Schema, seen map[string]string) error { + if fld.Embedded() || !fld.Exported() { + return nil + } - debugLogf("field %s: %s(%T) [%q] ==> %s", fld.Name(), fld.Type().String(), fld.Type(), tg, at.Doc.Text()) - afld = at - break - } + afld := findASTField(decl.File, fld.Pos()) + if afld == nil { + debugLogf("can't find source associated with %s", fld.String()) + return nil + } - if afld == nil { - debugLogf("can't find source associated with %s", fld.String()) - continue - } + if ignored(afld.Doc) { + return nil + } - // if the field is annotated with swagger:ignore, ignore it - if ignored(afld.Doc) { - continue - } + name, ignore, isString, omitEmpty, err := parseJSONTag(afld) + if err != nil { + return err + } - name, ignore, isString, omitEmpty, err := parseJSONTag(afld) - if err != nil { - return err - } - if ignore { - for seenTagName, seenFieldName := range seen { - if seenFieldName == fld.Name() { - delete(tgt.Properties, seenTagName) - break - } + if ignore { + for seenTagName, seenFieldName := range seen { + if seenFieldName == fld.Name() { + delete(tgt.Properties, seenTagName) + break } - continue - } - - ps := tgt.Properties[name] - if err = s.buildFromType(fld.Type(), schemaTypable{&ps, 0}); err != nil { - return err - } - if isString { - ps.Typed("string", ps.Format) - ps.Ref = spec.Ref{} - ps.Items = nil - } - if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { - ps.Typed("string", sfName) - ps.Ref = spec.Ref{} - ps.Items = nil - } - - if err = s.createParser(name, tgt, &ps, afld).Parse(afld.Doc); err != nil { - return err - } - - if ps.Ref.String() == "" && name != fld.Name() { - addExtension(&ps.VendorExtensible, "x-go-name", fld.Name()) } + return nil + } - if s.ctx.app.setXNullableForPointers { - if _, isPointer := fld.Type().(*types.Pointer); isPointer && !omitEmpty && - (ps.Extensions == nil || (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil)) { - ps.AddExtension("x-nullable", true) - } - } + ps := tgt.Properties[name] + if err = s.buildFromType(fld.Type(), schemaTypable{&ps, 0}); err != nil { + return err + } + if isString { + ps.Typed("string", ps.Format) + ps.Ref = spec.Ref{} + ps.Items = nil + } - // we have 2 cases: - // 1. field with different name override tag - // 2. field with different name removes tag - // so we need to save both tag&name - seen[name] = fld.Name() - tgt.Properties[name] = ps + if sfName, isStrfmt := strfmtName(afld.Doc); isStrfmt { + ps.Typed("string", sfName) + ps.Ref = spec.Ref{} + ps.Items = nil } - if tgt == nil { - return nil + if err = s.createParser(name, tgt, &ps, afld).Parse(afld.Doc); err != nil { + return err } - if hasAllOf && len(tgt.Properties) > 0 { - schema.AllOf = append(schema.AllOf, *tgt) + + if ps.Ref.String() == "" && name != fld.Name() { + addExtension(&ps.VendorExtensible, "x-go-name", fld.Name()) } - for k := range tgt.Properties { - if _, ok := seen[k]; !ok { - delete(tgt.Properties, k) + + if s.ctx.app.setXNullableForPointers { + if _, isPointer := fld.Type().(*types.Pointer); isPointer && !omitEmpty && + (ps.Extensions == nil || (ps.Extensions["x-nullable"] == nil && ps.Extensions["x-isnullable"] == nil)) { + ps.AddExtension("x-nullable", true) } } + + // we have 2 cases: + // 1. field with different name override tag + // 2. field with different name removes tag + // so we need to save both tag&name + seen[name] = fld.Name() + tgt.Properties[name] = ps return nil } @@ -1321,7 +1321,7 @@ func (s *schemaBuilder) buildAllOf(tpe types.Type, schema *spec.Schema) error { return nil default: log.Printf("WARNING: missing allOf parser for a %T, skipping field", ftpe) - return fmt.Errorf("unable to resolve allOf member for: %v", ftpe) + return fmt.Errorf("unable to resolve allOf member for: %v: %w", ftpe, ErrCodeScan) } } @@ -1330,7 +1330,7 @@ func (s *schemaBuilder) buildNamedAllOf(ftpe *types.Named, schema *spec.Schema) case *types.Struct: decl, found := s.ctx.FindModel(ftpe.Obj().Pkg().Path(), ftpe.Obj().Name()) if !found { - return fmt.Errorf("can't find source file for struct: %s", ftpe.String()) + return fmt.Errorf("can't find source file for struct: %s: %w", ftpe.String(), ErrCodeScan) } if isStdTime(ftpe.Obj()) { @@ -1351,7 +1351,7 @@ func (s *schemaBuilder) buildNamedAllOf(ftpe *types.Named, schema *spec.Schema) case *types.Interface: decl, found := s.ctx.FindModel(ftpe.Obj().Pkg().Path(), ftpe.Obj().Name()) if !found { - return fmt.Errorf("can't find source file for interface: %s", ftpe.String()) + return fmt.Errorf("can't find source file for interface: %s: %w", ftpe.String(), ErrCodeScan) } if sfnm, isf := strfmtName(decl.Comments); isf { @@ -1378,8 +1378,8 @@ func (s *schemaBuilder) buildNamedAllOf(ftpe *types.Named, schema *spec.Schema) "WARNING: can't figure out object type for allOf named type (%T): %v", ftpe, utpe, ) - return fmt.Errorf("unable to locate source file for allOf (%T): %v", - ftpe, utpe, + return fmt.Errorf("unable to locate source file for allOf (%T): %v: %w", + ftpe, utpe, ErrCodeScan, ) } } @@ -1426,7 +1426,7 @@ func (s *schemaBuilder) buildNamedEmbedded(ftpe *types.Named, schema *spec.Schem case *types.Struct: decl, found := s.ctx.FindModel(ftpe.Obj().Pkg().Path(), ftpe.Obj().Name()) if !found { - return fmt.Errorf("can't find source file for struct: %s", ftpe.String()) + return fmt.Errorf("can't find source file for struct: %s: %w", ftpe.String(), ErrCodeScan) } return s.buildFromStruct(decl, utpe, schema, seen) @@ -1447,7 +1447,7 @@ func (s *schemaBuilder) buildNamedEmbedded(ftpe *types.Named, schema *spec.Schem decl, found := s.ctx.FindModel(o.Pkg().Path(), o.Name()) if !found { - return fmt.Errorf("can't find source file for struct: %s", ftpe.String()) + return fmt.Errorf("can't find source file for struct: %s: %w", ftpe.String(), ErrCodeScan) } return s.buildFromInterface(decl, utpe, schema, seen) case *types.Union: // e.g. type X interface{ ~uint16 | ~float32 } @@ -1544,8 +1544,10 @@ func (s *schemaBuilder) createParser(nm string, schema, ps *spec.Schema, fld *as newSingleLineTagParser(fmt.Sprintf("items%dMaxItems", level), &setMaxItems{schemaValidations{items}, rxf(rxMaxItemsFmt, itemsPrefix)}), newSingleLineTagParser(fmt.Sprintf("items%dUnique", level), &setUnique{schemaValidations{items}, rxf(rxUniqueFmt, itemsPrefix)}), newSingleLineTagParser(fmt.Sprintf("items%dEnum", level), &setEnum{schemaValidations{items}, rxf(rxEnumFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), &setDefault{&spec.SimpleSchema{Type: string(schemeType)}, schemaValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}), - newSingleLineTagParser(fmt.Sprintf("items%dExample", level), &setExample{&spec.SimpleSchema{Type: string(schemeType)}, schemaValidations{items}, rxf(rxExampleFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), + &setDefault{&spec.SimpleSchema{Type: string(schemeType)}, schemaValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dExample", level), + &setExample{&spec.SimpleSchema{Type: string(schemeType)}, schemaValidations{items}, rxf(rxExampleFmt, itemsPrefix)}), } } @@ -1580,7 +1582,7 @@ func (s *schemaBuilder) createParser(nm string, schema, ps *spec.Schema, fld *as } return otherTaggers, nil default: - return nil, fmt.Errorf("unknown field type element for %q", nm) + return nil, fmt.Errorf("unknown field type element for %q: %w", nm, ErrCodeScan) } } @@ -1612,7 +1614,7 @@ func schemaVendorExtensibleSetter(meta *spec.Schema) func(json.RawMessage) error } for k := range jsonData { if !rxAllowedExtensions.MatchString(k) { - return fmt.Errorf("invalid schema extension name, should start from `x-`: %s", k) + return fmt.Errorf("invalid schema extension name, should start from `x-`: %s: %w", k, ErrCodeScan) } } meta.Extensions = jsonData @@ -1678,9 +1680,9 @@ func parseJSONTag(field *ast.Field) (name string, ignore, isString, omitEmpty bo func isFieldStringable(tpe ast.Expr) bool { if ident, ok := tpe.(*ast.Ident); ok { switch ident.Name { - case "int", "int8", "int16", "int32", "int64", + case goTypeInt, "int8", goTypeInt16, goTypeInt32, goTypeInt64, "uint", "uint8", "uint16", "uint32", "uint64", - "float64", "string", "bool": + goTypeFloat64, typeString, typeBool: return true } } else if starExpr, ok := tpe.(*ast.StarExpr); ok { diff --git a/schema_go118_test.go b/schema_go118_test.go index d49edab..4c59f17 100644 --- a/schema_go118_test.go +++ b/schema_go118_test.go @@ -12,9 +12,9 @@ import ( "github.com/go-openapi/spec" ) -var go118ClassificationCtx *scanCtx +var go118ClassificationCtx *scanCtx //nolint:gochecknoglobals // test package cache shared across test functions -func loadGo118ClassificationPkgsCtx(t *testing.T, extra ...string) *scanCtx { +func loadGo118ClassificationPkgsCtx(t *testing.T) *scanCtx { t.Helper() if go118ClassificationCtx != nil { @@ -22,9 +22,9 @@ func loadGo118ClassificationPkgsCtx(t *testing.T, extra ...string) *scanCtx { } sctx, err := newScanCtx(&Options{ - Packages: append([]string{ + Packages: []string{ "./goparsing/go118", - }, extra...), + }, WorkDir: "fixtures", }) require.NoError(t, err) diff --git a/schema_test.go b/schema_test.go index b85059c..26f19ca 100644 --- a/schema_test.go +++ b/schema_test.go @@ -24,6 +24,28 @@ const ( fixturesModule = "github.com/go-openapi/codescan/fixtures" ) +func assertHasExtension(t *testing.T, sch spec.Schema, ext string) { + t.Helper() + pkg, hasExt := sch.Extensions.GetString(ext) + assert.TrueT(t, hasExt) + assert.NotEmpty(t, pkg) +} + +func assertHasGoPackageExt(t *testing.T, sch spec.Schema) { + t.Helper() + assertHasExtension(t, sch, "x-go-package") +} + +func assertHasTitle(t *testing.T, sch spec.Schema) { + t.Helper() + assert.NotEmpty(t, sch.Title) +} + +func assertHasNoTitle(t *testing.T, sch spec.Schema) { + t.Helper() + assert.Empty(t, sch.Title) +} + func TestSchemaBuilder_Struct_Tag(t *testing.T) { sctx := loadPetstorePkgsCtx(t) var td *entityDecl @@ -604,45 +626,20 @@ func TestOverridingOneIgnore(t *testing.T) { assert.Len(t, schema.Properties, 2) } -func TestParseSliceFields(t *testing.T) { - sctx := loadClassificationPkgsCtx(t) - decl := getClassificationModel(sctx, "SliceAndDice") - require.NotNil(t, decl) - prs := &schemaBuilder{ - ctx: sctx, - decl: decl, - } - models := make(map[string]spec.Schema) - require.NoError(t, prs.Build(models)) - - schema := models["SliceAndDice"] - - assertArrayProperty(t, &schema, "integer", "ids", "int64", "IDs") - assertArrayProperty(t, &schema, "string", "names", "", "Names") - assertArrayProperty(t, &schema, "string", "uuids", "uuid", "UUIDs") - assertArrayProperty(t, &schema, "object", "embs", "", "Embs") - eSchema := schema.Properties["embs"].Items.Schema - assertArrayProperty(t, eSchema, "integer", "cid", "int64", "CID") - assertArrayProperty(t, eSchema, "string", "baz", "", "Baz") - - assertArrayRef(t, &schema, "tops", "Tops", "#/definitions/Something") - assertArrayRef(t, &schema, "notSels", "NotSels", "#/definitions/NotSelected") - - assertArrayProperty(t, &schema, "integer", "ptrIds", "int64", "PtrIDs") - assertArrayProperty(t, &schema, "string", "ptrNames", "", "PtrNames") - assertArrayProperty(t, &schema, "string", "ptrUuids", "uuid", "PtrUUIDs") - assertArrayProperty(t, &schema, "object", "ptrEmbs", "", "PtrEmbs") - eSchema = schema.Properties["ptrEmbs"].Items.Schema - assertArrayProperty(t, eSchema, "integer", "ptrCid", "int64", "PtrCID") - assertArrayProperty(t, eSchema, "string", "ptrBaz", "", "PtrBaz") - - assertArrayRef(t, &schema, "ptrTops", "PtrTops", "#/definitions/Something") - assertArrayRef(t, &schema, "ptrNotSels", "PtrNotSels", "#/definitions/NotSelected") +type collectionAssertions struct { + assertProperty func(t *testing.T, schema *spec.Schema, typeName, jsonName, format, goName string) + assertRef func(t *testing.T, schema *spec.Schema, jsonName, goName, fragment string) + nestedSchema func(prop spec.Schema) *spec.Schema } -func TestParseMapFields(t *testing.T) { +func testParseCollectionFields( + t *testing.T, + modelName string, + ca collectionAssertions, +) { + t.Helper() sctx := loadClassificationPkgsCtx(t) - decl := getClassificationModel(sctx, "MapTastic") + decl := getClassificationModel(sctx, modelName) require.NotNil(t, decl) prs := &schemaBuilder{ ctx: sctx, @@ -651,29 +648,45 @@ func TestParseMapFields(t *testing.T) { models := make(map[string]spec.Schema) require.NoError(t, prs.Build(models)) - schema := models["MapTastic"] - - assertMapProperty(t, &schema, "integer", "ids", "int64", "IDs") - assertMapProperty(t, &schema, "string", "names", "", "Names") - assertMapProperty(t, &schema, "string", "uuids", "uuid", "UUIDs") - assertMapProperty(t, &schema, "object", "embs", "", "Embs") - eSchema := schema.Properties["embs"].AdditionalProperties.Schema - assertMapProperty(t, eSchema, "integer", "cid", "int64", "CID") - assertMapProperty(t, eSchema, "string", "baz", "", "Baz") - - assertMapRef(t, &schema, "tops", "Tops", "#/definitions/Something") - assertMapRef(t, &schema, "notSels", "NotSels", "#/definitions/NotSelected") + schema := models[modelName] + + ca.assertProperty(t, &schema, "integer", "ids", "int64", "IDs") + ca.assertProperty(t, &schema, "string", "names", "", "Names") + ca.assertProperty(t, &schema, "string", "uuids", "uuid", "UUIDs") + ca.assertProperty(t, &schema, "object", "embs", "", "Embs") + eSchema := ca.nestedSchema(schema.Properties["embs"]) + ca.assertProperty(t, eSchema, "integer", "cid", "int64", "CID") + ca.assertProperty(t, eSchema, "string", "baz", "", "Baz") + + ca.assertRef(t, &schema, "tops", "Tops", "#/definitions/Something") + ca.assertRef(t, &schema, "notSels", "NotSels", "#/definitions/NotSelected") + + ca.assertProperty(t, &schema, "integer", "ptrIds", "int64", "PtrIDs") + ca.assertProperty(t, &schema, "string", "ptrNames", "", "PtrNames") + ca.assertProperty(t, &schema, "string", "ptrUuids", "uuid", "PtrUUIDs") + ca.assertProperty(t, &schema, "object", "ptrEmbs", "", "PtrEmbs") + eSchema = ca.nestedSchema(schema.Properties["ptrEmbs"]) + ca.assertProperty(t, eSchema, "integer", "ptrCid", "int64", "PtrCID") + ca.assertProperty(t, eSchema, "string", "ptrBaz", "", "PtrBaz") + + ca.assertRef(t, &schema, "ptrTops", "PtrTops", "#/definitions/Something") + ca.assertRef(t, &schema, "ptrNotSels", "PtrNotSels", "#/definitions/NotSelected") +} - assertMapProperty(t, &schema, "integer", "ptrIds", "int64", "PtrIDs") - assertMapProperty(t, &schema, "string", "ptrNames", "", "PtrNames") - assertMapProperty(t, &schema, "string", "ptrUuids", "uuid", "PtrUUIDs") - assertMapProperty(t, &schema, "object", "ptrEmbs", "", "PtrEmbs") - eSchema = schema.Properties["ptrEmbs"].AdditionalProperties.Schema - assertMapProperty(t, eSchema, "integer", "ptrCid", "int64", "PtrCID") - assertMapProperty(t, eSchema, "string", "ptrBaz", "", "PtrBaz") +func TestParseSliceFields(t *testing.T) { + testParseCollectionFields(t, "SliceAndDice", collectionAssertions{ + assertProperty: assertArrayProperty, + assertRef: assertArrayRef, + nestedSchema: func(prop spec.Schema) *spec.Schema { return prop.Items.Schema }, + }) +} - assertMapRef(t, &schema, "ptrTops", "PtrTops", "#/definitions/Something") - assertMapRef(t, &schema, "ptrNotSels", "PtrNotSels", "#/definitions/NotSelected") +func TestParseMapFields(t *testing.T) { + testParseCollectionFields(t, "MapTastic", collectionAssertions{ + assertProperty: assertMapProperty, + assertRef: assertMapRef, + nestedSchema: func(prop spec.Schema) *spec.Schema { return prop.AdditionalProperties.Schema }, + }) } func TestInterfaceField(t *testing.T) { @@ -798,10 +811,10 @@ func TestAliasedModels(t *testing.T) { } if assert.Empty(t, names) { // single value types - assertDefinition(t, defs, "SomeStringType", "string", "", "") - assertDefinition(t, defs, "SomeIntType", "integer", "int64", "") - assertDefinition(t, defs, "SomeTimeType", "string", "date-time", "") - assertDefinition(t, defs, "SomeTimedType", "string", "date-time", "") + assertDefinition(t, defs, "SomeStringType", "string", "") + assertDefinition(t, defs, "SomeIntType", "integer", "int64") + assertDefinition(t, defs, "SomeTimeType", "string", "date-time") + assertDefinition(t, defs, "SomeTimedType", "string", "date-time") assertRefDefinition(t, defs, "SomePettedType", "#/definitions/pet", "") assertRefDefinition(t, defs, "SomethingType", "#/definitions/Something", "") @@ -981,31 +994,12 @@ func TestAliasedSchemas(t *testing.T) { _, _ = os.Stdout.Write(yml) } - shouldHaveExt := func(t *testing.T, sch spec.Schema, ext string) { - t.Helper() - pkg, hasExt := sch.Extensions.GetString(ext) - assert.TrueT(t, hasExt) - assert.NotEmpty(t, pkg) - } - shouldHaveGoPackageExt := func(t *testing.T, sch spec.Schema) { - t.Helper() - shouldHaveExt(t, sch, "x-go-package") - } - shouldHaveTitle := func(t *testing.T, sch spec.Schema) { - t.Helper() - assert.NotEmpty(t, sch.Title) - } - shouldNotHaveTitle := func(t *testing.T, sch spec.Schema) { - t.Helper() - assert.Empty(t, sch.Title) - } - t.Run("type aliased to any should yield an empty schema", func(t *testing.T) { anything, ok := sp.Definitions["Anything"] require.TrueT(t, ok) - shouldHaveGoPackageExt(t, anything) - shouldHaveTitle(t, anything) + assertHasGoPackageExt(t, anything) + assertHasTitle(t, anything) // after stripping extension and title, should be empty anything.VendorExtensible = spec.VendorExtensible{} @@ -1017,8 +1011,8 @@ func TestAliasedSchemas(t *testing.T) { empty, ok := sp.Definitions["Empty"] require.TrueT(t, ok) - shouldHaveGoPackageExt(t, empty) - shouldHaveTitle(t, empty) + assertHasGoPackageExt(t, empty) + assertHasTitle(t, empty) // after stripping extension and title, should be empty empty.VendorExtensible = spec.VendorExtensible{} @@ -1029,67 +1023,15 @@ func TestAliasedSchemas(t *testing.T) { }) t.Run("struct fields defined as any or interface{} should yield properties with an empty schema", func(t *testing.T) { - extended, ok := sp.Definitions["ExtendedID"] - require.TrueT(t, ok) - - t.Run("struct with an embedded alias should render as allOf", func(t *testing.T) { - require.Len(t, extended.AllOf, 2) - shouldHaveTitle(t, extended) - - foundAliased := false - foundProps := false - for idx, member := range extended.AllOf { - isProps := len(member.Properties) > 0 - isAlias := member.Ref.String() != "" - - switch { - case isProps: - props := member - t.Run("with property of type any", func(t *testing.T) { - evenMore, ok := props.Properties["EvenMore"] - require.TrueT(t, ok) - assert.Equal(t, spec.Schema{}, evenMore) - }) - - t.Run("with property of type interface{}", func(t *testing.T) { - evenMore, ok := props.Properties["StillMore"] - require.TrueT(t, ok) - assert.Equal(t, spec.Schema{}, evenMore) - }) - - t.Run("non-aliased properties remain unaffected", func(t *testing.T) { - more, ok := props.Properties["more"] - require.TrueT(t, ok) - - shouldHaveExt(t, more, "x-go-name") // because we have a struct tag - shouldNotHaveTitle(t, more) - - // after stripping extension and title, should be empty - more.VendorExtensible = spec.VendorExtensible{} - - strSchema := &spec.Schema{} - strSchema = strSchema.Typed("string", "") - assert.Equal(t, *strSchema, more) - }) - foundProps = true - case isAlias: - assertIsRef(t, &member, "#/definitions/Empty") - foundAliased = true - default: - assert.Failf(t, "embedded members in struct are not as expected", "unexpected member in allOf: %d", idx) - } - } - require.TrueT(t, foundProps) - require.TrueT(t, foundAliased) - }) + testAliasedExtendedIDAllOf(t, sp) }) t.Run("aliased primitive types remain unaffected", func(t *testing.T) { uuid, ok := sp.Definitions["UUID"] require.TrueT(t, ok) - shouldHaveGoPackageExt(t, uuid) - shouldHaveTitle(t, uuid) + assertHasGoPackageExt(t, uuid) + assertHasTitle(t, uuid) // after strip extension, should be equal to integer with format uuid.VendorExtensible = spec.VendorExtensible{} @@ -1137,7 +1079,7 @@ func TestAliasedSchemas(t *testing.T) { require.MapContainsT(t, itemsSchema.Properties, "extra_options") extraOptions := itemsSchema.Properties["extra_options"] - shouldHaveExt(t, extraOptions, "x-go-name") + assertHasExtension(t, extraOptions, "x-go-name") extraOptions.VendorExtensible = spec.VendorExtensible{} empty := spec.Schema{} @@ -1196,8 +1138,8 @@ func TestAliasedSchemas(t *testing.T) { empty, ok := sp.Definitions["empty_redefinition"] require.TrueT(t, ok) - shouldHaveGoPackageExt(t, empty) - shouldNotHaveTitle(t, empty) + assertHasGoPackageExt(t, empty) + assertHasNoTitle(t, empty) // after stripping extension and title, should be empty empty.VendorExtensible = spec.VendorExtensible{} @@ -1318,178 +1260,247 @@ func TestAliasedSchemas(t *testing.T) { }) t.Run("with aliases in interfaces", func(t *testing.T) { - t.Run("should render anonymous interface as a schema", func(t *testing.T) { - iface, ok := sp.Definitions["anonymous_iface"] // e.g. type X interface{ String() string} - require.TrueT(t, ok) + testAliasedInterfaceVariants(t, sp) + }) - require.TrueT(t, iface.Type.Contains("object")) - require.MapContainsT(t, iface.Properties, "String") - prop := iface.Properties["String"] - require.TrueT(t, prop.Type.Contains("string")) - assert.Len(t, iface.Properties, 1) - }) + t.Run("with aliases in embedded types", func(t *testing.T) { + testAliasedEmbeddedTypes(t, sp) + }) +} - t.Run("alias to an anonymous interface should render as a $ref", func(t *testing.T) { - iface, ok := sp.Definitions["anonymous_iface_alias"] - require.TrueT(t, ok) +func testAliasedExtendedIDAllOf(t *testing.T, sp *spec.Swagger) { + t.Helper() + extended, ok := sp.Definitions["ExtendedID"] + require.TrueT(t, ok) - assertIsRef(t, &iface, "#/definitions/anonymous_iface") // points to an anonymous interface - }) + t.Run("struct with an embedded alias should render as allOf", func(t *testing.T) { + require.Len(t, extended.AllOf, 2) + assertHasTitle(t, extended) + + foundAliased := false + foundProps := false + for idx, member := range extended.AllOf { + isProps := len(member.Properties) > 0 + isAlias := member.Ref.String() != "" + + switch { + case isProps: + props := member + t.Run("with property of type any", func(t *testing.T) { + evenMore, ok := props.Properties["EvenMore"] + require.TrueT(t, ok) + assert.Equal(t, spec.Schema{}, evenMore) + }) - t.Run("named interface should render as a schema", func(t *testing.T) { - iface, ok := sp.Definitions["iface"] - require.TrueT(t, ok) + t.Run("with property of type interface{}", func(t *testing.T) { + evenMore, ok := props.Properties["StillMore"] + require.TrueT(t, ok) + assert.Equal(t, spec.Schema{}, evenMore) + }) - require.TrueT(t, iface.Type.Contains("object")) - require.MapContainsT(t, iface.Properties, "Get") - prop := iface.Properties["Get"] - require.TrueT(t, prop.Type.Contains("string")) - assert.Len(t, iface.Properties, 1) - }) + t.Run("non-aliased properties remain unaffected", func(t *testing.T) { + more, ok := props.Properties["more"] + require.TrueT(t, ok) - t.Run("named interface with embedded types should render as allOf", func(t *testing.T) { - iface, ok := sp.Definitions["iface_embedded"] - require.TrueT(t, ok) + assertHasExtension(t, more, "x-go-name") // because we have a struct tag + assertHasNoTitle(t, more) - require.Len(t, iface.AllOf, 2) - foundEmbedded := false - foundMethod := false - for idx, member := range iface.AllOf { - require.TrueT(t, member.Type.Contains("object")) - require.NotEmpty(t, member.Properties) - require.Len(t, member.Properties, 1) - propGet, isEmbedded := member.Properties["Get"] - propMethod, isMethod := member.Properties["Dump"] - - switch { - case isEmbedded: - assert.TrueT(t, propGet.Type.Contains("string")) - foundEmbedded = true - case isMethod: - assert.TrueT(t, propMethod.Type.Contains("array")) - foundMethod = true - default: - assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) - } + // after stripping extension and title, should be empty + more.VendorExtensible = spec.VendorExtensible{} + + strSchema := &spec.Schema{} + strSchema = strSchema.Typed("string", "") + assert.Equal(t, *strSchema, more) + }) + foundProps = true + case isAlias: + assertIsRef(t, &member, "#/definitions/Empty") + foundAliased = true + default: + assert.Failf(t, "embedded members in struct are not as expected", "unexpected member in allOf: %d", idx) } - require.TrueT(t, foundEmbedded) - require.TrueT(t, foundMethod) - }) + } + require.TrueT(t, foundProps) + require.TrueT(t, foundAliased) + }) +} - t.Run("named interface with embedded anonymous interface should render as allOf", func(t *testing.T) { - iface, ok := sp.Definitions["iface_embedded_anonymous"] - require.TrueT(t, ok) +func testAliasedInterfaceVariants(t *testing.T, sp *spec.Swagger) { + t.Helper() - require.Len(t, iface.AllOf, 2) - foundEmbedded := false - foundAnonymous := false - for idx, member := range iface.AllOf { - require.TrueT(t, member.Type.Contains("object")) - require.NotEmpty(t, member.Properties) - require.Len(t, member.Properties, 1) - propGet, isEmbedded := member.Properties["String"] - propAnonymous, isAnonymous := member.Properties["Error"] - - switch { - case isEmbedded: - assert.TrueT(t, propGet.Type.Contains("string")) - foundEmbedded = true - case isAnonymous: - assert.TrueT(t, propAnonymous.Type.Contains("string")) - foundAnonymous = true - default: - assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) - } + t.Run("should render anonymous interface as a schema", func(t *testing.T) { + iface, ok := sp.Definitions["anonymous_iface"] // e.g. type X interface{ String() string} + require.TrueT(t, ok) + + require.TrueT(t, iface.Type.Contains("object")) + require.MapContainsT(t, iface.Properties, "String") + prop := iface.Properties["String"] + require.TrueT(t, prop.Type.Contains("string")) + assert.Len(t, iface.Properties, 1) + }) + + t.Run("alias to an anonymous interface should render as a $ref", func(t *testing.T) { + iface, ok := sp.Definitions["anonymous_iface_alias"] + require.TrueT(t, ok) + + assertIsRef(t, &iface, "#/definitions/anonymous_iface") // points to an anonymous interface + }) + + t.Run("named interface should render as a schema", func(t *testing.T) { + iface, ok := sp.Definitions["iface"] + require.TrueT(t, ok) + + require.TrueT(t, iface.Type.Contains("object")) + require.MapContainsT(t, iface.Properties, "Get") + prop := iface.Properties["Get"] + require.TrueT(t, prop.Type.Contains("string")) + assert.Len(t, iface.Properties, 1) + }) + + t.Run("named interface with embedded types should render as allOf", func(t *testing.T) { + iface, ok := sp.Definitions["iface_embedded"] + require.TrueT(t, ok) + + require.Len(t, iface.AllOf, 2) + foundEmbedded := false + foundMethod := false + for idx, member := range iface.AllOf { + require.TrueT(t, member.Type.Contains("object")) + require.NotEmpty(t, member.Properties) + require.Len(t, member.Properties, 1) + propGet, isEmbedded := member.Properties["Get"] + propMethod, isMethod := member.Properties["Dump"] + + switch { + case isEmbedded: + assert.TrueT(t, propGet.Type.Contains("string")) + foundEmbedded = true + case isMethod: + assert.TrueT(t, propMethod.Type.Contains("array")) + foundMethod = true + default: + assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) } - require.TrueT(t, foundEmbedded) - require.TrueT(t, foundAnonymous) - }) + } + require.TrueT(t, foundEmbedded) + require.TrueT(t, foundMethod) + }) - t.Run("composition of empty interfaces is rendered as an empty schema", func(t *testing.T) { - iface, ok := sp.Definitions["iface_embedded_empty"] - require.TrueT(t, ok) + t.Run("named interface with embedded anonymous interface should render as allOf", func(t *testing.T) { + iface, ok := sp.Definitions["iface_embedded_anonymous"] + require.TrueT(t, ok) - iface.VendorExtensible = spec.VendorExtensible{} - assert.Equal(t, spec.Schema{}, iface) - }) + require.Len(t, iface.AllOf, 2) + foundEmbedded := false + foundAnonymous := false + for idx, member := range iface.AllOf { + require.TrueT(t, member.Type.Contains("object")) + require.NotEmpty(t, member.Properties) + require.Len(t, member.Properties, 1) + propGet, isEmbedded := member.Properties["String"] + propAnonymous, isAnonymous := member.Properties["Error"] + + switch { + case isEmbedded: + assert.TrueT(t, propGet.Type.Contains("string")) + foundEmbedded = true + case isAnonymous: + assert.TrueT(t, propAnonymous.Type.Contains("string")) + foundAnonymous = true + default: + assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) + } + } + require.TrueT(t, foundEmbedded) + require.TrueT(t, foundAnonymous) + }) - t.Run("interface embedded with an alias should be rendered as allOf, with a ref", func(t *testing.T) { - iface, ok := sp.Definitions["iface_embedded_with_alias"] - require.TrueT(t, ok) + t.Run("composition of empty interfaces is rendered as an empty schema", func(t *testing.T) { + iface, ok := sp.Definitions["iface_embedded_empty"] + require.TrueT(t, ok) - require.Len(t, iface.AllOf, 3) - foundEmbedded := false - foundEmbeddedAnon := false - foundRef := false - for idx, member := range iface.AllOf { - propGet, isEmbedded := member.Properties["String"] - propAnonymous, isAnonymous := member.Properties["Dump"] - isRef := member.Ref.String() != "" - - switch { - case isEmbedded: - require.TrueT(t, member.Type.Contains("object")) - require.Len(t, member.Properties, 1) - assert.TrueT(t, propGet.Type.Contains("string")) - foundEmbedded = true - case isAnonymous: - require.TrueT(t, member.Type.Contains("object")) - require.Len(t, member.Properties, 1) - assert.TrueT(t, propAnonymous.Type.Contains("array")) - foundEmbeddedAnon = true - case isRef: - require.Empty(t, member.Properties) - assertIsRef(t, &member, "#/definitions/iface_alias") - foundRef = true - default: - assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) - } + iface.VendorExtensible = spec.VendorExtensible{} + assert.Equal(t, spec.Schema{}, iface) + }) + + t.Run("interface embedded with an alias should be rendered as allOf, with a ref", func(t *testing.T) { + iface, ok := sp.Definitions["iface_embedded_with_alias"] + require.TrueT(t, ok) + + require.Len(t, iface.AllOf, 3) + foundEmbedded := false + foundEmbeddedAnon := false + foundRef := false + for idx, member := range iface.AllOf { + propGet, isEmbedded := member.Properties["String"] + propAnonymous, isAnonymous := member.Properties["Dump"] + isRef := member.Ref.String() != "" + + switch { + case isEmbedded: + require.TrueT(t, member.Type.Contains("object")) + require.Len(t, member.Properties, 1) + assert.TrueT(t, propGet.Type.Contains("string")) + foundEmbedded = true + case isAnonymous: + require.TrueT(t, member.Type.Contains("object")) + require.Len(t, member.Properties, 1) + assert.TrueT(t, propAnonymous.Type.Contains("array")) + foundEmbeddedAnon = true + case isRef: + require.Empty(t, member.Properties) + assertIsRef(t, &member, "#/definitions/iface_alias") + foundRef = true + default: + assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) } - require.TrueT(t, foundEmbedded) - require.TrueT(t, foundEmbeddedAnon) - require.TrueT(t, foundRef) - }) + } + require.TrueT(t, foundEmbedded) + require.TrueT(t, foundEmbeddedAnon) + require.TrueT(t, foundRef) }) +} - t.Run("with aliases in embedded types", func(t *testing.T) { - t.Run("embedded alias should render as a $ref", func(t *testing.T) { - iface, ok := sp.Definitions["embedded_with_alias"] - require.TrueT(t, ok) +func testAliasedEmbeddedTypes(t *testing.T, sp *spec.Swagger) { + t.Helper() - require.Len(t, iface.AllOf, 3) - foundAnything := false - foundUUID := false - foundProps := false - for idx, member := range iface.AllOf { - isProps := len(member.Properties) > 0 - isRef := member.Ref.String() != "" - - switch { - case isProps: - require.TrueT(t, member.Type.Contains("object")) - require.Len(t, member.Properties, 3) - assert.MapContainsT(t, member.Properties, "EvenMore") - foundProps = true - case isRef: - switch member.Ref.String() { - case "#/definitions/Anything": - foundAnything = true - case "#/definitions/UUID": - foundUUID = true - default: - assert.Failf(t, - "embedded members in interface are not as expected", "unexpected $ref for member (%v): %d", - member.Ref, idx, - ) - } + t.Run("embedded alias should render as a $ref", func(t *testing.T) { + iface, ok := sp.Definitions["embedded_with_alias"] + require.TrueT(t, ok) + + require.Len(t, iface.AllOf, 3) + foundAnything := false + foundUUID := false + foundProps := false + for idx, member := range iface.AllOf { + isProps := len(member.Properties) > 0 + isRef := member.Ref.String() != "" + + switch { + case isProps: + require.TrueT(t, member.Type.Contains("object")) + require.Len(t, member.Properties, 3) + assert.MapContainsT(t, member.Properties, "EvenMore") + foundProps = true + case isRef: + switch member.Ref.String() { + case "#/definitions/Anything": + foundAnything = true + case "#/definitions/UUID": + foundUUID = true default: - assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) + assert.Failf(t, + "embedded members in interface are not as expected", "unexpected $ref for member (%v): %d", + member.Ref, idx, + ) } + default: + assert.Failf(t, "embedded members in interface are not as expected", "unexpected member in allOf: %d", idx) } - require.TrueT(t, foundAnything) - require.TrueT(t, foundUUID) - require.TrueT(t, foundProps) - }) + } + require.TrueT(t, foundAnything) + require.TrueT(t, foundUUID) + require.TrueT(t, foundProps) }) } @@ -1587,244 +1598,7 @@ func TestSpecialSchemas(t *testing.T) { }) t.Run("with SpecialTypes struct", func(t *testing.T) { - t.Run("in spite of all the pitfalls, the struct should be rendered", func(t *testing.T) { - special, ok := sp.Definitions["special_types"] - require.TrueT(t, ok) - require.TrueT(t, special.Type.Contains("object")) - props := special.Properties - require.NotEmpty(t, props) - require.Empty(t, special.AllOf) - - t.Run("property pointer to struct should render as a ref", func(t *testing.T) { - ptr, ok := props["PtrStruct"] - require.TrueT(t, ok) - assertIsRef(t, &ptr, "#/definitions/GoStruct") - }) - - t.Run("property as time.Time should render as a formatted string", func(t *testing.T) { - str, ok := props["ShouldBeStringTime"] - require.TrueT(t, ok) - require.TrueT(t, str.Type.Contains("string")) - require.EqualT(t, "date-time", str.Format) - }) - - t.Run("property as *time.Time should also render as a formatted string", func(t *testing.T) { - str, ok := props["ShouldAlsoBeStringTime"] - require.TrueT(t, ok) - require.TrueT(t, str.Type.Contains("string")) - require.EqualT(t, "date-time", str.Format) - }) - - t.Run("property as builtin error should render as a string", func(t *testing.T) { - goerror, ok := props["Err"] - require.TrueT(t, ok) - require.TrueT(t, goerror.Type.Contains("string")) - - t.Run("a type based on the error builtin should be decorated with a x-go-type: error extension", func(t *testing.T) { - val, hasExt := goerror.Extensions.GetString("x-go-type") - assert.TrueT(t, hasExt) - assert.EqualT(t, "error", val) - }) - }) - - t.Run("type recognized as a text marshaler should render as a string", func(t *testing.T) { - m, ok := props["Marshaler"] - require.TrueT(t, ok) - require.TrueT(t, m.Type.Contains("string")) - - t.Run("a type based on the encoding.TextMarshaler decorated with a x-go-type extension", func(t *testing.T) { - val, hasExt := m.Extensions.GetString("x-go-type") - assert.TrueT(t, hasExt) - assert.EqualT(t, fixturesModule+"/goparsing/go123/special.IsATextMarshaler", val) - }) - }) - - t.Run("a json.RawMessage should be recognized and render as an object (yes this is wrong)", func(t *testing.T) { - m, ok := props["Message"] - require.TrueT(t, ok) - require.TrueT(t, m.Type.Contains("object")) - }) - - t.Run("type time.Duration is not recognized as a special type and should just render as a ref", func(t *testing.T) { - d, ok := props["Duration"] - require.TrueT(t, ok) - assertIsRef(t, &d, "#/definitions/Duration") - - t.Run("discovered definition should be an integer", func(t *testing.T) { - duration, ok := sp.Definitions["Duration"] - require.TrueT(t, ok) - require.TrueT(t, duration.Type.Contains("integer")) - require.EqualT(t, "int64", duration.Format) - - t.Run("time.Duration schema should be decorated with a x-go-package: time", func(t *testing.T) { - val, hasExt := duration.Extensions.GetString("x-go-package") - assert.TrueT(t, hasExt) - assert.EqualT(t, "time", val) - }) - }) - }) - - t.Run("with strfmt types", func(t *testing.T) { - t.Run("a strfmt.Date should be recognized and render as a formatted string", func(t *testing.T) { - d, ok := props["FormatDate"] - require.TrueT(t, ok) - require.TrueT(t, d.Type.Contains("string")) - require.EqualT(t, "date", d.Format) - }) - - t.Run("a strfmt.DateTime should be recognized and render as a formatted string", func(t *testing.T) { - d, ok := props["FormatTime"] - require.TrueT(t, ok) - require.TrueT(t, d.Type.Contains("string")) - require.EqualT(t, "date-time", d.Format) - }) - - t.Run("a strfmt.UUID should be recognized and render as a formatted string", func(t *testing.T) { - u, ok := props["FormatUUID"] - require.TrueT(t, ok) - require.TrueT(t, u.Type.Contains("string")) - require.EqualT(t, "uuid", u.Format) - }) - - t.Run("a pointer to strfmt.UUID should be recognized and render as a formatted string", func(t *testing.T) { - u, ok := props["PtrFormatUUID"] - require.TrueT(t, ok) - require.TrueT(t, u.Type.Contains("string")) - require.EqualT(t, "uuid", u.Format) - }) - }) - - t.Run("a property which is a map should render just fine, with a ref", func(t *testing.T) { - mm, ok := props["Map"] - require.TrueT(t, ok) - require.TrueT(t, mm.Type.Contains("object")) - require.NotNil(t, mm.AdditionalProperties) - mapSchema := mm.AdditionalProperties.Schema - require.NotNil(t, mapSchema) - assertIsRef(t, mapSchema, "#/definitions/GoStruct") - }) - - t.Run(`with the "WhatNot" anonymous inner struct`, func(t *testing.T) { - t.Run("should render as an anonymous schema, in spite of all the unsupported things", func(t *testing.T) { - wn, ok := props["WhatNot"] - require.TrueT(t, ok) - require.TrueT(t, wn.Type.Contains("object")) - require.NotEmpty(t, wn.Properties) - - markedProps := make([]string, 0) - - for _, unsupportedProp := range []string{ - "AA", // complex128 - "A", // complex64 - "B", // chan int - "C", // func() - "D", // func() string - "E", // unsafe.Pointer - } { - t.Run("with property "+unsupportedProp, func(t *testing.T) { - prop, ok := wn.Properties[unsupportedProp] - require.TrueT(t, ok) - markedProps = append(markedProps, unsupportedProp) - - t.Run("unsupported type in property should render as an empty schema", func(t *testing.T) { - var empty spec.Schema - require.Equal(t, empty, prop) - }) - }) - } - - for _, supportedProp := range []string{ - "F", // uintptr - "G", - "H", - "I", - "J", - "K", - } { - t.Run("with property "+supportedProp, func(t *testing.T) { - prop, ok := wn.Properties[supportedProp] - require.TrueT(t, ok) - markedProps = append(markedProps, supportedProp) - - switch supportedProp { - case "F": - t.Run("uintptr should render as integer", func(t *testing.T) { - require.TrueT(t, prop.Type.Contains("integer")) - require.EqualT(t, "uint64", prop.Format) - }) - case "G", "H": - t.Run( - "math/big types are not recognized as special types and as TextMarshalers they render as string", - func(t *testing.T) { - require.TrueT(t, prop.Type.Contains("string")) - }) - case "I": - t.Run("go array should render as a json array", func(t *testing.T) { - require.TrueT(t, prop.Type.Contains("array")) - require.NotNil(t, prop.Items) - itemsSchema := prop.Items.Schema - require.NotNil(t, itemsSchema) - - require.TrueT(t, itemsSchema.Type.Contains("integer")) - // [5]byte is not recognized an array of bytes, but of uint8 - // (internally this is the same for go) - require.EqualT(t, "uint8", itemsSchema.Format) - }) - case "J", "K": - t.Run("reflect types should render just fine", func(t *testing.T) { - var dest string - if supportedProp == "J" { - dest = "Type" - } else { - dest = "Value" - } - assertIsRef(t, &prop, "#/definitions/"+dest) - - t.Run("the $ref should exist", func(t *testing.T) { - deref, ok := sp.Definitions[dest] - require.TrueT(t, ok) - val, hasExt := deref.Extensions.GetString("x-go-package") - assert.TrueT(t, hasExt) - assert.EqualT(t, "reflect", val) - }) - }) - } - }) - } - - t.Run("we should not have any property left in WhatNot", func(t *testing.T) { - for _, key := range markedProps { - delete(wn.Properties, key) - } - - require.Empty(t, wn.Properties) - }) - - t.Run("surprisingly, a tagged unexported top-level definition can be rendered", func(t *testing.T) { - unexported, ok := sp.Definitions["unexported"] - require.TrueT(t, ok) - require.TrueT(t, unexported.Type.Contains("object")) - }) - - t.Run("the IsATextMarshaler type is not identified as a discovered type and is not rendered", func(t *testing.T) { - _, ok := sp.Definitions["IsATextMarshaler"] - require.FalseT(t, ok) - }) - - t.Run("a top-level go array should render just fine", func(t *testing.T) { - // Notice that the semantics of fixed length are lost in this mapping - goarray, ok := sp.Definitions["go_array"] - require.TrueT(t, ok) - require.TrueT(t, goarray.Type.Contains("array")) - require.NotNil(t, goarray.Items) - itemsSchema := goarray.Items.Schema - require.NotNil(t, itemsSchema) - require.TrueT(t, itemsSchema.Type.Contains("integer")) - require.EqualT(t, "int64", itemsSchema.Format) - }) - }) - }) - }) + testSpecialTypesStruct(t, sp) }) t.Run("with generic types", func(t *testing.T) { @@ -1903,6 +1677,261 @@ func TestSpecialSchemas(t *testing.T) { }) } +func testSpecialTypesStruct(t *testing.T, sp *spec.Swagger) { + t.Helper() + + t.Run("in spite of all the pitfalls, the struct should be rendered", func(t *testing.T) { + special, ok := sp.Definitions["special_types"] + require.TrueT(t, ok) + require.TrueT(t, special.Type.Contains("object")) + props := special.Properties + require.NotEmpty(t, props) + require.Empty(t, special.AllOf) + + t.Run("property pointer to struct should render as a ref", func(t *testing.T) { + ptr, ok := props["PtrStruct"] + require.TrueT(t, ok) + assertIsRef(t, &ptr, "#/definitions/GoStruct") + }) + + t.Run("property as time.Time should render as a formatted string", func(t *testing.T) { + str, ok := props["ShouldBeStringTime"] + require.TrueT(t, ok) + require.TrueT(t, str.Type.Contains("string")) + require.EqualT(t, "date-time", str.Format) + }) + + t.Run("property as *time.Time should also render as a formatted string", func(t *testing.T) { + str, ok := props["ShouldAlsoBeStringTime"] + require.TrueT(t, ok) + require.TrueT(t, str.Type.Contains("string")) + require.EqualT(t, "date-time", str.Format) + }) + + t.Run("property as builtin error should render as a string", func(t *testing.T) { + goerror, ok := props["Err"] + require.TrueT(t, ok) + require.TrueT(t, goerror.Type.Contains("string")) + + t.Run("a type based on the error builtin should be decorated with a x-go-type: error extension", func(t *testing.T) { + val, hasExt := goerror.Extensions.GetString("x-go-type") + assert.TrueT(t, hasExt) + assert.EqualT(t, "error", val) + }) + }) + + t.Run("type recognized as a text marshaler should render as a string", func(t *testing.T) { + m, ok := props["Marshaler"] + require.TrueT(t, ok) + require.TrueT(t, m.Type.Contains("string")) + + t.Run("a type based on the encoding.TextMarshaler decorated with a x-go-type extension", func(t *testing.T) { + val, hasExt := m.Extensions.GetString("x-go-type") + assert.TrueT(t, hasExt) + assert.EqualT(t, fixturesModule+"/goparsing/go123/special.IsATextMarshaler", val) + }) + }) + + t.Run("a json.RawMessage should be recognized and render as an object (yes this is wrong)", func(t *testing.T) { + m, ok := props["Message"] + require.TrueT(t, ok) + require.TrueT(t, m.Type.Contains("object")) + }) + + t.Run("type time.Duration is not recognized as a special type and should just render as a ref", func(t *testing.T) { + d, ok := props["Duration"] + require.TrueT(t, ok) + assertIsRef(t, &d, "#/definitions/Duration") + + t.Run("discovered definition should be an integer", func(t *testing.T) { + duration, ok := sp.Definitions["Duration"] + require.TrueT(t, ok) + require.TrueT(t, duration.Type.Contains("integer")) + require.EqualT(t, "int64", duration.Format) + + t.Run("time.Duration schema should be decorated with a x-go-package: time", func(t *testing.T) { + val, hasExt := duration.Extensions.GetString("x-go-package") + assert.TrueT(t, hasExt) + assert.EqualT(t, "time", val) + }) + }) + }) + + testSpecialTypesStrfmt(t, props) + + t.Run("a property which is a map should render just fine, with a ref", func(t *testing.T) { + mm, ok := props["Map"] + require.TrueT(t, ok) + require.TrueT(t, mm.Type.Contains("object")) + require.NotNil(t, mm.AdditionalProperties) + mapSchema := mm.AdditionalProperties.Schema + require.NotNil(t, mapSchema) + assertIsRef(t, mapSchema, "#/definitions/GoStruct") + }) + + testSpecialTypesWhatNot(t, sp, props) + }) +} + +func testSpecialTypesStrfmt(t *testing.T, props map[string]spec.Schema) { + t.Helper() + + t.Run("with strfmt types", func(t *testing.T) { + t.Run("a strfmt.Date should be recognized and render as a formatted string", func(t *testing.T) { + d, ok := props["FormatDate"] + require.TrueT(t, ok) + require.TrueT(t, d.Type.Contains("string")) + require.EqualT(t, "date", d.Format) + }) + + t.Run("a strfmt.DateTime should be recognized and render as a formatted string", func(t *testing.T) { + d, ok := props["FormatTime"] + require.TrueT(t, ok) + require.TrueT(t, d.Type.Contains("string")) + require.EqualT(t, "date-time", d.Format) + }) + + t.Run("a strfmt.UUID should be recognized and render as a formatted string", func(t *testing.T) { + u, ok := props["FormatUUID"] + require.TrueT(t, ok) + require.TrueT(t, u.Type.Contains("string")) + require.EqualT(t, "uuid", u.Format) + }) + + t.Run("a pointer to strfmt.UUID should be recognized and render as a formatted string", func(t *testing.T) { + u, ok := props["PtrFormatUUID"] + require.TrueT(t, ok) + require.TrueT(t, u.Type.Contains("string")) + require.EqualT(t, "uuid", u.Format) + }) + }) +} + +func testSpecialTypesWhatNot(t *testing.T, sp *spec.Swagger, props map[string]spec.Schema) { + t.Helper() + + t.Run(`with the "WhatNot" anonymous inner struct`, func(t *testing.T) { + t.Run("should render as an anonymous schema, in spite of all the unsupported things", func(t *testing.T) { + wn, ok := props["WhatNot"] + require.TrueT(t, ok) + require.TrueT(t, wn.Type.Contains("object")) + require.NotEmpty(t, wn.Properties) + + markedProps := make([]string, 0) + + for _, unsupportedProp := range []string{ + "AA", // complex128 + "A", // complex64 + "B", // chan int + "C", // func() + "D", // func() string + "E", // unsafe.Pointer + } { + t.Run("with property "+unsupportedProp, func(t *testing.T) { + prop, ok := wn.Properties[unsupportedProp] + require.TrueT(t, ok) + markedProps = append(markedProps, unsupportedProp) + + t.Run("unsupported type in property should render as an empty schema", func(t *testing.T) { + var empty spec.Schema + require.Equal(t, empty, prop) + }) + }) + } + + for _, supportedProp := range []string{ + "F", // uintptr + "G", + "H", + "I", + "J", + "K", + } { + t.Run("with property "+supportedProp, func(t *testing.T) { + prop, ok := wn.Properties[supportedProp] + require.TrueT(t, ok) + markedProps = append(markedProps, supportedProp) + + switch supportedProp { + case "F": + t.Run("uintptr should render as integer", func(t *testing.T) { + require.TrueT(t, prop.Type.Contains("integer")) + require.EqualT(t, "uint64", prop.Format) + }) + case "G", "H": + t.Run( + "math/big types are not recognized as special types and as TextMarshalers they render as string", + func(t *testing.T) { + require.TrueT(t, prop.Type.Contains("string")) + }) + case "I": + t.Run("go array should render as a json array", func(t *testing.T) { + require.TrueT(t, prop.Type.Contains("array")) + require.NotNil(t, prop.Items) + itemsSchema := prop.Items.Schema + require.NotNil(t, itemsSchema) + + require.TrueT(t, itemsSchema.Type.Contains("integer")) + // [5]byte is not recognized an array of bytes, but of uint8 + // (internally this is the same for go) + require.EqualT(t, "uint8", itemsSchema.Format) + }) + case "J", "K": + t.Run("reflect types should render just fine", func(t *testing.T) { + var dest string + if supportedProp == "J" { + dest = "Type" + } else { + dest = "Value" + } + assertIsRef(t, &prop, "#/definitions/"+dest) + + t.Run("the $ref should exist", func(t *testing.T) { + deref, ok := sp.Definitions[dest] + require.TrueT(t, ok) + val, hasExt := deref.Extensions.GetString("x-go-package") + assert.TrueT(t, hasExt) + assert.EqualT(t, "reflect", val) + }) + }) + } + }) + } + + t.Run("we should not have any property left in WhatNot", func(t *testing.T) { + for _, key := range markedProps { + delete(wn.Properties, key) + } + + require.Empty(t, wn.Properties) + }) + + t.Run("surprisingly, a tagged unexported top-level definition can be rendered", func(t *testing.T) { + unexported, ok := sp.Definitions["unexported"] + require.TrueT(t, ok) + require.TrueT(t, unexported.Type.Contains("object")) + }) + + t.Run("the IsATextMarshaler type is not identified as a discovered type and is not rendered", func(t *testing.T) { + _, ok := sp.Definitions["IsATextMarshaler"] + require.FalseT(t, ok) + }) + + t.Run("a top-level go array should render just fine", func(t *testing.T) { + // Notice that the semantics of fixed length are lost in this mapping + goarray, ok := sp.Definitions["go_array"] + require.TrueT(t, ok) + require.TrueT(t, goarray.Type.Contains("array")) + require.NotNil(t, goarray.Items) + itemsSchema := goarray.Items.Schema + require.NotNil(t, itemsSchema) + require.TrueT(t, itemsSchema.Type.Contains("integer")) + require.EqualT(t, "int64", itemsSchema.Format) + }) + }) + }) +} + func TestEmbeddedAllOf(t *testing.T) { sctx := loadClassificationPkgsCtx(t) decl := getClassificationModel(sctx, "AllOfModel") @@ -2274,18 +2303,14 @@ func assertIsRef(t *testing.T, schema *spec.Schema, fragment string) { assert.EqualT(t, fragment, schema.Ref.String()) } -func assertDefinition(t *testing.T, defs map[string]spec.Schema, defName, typeName, formatName, goName string) { +func assertDefinition(t *testing.T, defs map[string]spec.Schema, defName, typeName, formatName string) { t.Helper() schema, ok := defs[defName] if assert.TrueT(t, ok) { if assert.NotEmpty(t, schema.Type) { assert.EqualT(t, typeName, schema.Type[0]) - if goName != "" { - assert.Equal(t, goName, schema.Extensions["x-go-name"]) - } else { - assert.Nil(t, schema.Extensions["x-go-name"]) - } + assert.Nil(t, schema.Extensions["x-go-name"]) assert.EqualT(t, formatName, schema.Format) } } diff --git a/spec.go b/spec.go index c466657..5dc9ea1 100644 --- a/spec.go +++ b/spec.go @@ -234,29 +234,31 @@ func (s *specBuilder) joinExtraModels() error { func collectOperationsFromInput(input *spec.Swagger) map[string]*spec.Operation { operations := make(map[string]*spec.Operation) - if input != nil && input.Paths != nil { - for _, pth := range input.Paths.Paths { - if pth.Get != nil { - operations[pth.Get.ID] = pth.Get - } - if pth.Post != nil { - operations[pth.Post.ID] = pth.Post - } - if pth.Put != nil { - operations[pth.Put.ID] = pth.Put - } - if pth.Patch != nil { - operations[pth.Patch.ID] = pth.Patch - } - if pth.Delete != nil { - operations[pth.Delete.ID] = pth.Delete - } - if pth.Head != nil { - operations[pth.Head.ID] = pth.Head - } - if pth.Options != nil { - operations[pth.Options.ID] = pth.Options - } + if input == nil || input.Paths == nil { + return operations + } + + for _, pth := range input.Paths.Paths { + if pth.Get != nil { + operations[pth.Get.ID] = pth.Get + } + if pth.Post != nil { + operations[pth.Post.ID] = pth.Post + } + if pth.Put != nil { + operations[pth.Put.ID] = pth.Put + } + if pth.Patch != nil { + operations[pth.Patch.ID] = pth.Patch + } + if pth.Delete != nil { + operations[pth.Delete.ID] = pth.Delete + } + if pth.Head != nil { + operations[pth.Head.ID] = pth.Head + } + if pth.Options != nil { + operations[pth.Options.ID] = pth.Options } } return operations diff --git a/taggers.go b/taggers.go new file mode 100644 index 0000000..fbcd2dc --- /dev/null +++ b/taggers.go @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: Copyright 2015-2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package codescan + +import ( + "fmt" + "go/ast" + + "github.com/go-openapi/spec" +) + +// itemsTaggers builds tag parsers for array items at a given nesting level. +func itemsTaggers(items *spec.Items, level int) []tagParser { + // the expression is 1-index based not 0-index + itemsPrefix := fmt.Sprintf(rxItemsPrefixFmt, level+1) + + return []tagParser{ + newSingleLineTagParser(fmt.Sprintf("items%dMaximum", level), &setMaximum{itemsValidations{items}, rxf(rxMaximumFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMinimum", level), &setMinimum{itemsValidations{items}, rxf(rxMinimumFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMultipleOf", level), &setMultipleOf{itemsValidations{items}, rxf(rxMultipleOfFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMinLength", level), &setMinLength{itemsValidations{items}, rxf(rxMinLengthFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMaxLength", level), &setMaxLength{itemsValidations{items}, rxf(rxMaxLengthFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dPattern", level), &setPattern{itemsValidations{items}, rxf(rxPatternFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dCollectionFormat", level), &setCollectionFormat{itemsValidations{items}, rxf(rxCollectionFormatFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMinItems", level), &setMinItems{itemsValidations{items}, rxf(rxMinItemsFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dMaxItems", level), &setMaxItems{itemsValidations{items}, rxf(rxMaxItemsFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dUnique", level), &setUnique{itemsValidations{items}, rxf(rxUniqueFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dEnum", level), &setEnum{itemsValidations{items}, rxf(rxEnumFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dDefault", level), &setDefault{&items.SimpleSchema, itemsValidations{items}, rxf(rxDefaultFmt, itemsPrefix)}), + newSingleLineTagParser(fmt.Sprintf("items%dExample", level), &setExample{&items.SimpleSchema, itemsValidations{items}, rxf(rxExampleFmt, itemsPrefix)}), + } +} + +// parseArrayTypes recursively builds tag parsers for nested array types. +func parseArrayTypes(sp *sectionedParser, name string, expr ast.Expr, items *spec.Items, level int) ([]tagParser, error) { + if items == nil { + return []tagParser{}, nil + } + switch iftpe := expr.(type) { + case *ast.ArrayType: + eleTaggers := itemsTaggers(items, level) + sp.taggers = append(eleTaggers, sp.taggers...) + return parseArrayTypes(sp, name, iftpe.Elt, items.Items, level+1) + case *ast.SelectorExpr: + return parseArrayTypes(sp, name, iftpe.Sel, items.Items, level+1) + case *ast.Ident: + taggers := []tagParser{} + if iftpe.Obj == nil { + taggers = itemsTaggers(items, level) + } + otherTaggers, err := parseArrayTypes(sp, name, expr, items.Items, level+1) + if err != nil { + return nil, err + } + return append(taggers, otherTaggers...), nil + case *ast.StarExpr: + return parseArrayTypes(sp, name, iftpe.X, items, level) + default: + return nil, fmt.Errorf("unknown field type ele for %q: %w", name, ErrCodeScan) + } +} + +// setupRefParamTaggers configures taggers for a parameter that is a $ref. +func setupRefParamTaggers(sp *sectionedParser, ps *spec.Parameter) { + sp.taggers = []tagParser{ + newSingleLineTagParser("in", &matchOnlyParam{ps, rxIn}), + newSingleLineTagParser("required", &matchOnlyParam{ps, rxRequired}), + newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(ps)), true), + } +} + +// setupInlineParamTaggers configures taggers for a fully-defined inline parameter. +func setupInlineParamTaggers(sp *sectionedParser, ps *spec.Parameter, name string, afld *ast.Field) error { + sp.taggers = []tagParser{ + newSingleLineTagParser("in", &matchOnlyParam{ps, rxIn}), + newSingleLineTagParser("maximum", &setMaximum{paramValidations{ps}, rxf(rxMaximumFmt, "")}), + newSingleLineTagParser("minimum", &setMinimum{paramValidations{ps}, rxf(rxMinimumFmt, "")}), + newSingleLineTagParser("multipleOf", &setMultipleOf{paramValidations{ps}, rxf(rxMultipleOfFmt, "")}), + newSingleLineTagParser("minLength", &setMinLength{paramValidations{ps}, rxf(rxMinLengthFmt, "")}), + newSingleLineTagParser("maxLength", &setMaxLength{paramValidations{ps}, rxf(rxMaxLengthFmt, "")}), + newSingleLineTagParser("pattern", &setPattern{paramValidations{ps}, rxf(rxPatternFmt, "")}), + newSingleLineTagParser("collectionFormat", &setCollectionFormat{paramValidations{ps}, rxf(rxCollectionFormatFmt, "")}), + newSingleLineTagParser("minItems", &setMinItems{paramValidations{ps}, rxf(rxMinItemsFmt, "")}), + newSingleLineTagParser("maxItems", &setMaxItems{paramValidations{ps}, rxf(rxMaxItemsFmt, "")}), + newSingleLineTagParser("unique", &setUnique{paramValidations{ps}, rxf(rxUniqueFmt, "")}), + newSingleLineTagParser("enum", &setEnum{paramValidations{ps}, rxf(rxEnumFmt, "")}), + newSingleLineTagParser("default", &setDefault{&ps.SimpleSchema, paramValidations{ps}, rxf(rxDefaultFmt, "")}), + newSingleLineTagParser("example", &setExample{&ps.SimpleSchema, paramValidations{ps}, rxf(rxExampleFmt, "")}), + newSingleLineTagParser("required", &setRequiredParam{ps}), + newMultiLineTagParser("Extensions", newSetExtensions(spExtensionsSetter(ps)), true), + } + + // check if this is a primitive, if so parse the validations from the + // doc comments of the slice declaration. + if ftped, ok := afld.Type.(*ast.ArrayType); ok { + taggers, err := parseArrayTypes(sp, name, ftped.Elt, ps.Items, 0) + if err != nil { + return err + } + sp.taggers = append(taggers, sp.taggers...) + } + + return nil +} + +// setupResponseHeaderTaggers configures taggers for a response header field. +func setupResponseHeaderTaggers(sp *sectionedParser, ps *spec.Header, name string, afld *ast.Field) error { + sp.taggers = []tagParser{ + newSingleLineTagParser("maximum", &setMaximum{headerValidations{ps}, rxf(rxMaximumFmt, "")}), + newSingleLineTagParser("minimum", &setMinimum{headerValidations{ps}, rxf(rxMinimumFmt, "")}), + newSingleLineTagParser("multipleOf", &setMultipleOf{headerValidations{ps}, rxf(rxMultipleOfFmt, "")}), + newSingleLineTagParser("minLength", &setMinLength{headerValidations{ps}, rxf(rxMinLengthFmt, "")}), + newSingleLineTagParser("maxLength", &setMaxLength{headerValidations{ps}, rxf(rxMaxLengthFmt, "")}), + newSingleLineTagParser("pattern", &setPattern{headerValidations{ps}, rxf(rxPatternFmt, "")}), + newSingleLineTagParser("collectionFormat", &setCollectionFormat{headerValidations{ps}, rxf(rxCollectionFormatFmt, "")}), + newSingleLineTagParser("minItems", &setMinItems{headerValidations{ps}, rxf(rxMinItemsFmt, "")}), + newSingleLineTagParser("maxItems", &setMaxItems{headerValidations{ps}, rxf(rxMaxItemsFmt, "")}), + newSingleLineTagParser("unique", &setUnique{headerValidations{ps}, rxf(rxUniqueFmt, "")}), + newSingleLineTagParser("enum", &setEnum{headerValidations{ps}, rxf(rxEnumFmt, "")}), + newSingleLineTagParser("default", &setDefault{&ps.SimpleSchema, headerValidations{ps}, rxf(rxDefaultFmt, "")}), + newSingleLineTagParser("example", &setExample{&ps.SimpleSchema, headerValidations{ps}, rxf(rxExampleFmt, "")}), + } + + // check if this is a primitive, if so parse the validations from the + // doc comments of the slice declaration. + if ftped, ok := afld.Type.(*ast.ArrayType); ok { + taggers, err := parseArrayTypes(sp, name, ftped.Elt, ps.Items, 0) + if err != nil { + return err + } + sp.taggers = append(taggers, sp.taggers...) + } + + return nil +} diff --git a/yamlparser_test.go b/yamlparser_test.go index 7fde231..592ccef 100644 --- a/yamlparser_test.go +++ b/yamlparser_test.go @@ -45,7 +45,8 @@ func TestYamlParser(t *testing.T) { require.NoError(t, parser.Parse(lines)) require.EqualT(t, 1, setterCalled) - const expectedJSON = `{"SecurityDefinitions":{"api_key":{"name":"X-API-KEY","type":"apiKey"},"petstore_auth":{"scopes":{"read:pets":"read your pets","write:pets":"modify pets in your account"},"type":"oauth2"}}}` + const expectedJSON = `{"SecurityDefinitions":{"api_key":{"name":"X-API-KEY","type":"apiKey"},` + + `"petstore_auth":{"scopes":{"read:pets":"read your pets","write:pets":"modify pets in your account"},"type":"oauth2"}}}` require.JSONEqT(t, expectedJSON, actualJSON) })