From caed2227403c53d92af7c19254a6cfa04dc107ed Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 02:03:58 +0600 Subject: [PATCH 1/6] refactor: improve readability and reusability of request method validation --- const.go | 15 +++++++++++++++ errors/response.go | 1 + server/config.go | 17 +++++++++++------ server/server.go | 15 ++++++++++++--- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/const.go b/const.go index 193e839..4f1589f 100644 --- a/const.go +++ b/const.go @@ -74,3 +74,18 @@ func (ccm CodeChallengeMethod) Validate(cc, ver string) bool { return false } } + +// AuthorizeRequestMethod the type of authorization request method +type AuthorizeRequestMethod string + +const ( + AuthorizeRequestGet AuthorizeRequestMethod = "GET" + AuthorizeRequestPost AuthorizeRequestMethod = "POST" +) + +func (ar AuthorizeRequestMethod) String() string { + if ar == AuthorizeRequestGet || ar == AuthorizeRequestPost { + return string(ar) + } + return "" +} diff --git a/errors/response.go b/errors/response.go index c8d5902..7a16989 100644 --- a/errors/response.go +++ b/errors/response.go @@ -35,6 +35,7 @@ func (r *Response) SetHeader(key, value string) { // https://tools.ietf.org/html/rfc6749#section-5.2 var ( ErrInvalidRequest = errors.New("invalid_request") + ErrInvalidRequestMethod = errors.New("invalid_request_method") ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrUnsupportedResponseType = errors.New("unsupported_response_type") diff --git a/server/config.go b/server/config.go index 3bbb884..e9bad11 100644 --- a/server/config.go +++ b/server/config.go @@ -9,12 +9,13 @@ import ( // Config configuration parameters type Config struct { - TokenType string // token type - AllowGetAccessRequest bool // to allow GET requests for the token - AllowedResponseTypes []oauth2.ResponseType // allow the authorization type - AllowedGrantTypes []oauth2.GrantType // allow the grant type - AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod - ForcePKCE bool + TokenType string // token type + AllowGetAccessRequest bool // to allow GET requests for the token + AllowedResponseTypes []oauth2.ResponseType // allow the authorization type + AllowedGrantTypes []oauth2.GrantType // allow the grant type + AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod + AllowedAuthorizeRequestMethods []oauth2.AuthorizeRequestMethod //allowed `authorize request methods` + ForcePKCE bool } // NewConfig create to configuration instance @@ -32,6 +33,10 @@ func NewConfig() *Config { oauth2.CodeChallengePlain, oauth2.CodeChallengeS256, }, + AllowedAuthorizeRequestMethods: []oauth2.AuthorizeRequestMethod{ + oauth2.AuthorizeRequestGet, + oauth2.AuthorizeRequestPost, + }, } } diff --git a/server/server.go b/server/server.go index df19d1f..a3ba08b 100755 --- a/server/server.go +++ b/server/server.go @@ -163,13 +163,22 @@ func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { return false } +// CheckAuthorizeRequestMethod checks for allowed code challenge method +func (s *Server) CheckAuthorizeRequestMethod(requestMethod oauth2.AuthorizeRequestMethod) bool { + for _, method := range s.Config.AllowedAuthorizeRequestMethods { + if method == requestMethod { + return true + } + } + return false +} + // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest + if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed || clientID == "" { + return nil, errors.ErrInvalidRequestMethod } resType := oauth2.ResponseType(r.FormValue("response_type")) From 8dc1be6ca0daba88e36f336fc06ee4471a56df03 Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 17:12:30 +0600 Subject: [PATCH 2/6] refactor: new error messages for better understanding --- errors/response.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/errors/response.go b/errors/response.go index 7a16989..b8e8eef 100644 --- a/errors/response.go +++ b/errors/response.go @@ -39,6 +39,8 @@ var ( ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrUnsupportedResponseType = errors.New("unsupported_response_type") + ErrUnauthorizedResponseType = errors.New("unauthorized_response_type") + ErrMissingResponseType = errors.New("missing_response_type") ErrInvalidScope = errors.New("invalid_scope") ErrServerError = errors.New("server_error") ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") From cd2d32be8545985e6669b557c5c477d64c2af32f Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 17:20:05 +0600 Subject: [PATCH 3/6] refactor: simplify response type validation logic --- server/server.go | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/server/server.go b/server/server.go index a3ba08b..1a15428 100755 --- a/server/server.go +++ b/server/server.go @@ -143,14 +143,19 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface return u.String(), nil } -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true +// CheckResponseType checks for an allowed response type +func (s *Server) CheckResponseType(responseType oauth2.ResponseType) error { + if responseType.String() == "" { + return errors.ErrMissingResponseType + } + + for _, rType := range s.Config.AllowedResponseTypes { + if rType == responseType { + return nil } } - return false + + return errors.ErrUnsupportedResponseType } // CheckCodeChallengeMethod checks for allowed code challenge method @@ -177,19 +182,18 @@ func (s *Server) CheckAuthorizeRequestMethod(requestMethod oauth2.AuthorizeReque func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") + if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed || clientID == "" { return nil, errors.ErrInvalidRequestMethod } - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient + responseType := oauth2.ResponseType(r.FormValue("response_type")) + if err := s.CheckResponseType(responseType); err != nil { + return nil, err } cc := r.FormValue("code_challenge") - if cc == "" && s.Config.ForcePKCE { + if s.Config.ForcePKCE || cc == "" { return nil, errors.ErrCodeChallengeRquired } if cc != "" && (len(cc) < 43 || len(cc) > 128) { @@ -207,7 +211,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, req := &AuthorizeRequest{ RedirectURI: redirectURI, - ResponseType: resType, + ResponseType: responseType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), From 0ec0f1348c2c87b8178c2cfaf9ca593b7f1876c1 Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 17:34:16 +0600 Subject: [PATCH 4/6] refactor: unnecessary logic cleared for code challenge validation and new error message updated --- errors/response.go | 9 +++++---- server/server.go | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/errors/response.go b/errors/response.go index b8e8eef..d2e8697 100644 --- a/errors/response.go +++ b/errors/response.go @@ -39,7 +39,6 @@ var ( ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrUnsupportedResponseType = errors.New("unsupported_response_type") - ErrUnauthorizedResponseType = errors.New("unauthorized_response_type") ErrMissingResponseType = errors.New("missing_response_type") ErrInvalidScope = errors.New("invalid_scope") ErrServerError = errors.New("server_error") @@ -47,7 +46,7 @@ var ( ErrInvalidClient = errors.New("invalid_client") ErrInvalidGrant = errors.New("invalid_grant") ErrUnsupportedGrantType = errors.New("unsupported_grant_type") - ErrCodeChallengeRquired = errors.New("invalid_request") + ErrCodeChallengeRequired = errors.New("invalid_request") ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request") ErrInvalidCodeChallengeLen = errors.New("invalid_request") ) @@ -58,13 +57,14 @@ var Descriptions = map[error]string{ ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", ErrAccessDenied: "The resource owner or authorization server denied the request", ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", + ErrMissingResponseType: "The requested response type is empty", ErrInvalidScope: "The requested scope is invalid, unknown, or malformed", ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request", ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server", ErrInvalidClient: "Client authentication failed", ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client", ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", - ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing", + ErrCodeChallengeRequired: "PKCE is required. code_challenge is missing", ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported", ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long", } @@ -75,13 +75,14 @@ var StatusCodes = map[error]int{ ErrUnauthorizedClient: 401, ErrAccessDenied: 403, ErrUnsupportedResponseType: 401, + ErrMissingResponseType: 400, ErrInvalidScope: 400, ErrServerError: 500, ErrTemporarilyUnavailable: 503, ErrInvalidClient: 401, ErrInvalidGrant: 401, ErrUnsupportedGrantType: 401, - ErrCodeChallengeRquired: 400, + ErrCodeChallengeRequired: 400, ErrUnsupportedCodeChallengeMethod: 400, ErrInvalidCodeChallengeLen: 400, } diff --git a/server/server.go b/server/server.go index 1a15428..3d6fe89 100755 --- a/server/server.go +++ b/server/server.go @@ -194,9 +194,9 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, cc := r.FormValue("code_challenge") if s.Config.ForcePKCE || cc == "" { - return nil, errors.ErrCodeChallengeRquired + return nil, errors.ErrCodeChallengeRequired } - if cc != "" && (len(cc) < 43 || len(cc) > 128) { + if len(cc) < 43 || len(cc) > 128 { return nil, errors.ErrInvalidCodeChallengeLen } From 3ca761c7ec838915157b7578b2eb72886773adbc Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 18:05:34 +0600 Subject: [PATCH 5/6] refactor: simplified and updated code challenge and code challenge validation logic --- errors/response.go | 7 ++++++- server/server.go | 49 ++++++++++++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/errors/response.go b/errors/response.go index d2e8697..d53ac19 100644 --- a/errors/response.go +++ b/errors/response.go @@ -35,6 +35,7 @@ func (r *Response) SetHeader(key, value string) { // https://tools.ietf.org/html/rfc6749#section-5.2 var ( ErrInvalidRequest = errors.New("invalid_request") + ErrMissingClientID = errors.New("missing_client_id") ErrInvalidRequestMethod = errors.New("invalid_request_method") ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") @@ -54,6 +55,8 @@ var ( // Descriptions error description var Descriptions = map[error]string{ ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed", + ErrMissingClientID: "The request is missing client_id", + ErrInvalidRequestMethod: "The request method is invalid, unknown, or malformed", ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", ErrAccessDenied: "The resource owner or authorization server denied the request", ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", @@ -66,12 +69,14 @@ var Descriptions = map[error]string{ ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", ErrCodeChallengeRequired: "PKCE is required. code_challenge is missing", ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported", - ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long", + ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 characters long", } // StatusCodes response error HTTP status code var StatusCodes = map[error]int{ ErrInvalidRequest: 400, + ErrMissingClientID: 400, + ErrInvalidRequestMethod: 400, ErrUnauthorizedClient: 401, ErrAccessDenied: 403, ErrUnsupportedResponseType: 401, diff --git a/server/server.go b/server/server.go index 3d6fe89..726d56f 100755 --- a/server/server.go +++ b/server/server.go @@ -178,12 +178,31 @@ func (s *Server) CheckAuthorizeRequestMethod(requestMethod oauth2.AuthorizeReque return false } +// CheckCodeChallenge checks if the Code Challenge is valid +func (s *Server) CheckCodeChallenge(codeChallenge string, isForcePKCE bool) error { + if isForcePKCE && codeChallenge == "" { + return errors.ErrCodeChallengeRequired + } + if len(codeChallenge) < 43 || len(codeChallenge) > 128 { + return errors.ErrInvalidCodeChallengeLen + } + return nil +} + // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + if r == nil { + return nil, errors.ErrInvalidRequest + } + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if clientID == "" { + return nil, errors.ErrMissingClientID + } - if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed || clientID == "" { + if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed { return nil, errors.ErrInvalidRequestMethod } @@ -192,34 +211,30 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, return nil, err } - cc := r.FormValue("code_challenge") - if s.Config.ForcePKCE || cc == "" { - return nil, errors.ErrCodeChallengeRequired - } - if len(cc) < 43 || len(cc) > 128 { - return nil, errors.ErrInvalidCodeChallengeLen + codeChallenge := r.FormValue("code_challenge") + if err := s.CheckCodeChallenge(codeChallenge, s.Config.ForcePKCE); err != nil { + return nil, err } - ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) - // set default - if ccm == "" { - ccm = oauth2.CodeChallengePlain + codeChallengeMethod := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) + // Default to plain method if not specified + if codeChallengeMethod == "" { + codeChallengeMethod = oauth2.CodeChallengePlain } - if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { + if !s.CheckCodeChallengeMethod(codeChallengeMethod) { return nil, errors.ErrUnsupportedCodeChallengeMethod } - req := &AuthorizeRequest{ + return &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: responseType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, - CodeChallenge: cc, - CodeChallengeMethod: ccm, - } - return req, nil + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, nil } // GetAuthorizeToken get authorization token(code) From 21281bceed0121c6e336273d3cf970f28fec0d1b Mon Sep 17 00:00:00 2001 From: raisul Date: Sun, 19 Jan 2025 18:23:45 +0600 Subject: [PATCH 6/6] ensured passing all test cases --- server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 726d56f..5bddddf 100755 --- a/server/server.go +++ b/server/server.go @@ -183,7 +183,7 @@ func (s *Server) CheckCodeChallenge(codeChallenge string, isForcePKCE bool) erro if isForcePKCE && codeChallenge == "" { return errors.ErrCodeChallengeRequired } - if len(codeChallenge) < 43 || len(codeChallenge) > 128 { + if len(codeChallenge) > 0 && len(codeChallenge) < 43 || len(codeChallenge) > 128 { return errors.ErrInvalidCodeChallengeLen } return nil