diff --git a/pkg/apiToken/ApiTokenBean.go b/pkg/apiToken/ApiTokenBean.go new file mode 100644 index 0000000000..c193bccc9c --- /dev/null +++ b/pkg/apiToken/ApiTokenBean.go @@ -0,0 +1,9 @@ +package apiToken + +import "github.com/golang-jwt/jwt/v4" + +type ApiTokenCustomClaims struct { + Email string `json:"email"` + Version string `json:"version"` + jwt.RegisteredClaims +} diff --git a/pkg/apiToken/ApiTokenRepository.go b/pkg/apiToken/ApiTokenRepository.go index 83555fd0c5..7df3eb816d 100644 --- a/pkg/apiToken/ApiTokenRepository.go +++ b/pkg/apiToken/ApiTokenRepository.go @@ -18,6 +18,7 @@ package apiToken import ( + "fmt" "github.com/devtron-labs/devtron/pkg/auth/user/repository" "github.com/devtron-labs/devtron/pkg/sql" "github.com/go-pg/pg" @@ -29,6 +30,7 @@ type ApiToken struct { Id int `sql:"id,pk"` UserId int32 `sql:"user_id, notnull"` Name string `sql:"name, notnull"` + Version int `sql:"version, notnull"` Description string `sql:"description, notnull"` ExpireAtInMs int64 `sql:"expire_at_in_ms"` Token string `sql:"token, notnull"` @@ -42,6 +44,7 @@ type ApiTokenRepository interface { FindAllActive() ([]*ApiToken, error) FindActiveById(id int) (*ApiToken, error) FindByName(name string) (*ApiToken, error) + UpdateIf(apiToken *ApiToken, previousTokenVersion int) error } type ApiTokenRepositoryImpl struct { @@ -60,6 +63,20 @@ func (impl ApiTokenRepositoryImpl) Update(apiToken *ApiToken) error { return impl.dbConnection.Update(apiToken) } +func (impl ApiTokenRepositoryImpl) UpdateIf(apiToken *ApiToken, previousTokenVersion int) error { + res, err := impl.dbConnection.Model(apiToken). + Where("id = ?", apiToken.Id). + Where("version = ?", previousTokenVersion). + Update() + if err != nil { + return err + } + if res.RowsAffected() == 0 { + return fmt.Errorf(TokenVersionMismatch) + } + return nil +} + func (impl ApiTokenRepositoryImpl) FindAllActive() ([]*ApiToken, error) { var apiTokens []*ApiToken err := impl.dbConnection.Model(&apiTokens). diff --git a/pkg/apiToken/ApiTokenService.go b/pkg/apiToken/ApiTokenService.go index 8a271c065f..28a5d5521a 100644 --- a/pkg/apiToken/ApiTokenService.go +++ b/pkg/apiToken/ApiTokenService.go @@ -20,6 +20,7 @@ package apiToken import ( "errors" "fmt" + userBean "github.com/devtron-labs/devtron/pkg/auth/user/bean" "regexp" "strconv" "strings" @@ -62,14 +63,13 @@ func NewApiTokenServiceImpl(logger *zap.SugaredLogger, apiTokenSecretService Api } } -const API_TOKEN_USER_EMAIL_PREFIX = "API-TOKEN:" - var invalidCharsInApiTokenName = regexp.MustCompile("[,\\s]") -type ApiTokenCustomClaims struct { - Email string `json:"email"` - jwt.RegisteredClaims -} +const ( + ConcurrentTokenUpdateRequest = "there is an ongoing request for the token with the same name, please try again after some time" + UniqueKeyViolationPgErrorCode = 23505 + TokenVersionMismatch = "token version mismatch" +) func (impl ApiTokenServiceImpl) GetAllApiTokensForWebhook(projectName string, environmentName string, appName string, auth func(token string, projectObject string, envObject string) bool) ([]*openapi.ApiToken, error) { impl.logger.Info("Getting active api tokens") @@ -181,11 +181,21 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe impl.logger.Info(fmt.Sprintf("apiTokenExists : %s", strconv.FormatBool(apiTokenExists))) - // step-2 - Build email - email := fmt.Sprintf("%s%s", API_TOKEN_USER_EMAIL_PREFIX, name) + // step-2 - Build email and version + email := fmt.Sprintf("%s%s", userBean.API_TOKEN_USER_EMAIL_PREFIX, name) + var ( + tokenVersion int + previousTokenVersion int + ) + if apiTokenExists { + tokenVersion = apiToken.Version + 1 + previousTokenVersion = apiToken.Version + } else { + tokenVersion = 1 + } // step-3 - Build token - token, err := impl.createApiJwtToken(email, *request.ExpireAtInMs) + token, err := impl.createApiJwtToken(email, tokenVersion, *request.ExpireAtInMs) if err != nil { return nil, err } @@ -214,6 +224,7 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe Description: *request.Description, ExpireAtInMs: *request.ExpireAtInMs, Token: token, + Version: tokenVersion, AuditLog: sql.AuditLog{UpdatedOn: time.Now()}, } if apiTokenExists { @@ -221,7 +232,9 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe apiTokenSaveRequest.CreatedBy = apiToken.CreatedBy apiTokenSaveRequest.CreatedOn = apiToken.CreatedOn apiTokenSaveRequest.UpdatedBy = createdBy - err = impl.apiTokenRepository.Update(apiTokenSaveRequest) + // update api-token only if `previousTokenVersion` is same as version stored in DB + // we are checking this to ensure that two users are not updating the same token at the same time + err = impl.apiTokenRepository.UpdateIf(apiTokenSaveRequest, previousTokenVersion) } else { apiTokenSaveRequest.CreatedBy = createdBy apiTokenSaveRequest.CreatedOn = time.Now() @@ -229,6 +242,19 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe } if err != nil { impl.logger.Errorw("error while saving api-token into DB", "error", err) + // fetching error code from pg error for Unique key violation constraint + // in case of save + pgErr, ok := err.(pg.Error) + if ok { + errCode, conversionErr := strconv.Atoi(pgErr.Field('C')) + if conversionErr == nil && errCode == UniqueKeyViolationPgErrorCode { + return nil, fmt.Errorf(ConcurrentTokenUpdateRequest) + } + } + // in case of update + if errors.Is(err, fmt.Errorf(TokenVersionMismatch)) { + return nil, fmt.Errorf(ConcurrentTokenUpdateRequest) + } return nil, err } @@ -254,14 +280,18 @@ func (impl ApiTokenServiceImpl) UpdateApiToken(apiTokenId int, request *openapi. return nil, errors.New(fmt.Sprintf("api-token corresponds to apiTokenId '%d' is not found", apiTokenId)) } + previousTokenVersion := apiToken.Version + tokenVersion := apiToken.Version + 1 + // step-2 - If expires_at is not same, then token needs to be generated again if *request.ExpireAtInMs != apiToken.ExpireAtInMs { // regenerate token - token, err := impl.createApiJwtToken(apiToken.User.EmailId, *request.ExpireAtInMs) + token, err := impl.createApiJwtToken(apiToken.User.EmailId, tokenVersion, *request.ExpireAtInMs) if err != nil { return nil, err } apiToken.Token = token + apiToken.Version = tokenVersion } // step-3 - update in DB @@ -269,7 +299,9 @@ func (impl ApiTokenServiceImpl) UpdateApiToken(apiTokenId int, request *openapi. apiToken.ExpireAtInMs = *request.ExpireAtInMs apiToken.UpdatedBy = updatedBy apiToken.UpdatedOn = time.Now() - err = impl.apiTokenRepository.Update(apiToken) + // update api-token only if `previousTokenVersion` is same as version stored in DB + // we are checking this to ensure that two users are not updating the same token at the same time + err = impl.apiTokenRepository.UpdateIf(apiToken, previousTokenVersion) if err != nil { impl.logger.Errorw("error while updating api-token", "apiTokenId", apiTokenId, "error", err) return nil, err @@ -322,24 +354,41 @@ func (impl ApiTokenServiceImpl) DeleteApiToken(apiTokenId int, deletedBy int32) } -func (impl ApiTokenServiceImpl) createApiJwtToken(email string, expireAtInMs int64) (string, error) { +func (impl ApiTokenServiceImpl) createApiJwtToken(email string, tokenVersion int, expireAtInMs int64) (string, error) { + registeredClaims, secretByteArr, err := impl.setRegisteredClaims(expireAtInMs) + if err != nil { + return "", err + } + claims := &ApiTokenCustomClaims{ + email, + strconv.Itoa(tokenVersion), + registeredClaims, + } + token, err := impl.generateToken(claims, secretByteArr) + if err != nil { + return "", err + } + return token, nil +} + +func (impl ApiTokenServiceImpl) setRegisteredClaims(expireAtInMs int64) (jwt.RegisteredClaims, []byte, error) { secretByteArr, err := impl.apiTokenSecretService.GetApiTokenSecretByteArr() if err != nil { impl.logger.Errorw("error while getting api token secret", "error", err) - return "", err + return jwt.RegisteredClaims{}, secretByteArr, err } registeredClaims := jwt.RegisteredClaims{ Issuer: middleware.ApiTokenClaimIssuer, } + if expireAtInMs > 0 { registeredClaims.ExpiresAt = jwt.NewNumericDate(time.Unix(expireAtInMs/1000, 0)) } + return registeredClaims, secretByteArr, nil +} - claims := &ApiTokenCustomClaims{ - email, - registeredClaims, - } +func (impl ApiTokenServiceImpl) generateToken(claims *ApiTokenCustomClaims, secretByteArr []byte) (string, error) { unsignedToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token, err := unsignedToken.SignedString(secretByteArr) if err != nil { diff --git a/pkg/auth/user/UserAuthService.go b/pkg/auth/user/UserAuthService.go index dd00172da2..bfc82428a0 100644 --- a/pkg/auth/user/UserAuthService.go +++ b/pkg/auth/user/UserAuthService.go @@ -31,7 +31,7 @@ import ( "github.com/devtron-labs/authenticator/middleware" casbin2 "github.com/devtron-labs/devtron/pkg/auth/authorisation/casbin" - bean2 "github.com/devtron-labs/devtron/pkg/auth/user/bean" + userBean "github.com/devtron-labs/devtron/pkg/auth/user/bean" "github.com/devtron-labs/devtron/pkg/auth/user/repository" "github.com/go-pg/pg" @@ -473,7 +473,7 @@ func (impl UserAuthServiceImpl) AuthVerification(r *http.Request) (bool, error) } return false, err } - emailId, err := impl.userService.GetEmailFromToken(token) + emailId, version, err := impl.userService.GetEmailAndVersionFromToken(token) if err != nil { impl.logger.Errorw("AuthVerification failed ", "error", err) return false, err @@ -488,22 +488,33 @@ func (impl UserAuthServiceImpl) AuthVerification(r *http.Request) (bool, error) } return false, err } + // checking length of version, to ensure backward compatibility as earlier we did not + // have version for api-tokens + // therefore, for tokens without version we will skip the below part + if strings.HasPrefix(emailId, userBean.API_TOKEN_USER_EMAIL_PREFIX) && len(version) > 0 { + err := impl.userService.CheckIfTokenIsValid(emailId, version) + if err != nil { + impl.logger.Errorw("token is not valid", "error", err, "token", token) + return false, err + } + } //TODO - extends for other purpose return true, nil } + func (impl UserAuthServiceImpl) DeleteRoles(entityType string, entityName string, tx *pg.Tx, envIdentifier string, workflowName string) (err error) { var roleModels []*repository.RoleModel switch entityType { - case bean2.PROJECT_TYPE: + case userBean.PROJECT_TYPE: roleModels, err = impl.userAuthRepository.GetRolesForProject(entityName) - case bean2.ENV_TYPE: + case userBean.ENV_TYPE: roleModels, err = impl.userAuthRepository.GetRolesForEnvironment(entityName, envIdentifier) - case bean2.APP_TYPE: + case userBean.APP_TYPE: roleModels, err = impl.userAuthRepository.GetRolesForApp(entityName) - case bean2.CHART_GROUP_TYPE: + case userBean.CHART_GROUP_TYPE: roleModels, err = impl.userAuthRepository.GetRolesForChartGroup(entityName) - case bean2.WorkflowType: + case userBean.WorkflowType: roleModels, err = impl.userAuthRepository.GetRolesForWorkflow(workflowName, entityName) } if err != nil { diff --git a/pkg/auth/user/UserService.go b/pkg/auth/user/UserService.go index 7dec4ae8dd..1b5a75edf9 100644 --- a/pkg/auth/user/UserService.go +++ b/pkg/auth/user/UserService.go @@ -21,9 +21,10 @@ import ( "context" "fmt" "github.com/devtron-labs/devtron/pkg/auth/user/adapter" - helper2 "github.com/devtron-labs/devtron/pkg/auth/user/helper" + userHelper "github.com/devtron-labs/devtron/pkg/auth/user/helper" "github.com/devtron-labs/devtron/pkg/auth/user/repository/helper" "net/http" + "strconv" "strings" "sync" "time" @@ -58,6 +59,7 @@ type UserService interface { GetAllWithFilters(request *bean.ListingRequest) (*bean.UserListingResponse, error) GetAllDetailedUsers() ([]bean.UserInfo, error) GetEmailFromToken(token string) (string, error) + GetEmailAndVersionFromToken(token string) (string, string, error) GetEmailById(userId int32) (string, error) GetLoggedInUser(r *http.Request) (int32, error) GetByIds(ids []int32) ([]bean.UserInfo, error) @@ -72,6 +74,7 @@ type UserService interface { UpdateTriggerPolicyForTerminalAccess() (err error) GetRoleFiltersByUserRoleGroups(userRoleGroups []bean.UserRoleGroup) ([]bean.RoleFilter, error) SaveLoginAudit(emailId, clientIp string, id int32) + CheckIfTokenIsValid(email string, version string) error } type UserServiceImpl struct { @@ -1222,7 +1225,7 @@ func (impl *UserServiceImpl) GetLoggedInUser(r *http.Request) (int32, error) { func (impl *UserServiceImpl) GetUserByToken(context context.Context, token string) (int32, string, error) { _, span := otel.Tracer("userService").Start(context, "GetUserByToken") - email, err := impl.GetEmailFromToken(token) + email, version, err := impl.GetEmailAndVersionFromToken(token) span.End() if err != nil { return http.StatusUnauthorized, "", err @@ -1237,9 +1240,35 @@ func (impl *UserServiceImpl) GetUserByToken(context context.Context, token strin } return http.StatusUnauthorized, "", err } + // checking length of version, to ensure backward compatibility as earlier we did not + // have version for api-tokens + // therefore, for tokens without version we will skip the below part + if userInfo.UserType == bean.USER_TYPE_API_TOKEN && len(version) > 0 { + err := impl.CheckIfTokenIsValid(email, version) + if err != nil { + impl.logger.Errorw("token is not valid", "error", err, "token", token) + return http.StatusUnauthorized, "", err + } + } return userInfo.Id, userInfo.UserType, nil } +func (impl *UserServiceImpl) CheckIfTokenIsValid(email string, version string) error { + tokenName := userHelper.ExtractTokenNameFromEmail(email) + embeddedTokenVersion, _ := strconv.Atoi(version) + isProvidedTokenValid, err := impl.userRepository.CheckIfTokenExistsByTokenNameAndVersion(tokenName, embeddedTokenVersion) + if err != nil || !isProvidedTokenValid { + err := &util.ApiError{ + HttpStatusCode: http.StatusUnauthorized, + Code: constants.UserNotFoundForToken, + InternalMessage: "user not found for token", + UserMessage: fmt.Sprintf("no user found against provided token"), + } + return err + } + return nil +} + func (impl *UserServiceImpl) GetEmailFromToken(token string) (string, error) { if token == "" { impl.logger.Infow("no token provided") @@ -1283,6 +1312,50 @@ func (impl *UserServiceImpl) GetEmailFromToken(token string) (string, error) { return email, nil } +func (impl *UserServiceImpl) GetEmailAndVersionFromToken(token string) (string, string, error) { + if token == "" { + impl.logger.Infow("no token provided") + err := &util.ApiError{ + Code: constants.UserNoTokenProvided, + InternalMessage: "no token provided", + } + return "", "", err + } + + claims, err := impl.sessionManager2.VerifyToken(token) + + if err != nil { + impl.logger.Errorw("failed to verify token", "error", err) + err := &util.ApiError{ + Code: constants.UserNoTokenProvided, + InternalMessage: "failed to verify token", + UserMessage: "token verification failed while getting logged in user", + } + return "", "", err + } + + mapClaims, err := jwt.MapClaims(claims) + if err != nil { + impl.logger.Errorw("failed to MapClaims", "error", err) + err := &util.ApiError{ + Code: constants.UserNoTokenProvided, + InternalMessage: "token invalid", + UserMessage: "token verification failed while parsing token", + } + return "", "", err + } + + email := jwt.GetField(mapClaims, "email") + sub := jwt.GetField(mapClaims, "sub") + tokenVersion := jwt.GetField(mapClaims, "version") + + if email == "" && (sub == "admin" || sub == "admin:login") { + email = "admin" + } + + return email, tokenVersion, nil +} + func (impl *UserServiceImpl) GetByIds(ids []int32) ([]bean.UserInfo, error) { var beans []bean.UserInfo models, err := impl.userRepository.GetByIds(ids) @@ -1384,7 +1457,7 @@ func (impl *UserServiceImpl) getUserIdsHonoringFilters(request *bean.ListingRequ // collecting the required user ids from filtered models filteredUserIds := make([]int32, 0, len(models)) for _, model := range models { - if !helper2.IsSystemOrAdminUserByEmail(model.EmailId) { + if !userHelper.IsSystemOrAdminUserByEmail(model.EmailId) { filteredUserIds = append(filteredUserIds, model.Id) } } diff --git a/pkg/auth/user/bean/bean.go b/pkg/auth/user/bean/bean.go index 81212270e6..6ab54884e9 100644 --- a/pkg/auth/user/bean/bean.go +++ b/pkg/auth/user/bean/bean.go @@ -80,3 +80,8 @@ const ( AdminUserId = 2 // we have established Admin user as 2 while setting up devtron SystemUserId = 1 // we have established System user as 1 while setting up devtron, which are being used for auto-trigger operations ) + +const ( + API_TOKEN_USER_EMAIL_PREFIX = "API-TOKEN:" + ApiTokenTableName = "api_token" +) diff --git a/pkg/auth/user/helper/helper.go b/pkg/auth/user/helper/helper.go index 064c6724b7..1d85015272 100644 --- a/pkg/auth/user/helper/helper.go +++ b/pkg/auth/user/helper/helper.go @@ -4,6 +4,7 @@ import ( "github.com/devtron-labs/devtron/internal/util" "github.com/devtron-labs/devtron/pkg/auth/user/bean" "golang.org/x/exp/slices" + "strings" ) func IsSystemOrAdminUser(userId int32) bool { @@ -44,3 +45,7 @@ func CheckIfUserIdsExists(userIds []int32) error { } return nil } + +func ExtractTokenNameFromEmail(email string) string { + return strings.Split(email, ":")[1] +} diff --git a/pkg/auth/user/repository/UserRepository.go b/pkg/auth/user/repository/UserRepository.go index 3b149b07a4..de8352c339 100644 --- a/pkg/auth/user/repository/UserRepository.go +++ b/pkg/auth/user/repository/UserRepository.go @@ -22,6 +22,7 @@ package repository import ( "github.com/devtron-labs/devtron/api/bean" + userBean "github.com/devtron-labs/devtron/pkg/auth/user/bean" "github.com/devtron-labs/devtron/pkg/sql" "github.com/go-pg/pg" "go.uber.org/zap" @@ -46,6 +47,7 @@ type UserRepository interface { FetchActiveOrDeletedUserByEmail(email string) (*UserModel, error) UpdateRoleIdForUserRolesMappings(roleId int, newRoleId int) (*UserRoleModel, error) GetCountExecutingQuery(query string) (int, error) + CheckIfTokenExistsByTokenNameAndVersion(tokenName string, tokenVersion int) (bool, error) } type UserRepositoryImpl struct { @@ -242,3 +244,15 @@ func (impl UserRepositoryImpl) GetCountExecutingQuery(query string) (int, error) } return totalCount, err } + +// below method does operation on api_token table, +// we are writing this method here instead of ApiTokenRepository to avoid cyclic import +func (impl UserRepositoryImpl) CheckIfTokenExistsByTokenNameAndVersion(tokenName string, tokenVersion int) (bool, error) { + query := impl.dbConnection.Model(). + Table(userBean.ApiTokenTableName). + Where("name = ?", tokenName). + Where("version = ?", tokenVersion) + + exists, err := query.Exists() + return exists, err +} diff --git a/scripts/sql/239_api_token_version_support.down.sql b/scripts/sql/239_api_token_version_support.down.sql new file mode 100644 index 0000000000..907b0c5a98 --- /dev/null +++ b/scripts/sql/239_api_token_version_support.down.sql @@ -0,0 +1 @@ +ALTER TABLE api_token DROP COLUMN version; \ No newline at end of file diff --git a/scripts/sql/239_api_token_version_support.up.sql b/scripts/sql/239_api_token_version_support.up.sql new file mode 100644 index 0000000000..7cb506c1c6 --- /dev/null +++ b/scripts/sql/239_api_token_version_support.up.sql @@ -0,0 +1 @@ +ALTER TABLE api_token ADD COLUMN version int NOT NULL DEFAULT 1; \ No newline at end of file