diff --git a/auth_jwt.go b/auth_jwt.go index 35db57b..5bd46c9 100644 --- a/auth_jwt.go +++ b/auth_jwt.go @@ -2,7 +2,6 @@ package jwt import ( "crypto/rsa" - "encoding/json" "errors" "net/http" "os" @@ -10,7 +9,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" ) // MapClaims type that uses the map[string]interface{} for JSON decoding @@ -155,7 +154,8 @@ type GinJWTMiddleware struct { // CookieSameSite allow use http.SameSite cookie param CookieSameSite http.SameSite - // ParseOptions allow to modify jwt's parser methods + // ParseOptions allow to modify jwt's parser methods. + // WithTimeFunc is always added to ensure the TimeFunc is propagated to the validator ParseOptions []jwt.ParserOption } @@ -414,6 +414,12 @@ func (mw *GinJWTMiddleware) MiddlewareInit() error { if mw.Key == nil { return ErrMissingSecretKey } + + if len(mw.ParseOptions) == 0 { + mw.ParseOptions = []jwt.ParserOption{} + } + mw.ParseOptions = append(mw.ParseOptions, jwt.WithTimeFunc(mw.TimeFunc)) + return nil } @@ -427,31 +433,24 @@ func (mw *GinJWTMiddleware) MiddlewareFunc() gin.HandlerFunc { func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) { claims, err := mw.GetClaimsFromJWT(c) if err != nil { - mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) - return - } - - switch v := claims["exp"].(type) { - case nil: - mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, c)) - return - case float64: - if int64(v) < mw.TimeFunc().Unix() { + if errors.Is(err, jwt.ErrTokenExpired) { mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, c)) return - } - case json.Number: - n, err := v.Int64() - if err != nil { + } else if errors.Is(err, jwt.ErrInvalidType) && strings.Contains(err.Error(), "exp is invalid") { mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, c)) return - } - if n < mw.TimeFunc().Unix() { - mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(ErrExpiredToken, c)) + } else if errors.Is(err, jwt.ErrTokenRequiredClaimMissing) && strings.Contains(err.Error(), "exp claim is required") { + mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, c)) + return + } else { + mw.unauthorized(c, http.StatusUnauthorized, mw.HTTPStatusMessageFunc(err, c)) return } - default: - mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrWrongFormatOfExp, c)) + } + + // For backwards compatibility since technically exp is not required in the spec but has been in gin-jwt + if claims["exp"] == nil { + mw.unauthorized(c, http.StatusBadRequest, mw.HTTPStatusMessageFunc(ErrMissingExpField, c)) return } @@ -607,16 +606,13 @@ func (mw *GinJWTMiddleware) RefreshToken(c *gin.Context) (string, time.Time, err // CheckIfTokenExpire check if token expire func (mw *GinJWTMiddleware) CheckIfTokenExpire(c *gin.Context) (jwt.MapClaims, error) { token, err := mw.ParseToken(c) - if err != nil { - // If we receive an error, and the error is anything other than a single - // ValidationErrorExpired, we want to return the error. - // If the error is just ValidationErrorExpired, we want to continue, as we can still - // refresh the token if it's within the MaxRefresh time. - // (see https://github.com/appleboy/gin-jwt/issues/176) - validationErr, ok := err.(*jwt.ValidationError) - if !ok || validationErr.Errors != jwt.ValidationErrorExpired { - return nil, err - } + // If we receive an error, and the error is anything other than a single + // ErrTokenExpired, we want to return the error. + // If the error is just ErrTokenExpired, we want to continue, as we can still + // refresh the token if it's within the MaxRefresh time. + // (see https://github.com/appleboy/gin-jwt/issues/176) + if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { + return nil, err } claims := token.Claims.(jwt.MapClaims) diff --git a/auth_jwt_test.go b/auth_jwt_test.go index 1abd564..1107444 100644 --- a/auth_jwt_test.go +++ b/auth_jwt_test.go @@ -14,7 +14,7 @@ import ( "github.com/appleboy/gofight/v2" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) @@ -1228,7 +1228,7 @@ func TestExpiredField(t *testing.T) { }) // wrong format - claims["exp"] = "wrongFormatForExpiryIgnoredByJwtLibrary" + claims["exp"] = "wrongFormatForExpiry" tokenString, _ = token.SignedString(key) r.GET("/auth/hello"). @@ -1238,8 +1238,55 @@ func TestExpiredField(t *testing.T) { Run(handler, func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { message := gjson.Get(r.Body.String(), "message") - assert.Equal(t, ErrExpiredToken.Error(), strings.ToLower(message.String())) - assert.Equal(t, http.StatusUnauthorized, r.Code) + assert.Equal(t, ErrWrongFormatOfExp.Error(), strings.ToLower(message.String())) + assert.Equal(t, http.StatusBadRequest, r.Code) + }) +} + +func TestExpiredFieldRequiredParserOption(t *testing.T) { + // the middleware to test + authMiddleware, _ := New(&GinJWTMiddleware{ + Realm: "test zone", + Key: key, + Timeout: time.Hour, + Authenticator: defaultAuthenticator, + ParseOptions: []jwt.ParserOption{jwt.WithExpirationRequired()}, + }) + + handler := ginHandler(authMiddleware) + + r := gofight.New() + + token := jwt.New(jwt.GetSigningMethod("HS256")) + claims := token.Claims.(jwt.MapClaims) + claims["identity"] = "admin" + claims["orig_iat"] = 0 + tokenString, _ := token.SignedString(key) + + r.GET("/auth/hello"). + SetHeader(gofight.H{ + "Authorization": "Bearer " + tokenString, + }). + Run(handler, func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + message := gjson.Get(r.Body.String(), "message") + + assert.Equal(t, ErrMissingExpField.Error(), message.String()) + assert.Equal(t, http.StatusBadRequest, r.Code) + }) + + // wrong format + claims["exp"] = "wrongFormatForExpiry" + tokenString, _ = token.SignedString(key) + + r.GET("/auth/hello"). + SetHeader(gofight.H{ + "Authorization": "Bearer " + tokenString, + }). + Run(handler, func(r gofight.HTTPResponse, rq gofight.HTTPRequest) { + message := gjson.Get(r.Body.String(), "message") + + assert.Equal(t, ErrWrongFormatOfExp.Error(), strings.ToLower(message.String())) + assert.Equal(t, http.StatusBadRequest, r.Code) }) } diff --git a/go.mod b/go.mod index 724854c..b7fe3d0 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21.0 require ( github.com/appleboy/gofight/v2 v2.1.2 github.com/gin-gonic/gin v1.10.0 - github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang-jwt/jwt/v5 v5.2.0 github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 ) diff --git a/go.sum b/go.sum index 51f24a4..95725de 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4 github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= -github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= +github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=