@@ -20,6 +20,7 @@ package apiToken
20
20
import (
21
21
"errors"
22
22
"fmt"
23
+ userBean "github.com/devtron-labs/devtron/pkg/auth/user/bean"
23
24
"regexp"
24
25
"strconv"
25
26
"strings"
@@ -62,14 +63,13 @@ func NewApiTokenServiceImpl(logger *zap.SugaredLogger, apiTokenSecretService Api
62
63
}
63
64
}
64
65
65
- const API_TOKEN_USER_EMAIL_PREFIX = "API-TOKEN:"
66
-
67
66
var invalidCharsInApiTokenName = regexp .MustCompile ("[,\\ s]" )
68
67
69
- type ApiTokenCustomClaims struct {
70
- Email string `json:"email"`
71
- jwt.RegisteredClaims
72
- }
68
+ const (
69
+ ConcurrentTokenUpdateRequest = "there is an ongoing request for the token with the same name, please try again after some time"
70
+ UniqueKeyViolationPgErrorCode = 23505
71
+ TokenVersionMismatch = "token version mismatch"
72
+ )
73
73
74
74
func (impl ApiTokenServiceImpl ) GetAllApiTokensForWebhook (projectName string , environmentName string , appName string , auth func (token string , projectObject string , envObject string ) bool ) ([]* openapi.ApiToken , error ) {
75
75
impl .logger .Info ("Getting active api tokens" )
@@ -181,11 +181,21 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe
181
181
182
182
impl .logger .Info (fmt .Sprintf ("apiTokenExists : %s" , strconv .FormatBool (apiTokenExists )))
183
183
184
- // step-2 - Build email
185
- email := fmt .Sprintf ("%s%s" , API_TOKEN_USER_EMAIL_PREFIX , name )
184
+ // step-2 - Build email and version
185
+ email := fmt .Sprintf ("%s%s" , userBean .API_TOKEN_USER_EMAIL_PREFIX , name )
186
+ var (
187
+ tokenVersion int
188
+ previousTokenVersion int
189
+ )
190
+ if apiTokenExists {
191
+ tokenVersion = apiToken .Version + 1
192
+ previousTokenVersion = apiToken .Version
193
+ } else {
194
+ tokenVersion = 1
195
+ }
186
196
187
197
// step-3 - Build token
188
- token , err := impl .createApiJwtToken (email , * request .ExpireAtInMs )
198
+ token , err := impl .createApiJwtToken (email , tokenVersion , * request .ExpireAtInMs )
189
199
if err != nil {
190
200
return nil , err
191
201
}
@@ -214,21 +224,37 @@ func (impl ApiTokenServiceImpl) CreateApiToken(request *openapi.CreateApiTokenRe
214
224
Description : * request .Description ,
215
225
ExpireAtInMs : * request .ExpireAtInMs ,
216
226
Token : token ,
227
+ Version : tokenVersion ,
217
228
AuditLog : sql.AuditLog {UpdatedOn : time .Now ()},
218
229
}
219
230
if apiTokenExists {
220
231
apiTokenSaveRequest .Id = apiToken .Id
221
232
apiTokenSaveRequest .CreatedBy = apiToken .CreatedBy
222
233
apiTokenSaveRequest .CreatedOn = apiToken .CreatedOn
223
234
apiTokenSaveRequest .UpdatedBy = createdBy
224
- err = impl .apiTokenRepository .Update (apiTokenSaveRequest )
235
+ // update api-token only if `previousTokenVersion` is same as version stored in DB
236
+ // we are checking this to ensure that two users are not updating the same token at the same time
237
+ err = impl .apiTokenRepository .UpdateIf (apiTokenSaveRequest , previousTokenVersion )
225
238
} else {
226
239
apiTokenSaveRequest .CreatedBy = createdBy
227
240
apiTokenSaveRequest .CreatedOn = time .Now ()
228
241
err = impl .apiTokenRepository .Save (apiTokenSaveRequest )
229
242
}
230
243
if err != nil {
231
244
impl .logger .Errorw ("error while saving api-token into DB" , "error" , err )
245
+ // fetching error code from pg error for Unique key violation constraint
246
+ // in case of save
247
+ pgErr , ok := err .(pg.Error )
248
+ if ok {
249
+ errCode , conversionErr := strconv .Atoi (pgErr .Field ('C' ))
250
+ if conversionErr == nil && errCode == UniqueKeyViolationPgErrorCode {
251
+ return nil , fmt .Errorf (ConcurrentTokenUpdateRequest )
252
+ }
253
+ }
254
+ // in case of update
255
+ if errors .Is (err , fmt .Errorf (TokenVersionMismatch )) {
256
+ return nil , fmt .Errorf (ConcurrentTokenUpdateRequest )
257
+ }
232
258
return nil , err
233
259
}
234
260
@@ -254,22 +280,28 @@ func (impl ApiTokenServiceImpl) UpdateApiToken(apiTokenId int, request *openapi.
254
280
return nil , errors .New (fmt .Sprintf ("api-token corresponds to apiTokenId '%d' is not found" , apiTokenId ))
255
281
}
256
282
283
+ previousTokenVersion := apiToken .Version
284
+ tokenVersion := apiToken .Version + 1
285
+
257
286
// step-2 - If expires_at is not same, then token needs to be generated again
258
287
if * request .ExpireAtInMs != apiToken .ExpireAtInMs {
259
288
// regenerate token
260
- token , err := impl .createApiJwtToken (apiToken .User .EmailId , * request .ExpireAtInMs )
289
+ token , err := impl .createApiJwtToken (apiToken .User .EmailId , tokenVersion , * request .ExpireAtInMs )
261
290
if err != nil {
262
291
return nil , err
263
292
}
264
293
apiToken .Token = token
294
+ apiToken .Version = tokenVersion
265
295
}
266
296
267
297
// step-3 - update in DB
268
298
apiToken .Description = * request .Description
269
299
apiToken .ExpireAtInMs = * request .ExpireAtInMs
270
300
apiToken .UpdatedBy = updatedBy
271
301
apiToken .UpdatedOn = time .Now ()
272
- err = impl .apiTokenRepository .Update (apiToken )
302
+ // update api-token only if `previousTokenVersion` is same as version stored in DB
303
+ // we are checking this to ensure that two users are not updating the same token at the same time
304
+ err = impl .apiTokenRepository .UpdateIf (apiToken , previousTokenVersion )
273
305
if err != nil {
274
306
impl .logger .Errorw ("error while updating api-token" , "apiTokenId" , apiTokenId , "error" , err )
275
307
return nil , err
@@ -322,24 +354,41 @@ func (impl ApiTokenServiceImpl) DeleteApiToken(apiTokenId int, deletedBy int32)
322
354
323
355
}
324
356
325
- func (impl ApiTokenServiceImpl ) createApiJwtToken (email string , expireAtInMs int64 ) (string , error ) {
357
+ func (impl ApiTokenServiceImpl ) createApiJwtToken (email string , tokenVersion int , expireAtInMs int64 ) (string , error ) {
358
+ registeredClaims , secretByteArr , err := impl .setRegisteredClaims (expireAtInMs )
359
+ if err != nil {
360
+ return "" , err
361
+ }
362
+ claims := & ApiTokenCustomClaims {
363
+ email ,
364
+ strconv .Itoa (tokenVersion ),
365
+ registeredClaims ,
366
+ }
367
+ token , err := impl .generateToken (claims , secretByteArr )
368
+ if err != nil {
369
+ return "" , err
370
+ }
371
+ return token , nil
372
+ }
373
+
374
+ func (impl ApiTokenServiceImpl ) setRegisteredClaims (expireAtInMs int64 ) (jwt.RegisteredClaims , []byte , error ) {
326
375
secretByteArr , err := impl .apiTokenSecretService .GetApiTokenSecretByteArr ()
327
376
if err != nil {
328
377
impl .logger .Errorw ("error while getting api token secret" , "error" , err )
329
- return "" , err
378
+ return jwt. RegisteredClaims {}, secretByteArr , err
330
379
}
331
380
332
381
registeredClaims := jwt.RegisteredClaims {
333
382
Issuer : middleware .ApiTokenClaimIssuer ,
334
383
}
384
+
335
385
if expireAtInMs > 0 {
336
386
registeredClaims .ExpiresAt = jwt .NewNumericDate (time .Unix (expireAtInMs / 1000 , 0 ))
337
387
}
388
+ return registeredClaims , secretByteArr , nil
389
+ }
338
390
339
- claims := & ApiTokenCustomClaims {
340
- email ,
341
- registeredClaims ,
342
- }
391
+ func (impl ApiTokenServiceImpl ) generateToken (claims * ApiTokenCustomClaims , secretByteArr []byte ) (string , error ) {
343
392
unsignedToken := jwt .NewWithClaims (jwt .SigningMethodHS256 , claims )
344
393
token , err := unsignedToken .SignedString (secretByteArr )
345
394
if err != nil {
0 commit comments