@@ -13,67 +13,101 @@ module PostgREST.Auth.JwtCache
13
13
14
14
import qualified Data.Aeson as JSON
15
15
import qualified Data.Aeson.KeyMap as KM
16
- import qualified Data.Cache as C
16
+ import qualified Data.Cache.LRU as C
17
+ import qualified Data.IORef as I
17
18
import qualified Data.Scientific as Sci
18
19
20
+ import Data.Maybe (fromJust )
19
21
import Data.Time.Clock (UTCTime , nominalDiffTimeToSeconds )
20
22
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds )
21
- import System.Clock ( TimeSpec ( .. ) )
23
+ import GHC.Num ( integerFromInt )
22
24
23
25
import PostgREST.Auth.Types (AuthResult (.. ))
24
26
import PostgREST.Error (Error (.. ))
25
27
26
28
import Protolude
27
29
28
30
newtype JwtCacheState = JwtCacheState
29
- { jwtCache :: C. Cache ByteString AuthResult
31
+ { jwtCacheIORef :: I. IORef ( C. LRU ByteString AuthResult )
30
32
}
31
33
32
34
-- | Initialize JwtCacheState
33
- init :: IO JwtCacheState
34
- init = do
35
- cache <- C. newCache Nothing -- no default expiration
35
+ init :: Int -> IO JwtCacheState
36
+ init configJwtCacheMaxEntries = do
37
+ cache <- I. newIORef $ C. newLRU ( Just maxEntries) -- TODO: 0 will throw an error, decide what to do when 0 is set to config
36
38
return $ JwtCacheState cache
39
+ where
40
+ maxEntries = integerFromInt configJwtCacheMaxEntries
41
+
37
42
38
43
-- | Used to retrieve and insert JWT to JWT Cache
39
44
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult ) -> UTCTime -> IO (Either Error AuthResult )
40
- lookupJwtCache JwtCacheState {jwtCache} token maxLifetime parseJwt utc = do
41
- checkCache <- C. lookup jwtCache token
42
- authResult <- maybe parseJwt (pure . Right ) checkCache
43
-
44
- case (authResult,checkCache) of
45
- -- From comment:
46
- -- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
47
- --
48
- -- We purge expired cache entries on a cache miss
49
- -- The reasoning is that:
50
- --
51
- -- 1. We expect it to be rare (otherwise there is no point of the cache)
52
- -- 2. It makes sure the cache is not growing (as inserting new entries
53
- -- does garbage collection)
54
- -- 3. Since this is time expiration based cache there is no real risk of
55
- -- starvation - sooner or later we are going to have a cache miss.
56
-
57
- (Right res, Nothing ) -> do -- cache miss
58
-
59
- let timeSpec = getTimeSpec res maxLifetime utc
60
-
61
- -- purge expired cache entries
62
- C. purgeExpired jwtCache
63
-
64
- -- insert new cache entry
65
- C. insert' jwtCache (Just timeSpec) token res
66
-
67
- _ -> pure ()
68
-
69
- return authResult
70
-
71
- -- Used to extract JWT exp claim and add to JWT Cache
72
- getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
73
- getTimeSpec res maxLifetime utc = do
74
- let expireJSON = KM. lookup " exp" (authClaims res)
75
- utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
76
- sciToInt = fromMaybe 0 . Sci. toBoundedInteger
77
- case expireJSON of
78
- Just (JSON. Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
79
- _ -> TimeSpec (fromIntegral maxLifetime :: Int64 ) 0
45
+ lookupJwtCache JwtCacheState {jwtCacheIORef} token maxLifetime parseJwt utc = do
46
+ -- get cache from IORef
47
+ jwtCache <- I. readIORef jwtCacheIORef
48
+
49
+ -- lookup key = token
50
+ let (jwtCache', maybeVal) = C. lookup token jwtCache
51
+
52
+ -- check cache value otherwise parse jwt
53
+ authResult <- maybe parseJwt (pure . Right ) maybeVal
54
+
55
+ -- get updated authResult
56
+ case (authResult, maybeVal) of
57
+
58
+ (Right res, Nothing ) -> do -- CACHE MISS
59
+
60
+ -- When we get a cache miss, we get the parse result, add the exp
61
+ -- claim result and then insert it into the cache. After that
62
+ -- we write the cache IO ref with updated cache
63
+
64
+ let res' = addExpToClaims res maxLifetime utc
65
+ jwtCache'' = C. insert token res' jwtCache'
66
+
67
+ -- update IORef
68
+ I. writeIORef jwtCacheIORef jwtCache''
69
+
70
+ return $ Right res'
71
+
72
+ (parseJwt', Just res) -> -- CACHE HIT
73
+
74
+ -- For cache hit, we get the result from cache, we check the
75
+ -- exp claim. If it expired, we delete it from cache and parse
76
+ -- the jwt. Otherwise, the hit result is valid, so we return it
77
+
78
+ if isExpClaimExpired res utc then do
79
+
80
+ let (jwtCache'',_) = C. delete token jwtCache'
81
+
82
+ I. writeIORef jwtCacheIORef jwtCache''
83
+
84
+ return parseJwt'
85
+ else
86
+ return $ Right res
87
+
88
+ _ -> return authResult -- parsing failed, we fail later
89
+
90
+
91
+ -- | Add the exp claim to result by using maxLifetime
92
+ addExpToClaims :: AuthResult -> Int -> UTCTime -> AuthResult
93
+ addExpToClaims res maxLifetime utc =
94
+ let
95
+ now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
96
+ newExp = now + maxLifetime
97
+ authClaims' = KM. insert " exp" (JSON. Number $ Sci. scientific (integerFromInt newExp) 0 ) (authClaims res)
98
+ in
99
+ res{authClaims= authClaims'}
100
+
101
+
102
+ -- | Check if exp claim is expired when looked up from cache
103
+ isExpClaimExpired :: AuthResult -> UTCTime -> Bool
104
+ isExpClaimExpired res utc =
105
+ let
106
+ now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
107
+ -- we can use fromJust, because "exp" claim is inserted in all entries
108
+ expireJSON = fromJust $ KM. lookup " exp" (authClaims res)
109
+ sciToInt = fromMaybe 0 . Sci. toBoundedInteger
110
+ in
111
+ case expireJSON of
112
+ JSON. Number expiredAt -> (sciToInt expiredAt - now) < 0
113
+ _ -> True -- impossible case; we will always have "exp" as JSON.Number
0 commit comments