From 153c5d0d1e5b9a259fba220d00c8a89e10d8f94f Mon Sep 17 00:00:00 2001 From: Igor Lazarev Date: Thu, 12 Feb 2026 22:55:47 +0300 Subject: [PATCH] refactor: replace golang-jwt dependency with muonsoft/api-testing/jwt - Updated import paths in multiple files to use the new JWT package. - Removed golang-jwt dependency from go.mod and go.sum. - Introduced new JWT implementation with claims handling and signing methods. - Updated assertions to work with the new JWT structure. - Added error handling for token parsing and verification. --- EXAMPLES.md | 2 +- apitest/response_test.go | 2 +- assertions/jwt.go | 10 ++--- assertjson/assertjson_test.go | 6 +-- assertjson/jwt.go | 10 ++--- go.mod | 1 - go.sum | 2 - jwt/errors.go | 11 +++++ jwt/hmac.go | 43 ++++++++++++++++++ jwt/map_claims.go | 5 +++ jwt/parse.go | 84 +++++++++++++++++++++++++++++++++++ jwt/signing.go | 28 ++++++++++++ jwt/token.go | 67 ++++++++++++++++++++++++++++ 13 files changed, 253 insertions(+), 18 deletions(-) create mode 100644 jwt/errors.go create mode 100644 jwt/hmac.go create mode 100644 jwt/map_claims.go create mode 100644 jwt/parse.go create mode 100644 jwt/signing.go create mode 100644 jwt/token.go diff --git a/EXAMPLES.md b/EXAMPLES.md index 283f58e..3c84930 100644 --- a/EXAMPLES.md +++ b/EXAMPLES.md @@ -234,7 +234,7 @@ assertjson.Has(t, data, func(json *assertjson.AssertJSON) { ```go import ( "time" - "github.com/golang-jwt/jwt/v5" + "github.com/muonsoft/api-testing/jwt" ) assertjson.Has(t, data, func(json *assertjson.AssertJSON) { diff --git a/apitest/response_test.go b/apitest/response_test.go index 8891844..01c222f 100644 --- a/apitest/response_test.go +++ b/apitest/response_test.go @@ -5,7 +5,7 @@ import ( "net/http/httptest" "testing" - "github.com/golang-jwt/jwt/v5" + "github.com/muonsoft/api-testing/jwt" "github.com/muonsoft/api-testing/apitest" "github.com/muonsoft/api-testing/assertjson" "github.com/muonsoft/api-testing/internal/mock" diff --git a/assertions/jwt.go b/assertions/jwt.go index dfc5f9b..4f72629 100644 --- a/assertions/jwt.go +++ b/assertions/jwt.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/muonsoft/api-testing/jwt" "github.com/muonsoft/api-testing/assertjson" "github.com/stretchr/testify/assert" ) @@ -84,7 +84,7 @@ func (a *JWTAssertion) WithPayload(jsonAssert assertjson.JSONAssertFunc) *JWTAss jsonAssert(assertjson.NewAssertJSON( a.t, a.messagePrefix+`is JWT with payload: `, - map[string]interface{}(a.token.Claims.(jwt.MapClaims)), + map[string]interface{}(a.token.Claims), )) return a @@ -188,7 +188,7 @@ func (a *JWTAssertion) Assert(assertFunc func(tb testing.TB, token *jwt.Token)) func (a *JWTAssertion) assertStringField(title string, name string, expected string, msgAndArgs ...interface{}) *JWTAssertion { a.t.Helper() - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { return a.failOnMissingField(title, name, strconv.Quote(expected), msgAndArgs...) } @@ -208,7 +208,7 @@ func (a *JWTAssertion) assertStringField(title string, name string, expected str func (a *JWTAssertion) assertStringsField(title string, name string, expected []string, msgAndArgs ...interface{}) *JWTAssertion { a.t.Helper() - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { return a.failOnMissingField(title, name, wrapArray(formatStrings(expected)), msgAndArgs...) } @@ -226,7 +226,7 @@ func (a *JWTAssertion) assertStringsField(title string, name string, expected [] } func (a *JWTAssertion) assertTimeField(title string, name string) *TimeAssertion { - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { a.failOnMissingField(title, name, "") return nil diff --git a/assertjson/assertjson_test.go b/assertjson/assertjson_test.go index 5d7ce3b..41f04c2 100644 --- a/assertjson/assertjson_test.go +++ b/assertjson/assertjson_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/gofrs/uuid/v5" - "github.com/golang-jwt/jwt/v5" + "github.com/muonsoft/api-testing/jwt" "github.com/muonsoft/api-testing/assertjson" "github.com/muonsoft/api-testing/internal/mock" "github.com/stretchr/testify/assert" @@ -2783,7 +2783,7 @@ func TestHas(t *testing.T) { json.Node().IsJWT(getJWTSecret).WithExpiresAt() }, wantMessages: []string{ - `failed asserting that JSON node "" is JWT: token has invalid claims: invalid type for claim: exp is invalid`, + `is JWT with expires at ("exp") : number is expected`, }, }, { @@ -2820,7 +2820,7 @@ func TestHas(t *testing.T) { json.Node().IsJWT(getJWTSecret).WithNotBefore() }, wantMessages: []string{ - `failed asserting that JSON node "" is JWT: token has invalid claims: invalid type for claim: nbf is invalid`, + `is JWT with not before ("nbf") : number is expected`, }, }, { diff --git a/assertjson/jwt.go b/assertjson/jwt.go index 97ce7ea..f125aa2 100644 --- a/assertjson/jwt.go +++ b/assertjson/jwt.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" + "github.com/muonsoft/api-testing/jwt" "github.com/stretchr/testify/assert" ) @@ -107,7 +107,7 @@ func (a *JWTAssertion) WithPayload(jsonAssert JSONAssertFunc) *JWTAssertion { jsonAssert(&AssertJSON{ t: a.t, message: a.message + `is JWT with payload: `, - data: map[string]interface{}(a.token.Claims.(jwt.MapClaims)), + data: map[string]interface{}(a.token.Claims), }) return a @@ -217,7 +217,7 @@ func (a *JWTAssertion) Assert(assertFunc func(tb testing.TB, token *jwt.Token)) func (a *JWTAssertion) assertStringField(title string, name string, expected string, msgAndArgs ...interface{}) *JWTAssertion { a.t.Helper() - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { return a.failOnMissingField(title, name, strconv.Quote(expected), msgAndArgs...) } @@ -237,7 +237,7 @@ func (a *JWTAssertion) assertStringField(title string, name string, expected str func (a *JWTAssertion) assertStringsField(title string, name string, expected []string, msgAndArgs ...interface{}) *JWTAssertion { a.t.Helper() - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { return a.failOnMissingField(title, name, wrapArray(formatStrings(expected)), msgAndArgs...) } @@ -255,7 +255,7 @@ func (a *JWTAssertion) assertStringsField(title string, name string, expected [] } func (a *JWTAssertion) assertTimeField(title string, name string) *TimeAssertion { - raw, exist := a.token.Claims.(jwt.MapClaims)[name] + raw, exist := a.token.Claims[name] if !exist { a.failOnMissingField(title, name, "") return nil diff --git a/go.mod b/go.mod index 0d083c9..74655a0 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.16 require ( github.com/gofrs/uuid/v5 v5.3.2 - github.com/golang-jwt/jwt/v5 v5.2.2 github.com/json-iterator/go v1.1.12 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 59c2830..41f6199 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0= github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/jwt/errors.go b/jwt/errors.go new file mode 100644 index 0000000..af2bced --- /dev/null +++ b/jwt/errors.go @@ -0,0 +1,11 @@ +package jwt + +import "errors" + +var ( + ErrTokenMalformed = errors.New("token is malformed") + ErrTokenUnverifiable = errors.New("token is unverifiable") + ErrTokenSignatureInvalid = errors.New("token signature is invalid") + ErrSignatureInvalid = errors.New("signature is invalid") + ErrInvalidKeyType = errors.New("key is of invalid type") +) diff --git a/jwt/hmac.go b/jwt/hmac.go new file mode 100644 index 0000000..afda6a8 --- /dev/null +++ b/jwt/hmac.go @@ -0,0 +1,43 @@ +package jwt + +import ( + "crypto/hmac" + "crypto/sha256" +) + +// SigningMethodHMAC implements HS256. +type SigningMethodHMAC struct { + Name string +} + +var signingMethodHS256 = &SigningMethodHMAC{Name: "HS256"} + +// SigningMethodHS256 is the HMAC-SHA256 signing method. +var SigningMethodHS256 SigningMethod = signingMethodHS256 + +func (m *SigningMethodHMAC) Alg() string { + return m.Name +} + +func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error { + keyBytes, ok := key.([]byte) + if !ok { + return ErrInvalidKeyType + } + hasher := hmac.New(sha256.New, keyBytes) + hasher.Write([]byte(signingString)) + if !hmac.Equal(sig, hasher.Sum(nil)) { + return ErrSignatureInvalid + } + return nil +} + +func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) { + keyBytes, ok := key.([]byte) + if !ok { + return nil, ErrInvalidKeyType + } + hasher := hmac.New(sha256.New, keyBytes) + hasher.Write([]byte(signingString)) + return hasher.Sum(nil), nil +} diff --git a/jwt/map_claims.go b/jwt/map_claims.go new file mode 100644 index 0000000..ebfda89 --- /dev/null +++ b/jwt/map_claims.go @@ -0,0 +1,5 @@ +package jwt + +// MapClaims is a claims type that uses map[string]interface{} for JSON decoding. +// Used as the default claims type for parsing and creating tokens. +type MapClaims map[string]interface{} diff --git a/jwt/parse.go b/jwt/parse.go new file mode 100644 index 0000000..6e5d0a2 --- /dev/null +++ b/jwt/parse.go @@ -0,0 +1,84 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" +) + +const tokenDelimiter = "." + +// Parse parses and verifies the JWT and returns the token. +// Only HS256 signature verification is supported. +func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + parts, ok := splitToken(tokenString) + if !ok { + return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrTokenMalformed) + } + + token := &Token{Raw: tokenString} + + // Decode header + headerBytes, err := decodeSegment(parts[0]) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err) + } + if err := json.Unmarshal(headerBytes, &token.Header); err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err) + } + + // Decode claims + claimBytes, err := decodeSegment(parts[1]) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err) + } + token.Claims = MapClaims{} + if err := json.Unmarshal(claimBytes, &token.Claims); err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err) + } + + // Resolve signing method from header + alg, _ := token.Header["alg"].(string) + if alg == "" { + return nil, fmt.Errorf("%w: signing method (alg) is unspecified", ErrTokenUnverifiable) + } + token.Method = &methodByAlg{alg: alg} + + // Decode signature + token.Signature, err = decodeSegment(parts[2]) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenMalformed, err) + } + + if keyFunc == nil { + return nil, fmt.Errorf("%w: no keyfunc was provided", ErrTokenUnverifiable) + } + key, err := keyFunc(token) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenUnverifiable, err) + } + + signingString := strings.Join(parts[0:2], ".") + if err := token.Method.Verify(signingString, token.Signature, key); err != nil { + return nil, fmt.Errorf("%w: %v", ErrTokenSignatureInvalid, err) + } + + token.Valid = true + return token, nil +} + +func splitToken(s string) ([]string, bool) { + parts := strings.SplitN(s, tokenDelimiter, 4) + if len(parts) != 3 { + return nil, false + } + if parts[0] == "" || parts[1] == "" || parts[2] == "" { + return nil, false + } + return parts, true +} + +func decodeSegment(seg string) ([]byte, error) { + return base64.RawURLEncoding.DecodeString(seg) +} diff --git a/jwt/signing.go b/jwt/signing.go new file mode 100644 index 0000000..d95df69 --- /dev/null +++ b/jwt/signing.go @@ -0,0 +1,28 @@ +package jwt + +// SigningMethod is used to sign and verify tokens. +type SigningMethod interface { + Verify(signingString string, sig []byte, key interface{}) error + Sign(signingString string, key interface{}) ([]byte, error) + Alg() string +} + +// methodByAlg holds algorithm name from token header; only HS256 is verified. +type methodByAlg struct { + alg string +} + +func (m *methodByAlg) Alg() string { + return m.alg +} + +func (m *methodByAlg) Verify(signingString string, sig []byte, key interface{}) error { + if m.alg != "HS256" { + return ErrTokenSignatureInvalid + } + return signingMethodHS256.Verify(signingString, sig, key) +} + +func (m *methodByAlg) Sign(signingString string, key interface{}) ([]byte, error) { + return nil, ErrTokenUnverifiable +} diff --git a/jwt/token.go b/jwt/token.go new file mode 100644 index 0000000..fc2d06f --- /dev/null +++ b/jwt/token.go @@ -0,0 +1,67 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" +) + +// Keyfunc is used by Parse to supply the key for verification. +// The function receives the parsed but unverified Token (e.g. to read "alg" from header). +type Keyfunc func(*Token) (interface{}, error) + +// Token represents a JWT. +type Token struct { + Raw string + Method SigningMethod + Header map[string]interface{} + Claims MapClaims + Signature []byte + Valid bool +} + +// NewWithClaims creates a new Token with the given signing method and claims. +func NewWithClaims(method SigningMethod, claims MapClaims) *Token { + if claims == nil { + claims = MapClaims{} + } + return &Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": method.Alg(), + }, + Claims: claims, + Method: method, + } +} + +// SignedString signs the token and returns the full JWT string. +func (t *Token) SignedString(key interface{}) (string, error) { + sstr, err := t.SigningString() + if err != nil { + return "", err + } + sig, err := t.Method.Sign(sstr, key) + if err != nil { + return "", err + } + t.Signature = sig + return sstr + "." + t.EncodeSegment(sig), nil +} + +// SigningString returns the base64url(header).base64url(claims) string. +func (t *Token) SigningString() (string, error) { + h, err := json.Marshal(t.Header) + if err != nil { + return "", err + } + c, err := json.Marshal(t.Claims) + if err != nil { + return "", err + } + return t.EncodeSegment(h) + "." + t.EncodeSegment(c), nil +} + +// EncodeSegment encodes bytes to base64url without padding. +func (t *Token) EncodeSegment(seg []byte) string { + return base64.RawURLEncoding.EncodeToString(seg) +}