Skip to content

Commit 9cae300

Browse files
committed
feat: change JWT cache to limited LRU based cache
BREAKING CHANGE
1 parent 6b4648d commit 9cae300

21 files changed

+105
-53
lines changed

postgrest.cabal

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,8 @@ library
9999
, auto-update >= 0.1.4 && < 0.2
100100
, base64-bytestring >= 1 && < 1.3
101101
, bytestring >= 0.10.8 && < 0.13
102-
, cache >= 0.1.3 && < 0.2.0
103102
, case-insensitive >= 1.2 && < 1.3
104103
, cassava >= 0.4.5 && < 0.6
105-
, clock >= 0.8.3 && < 0.9.0
106104
, configurator-pg >= 0.2 && < 0.3
107105
, containers >= 0.5.7 && < 0.7
108106
, cookie >= 0.4.2 && < 0.5
@@ -122,6 +120,7 @@ library
122120
, jose-jwt >= 0.9.6 && < 0.11
123121
, lens >= 4.14 && < 5.3
124122
, lens-aeson >= 1.0.1 && < 1.3
123+
, lrucache >= 1.2.0.1 && < 1.3
125124
, mtl >= 2.2.2 && < 2.4
126125
, neat-interpolation >= 0.5 && < 0.6
127126
, network >= 2.6 && < 3.2

src/PostgREST/AppState.hs

+4-4
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,14 @@ data SchemaCacheStatus
120120
type AppSockets = (NS.Socket, Maybe NS.Socket)
121121

122122
init :: AppConfig -> IO AppState
123-
init conf@AppConfig{configLogLevel, configDbPoolSize} = do
123+
init conf = do
124124
loggerState <- Logger.init
125-
metricsState <- Metrics.init configDbPoolSize
126-
let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState)
125+
metricsState <- Metrics.init (configDbPoolSize conf)
126+
let observer = liftA2 (>>) (Logger.observationLogger loggerState (configLogLevel conf)) (Metrics.observationMetrics metricsState)
127127

128128
observer $ AppStartObs prettyVersion
129129

130-
jwtCacheState <- JwtCache.init
130+
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries conf)
131131
pool <- initPool conf observer
132132
(sock, adminSock) <- initSockets conf
133133
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer

src/PostgREST/Auth/JwtCache.hs

+80-46
Original file line numberDiff line numberDiff line change
@@ -13,67 +13,101 @@ module PostgREST.Auth.JwtCache
1313

1414
import qualified Data.Aeson as JSON
1515
import qualified Data.Aeson.KeyMap as KM
16-
import qualified Data.Cache as C
16+
import qualified Data.Cache.LRU as C
17+
import qualified Data.IORef as I
1718
import qualified Data.Scientific as Sci
1819

20+
import Data.Maybe (fromJust)
1921
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
2022
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
21-
import System.Clock (TimeSpec (..))
23+
import GHC.Num (integerFromInt)
2224

2325
import PostgREST.Auth.Types (AuthResult (..))
2426
import PostgREST.Error (Error (..))
2527

2628
import Protolude
2729

2830
newtype JwtCacheState = JwtCacheState
29-
{ jwtCache :: C.Cache ByteString AuthResult
31+
{ jwtCacheIORef :: I.IORef (C.LRU ByteString AuthResult)
3032
}
3133

3234
-- | Initialize JwtCacheState
33-
init :: IO JwtCacheState
34-
init = do
35-
cache <- C.newCache Nothing -- no default expiration
35+
init :: Int -> IO JwtCacheState
36+
init configJwtCacheMaxEntries = do
37+
cache <- I.newIORef $ C.newLRU (Just maxEntries) -- TODO: 0 will throw an error, decide what to do when 0 is set to config
3638
return $ JwtCacheState cache
39+
where
40+
maxEntries = integerFromInt configJwtCacheMaxEntries
41+
3742

3843
-- | Used to retrieve and insert JWT to JWT Cache
3944
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
40-
lookupJwtCache JwtCacheState{jwtCache} token maxLifetime parseJwt utc = do
41-
checkCache <- C.lookup jwtCache token
42-
authResult <- maybe parseJwt (pure . Right) checkCache
43-
44-
case (authResult,checkCache) of
45-
-- From comment:
46-
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
47-
--
48-
-- We purge expired cache entries on a cache miss
49-
-- The reasoning is that:
50-
--
51-
-- 1. We expect it to be rare (otherwise there is no point of the cache)
52-
-- 2. It makes sure the cache is not growing (as inserting new entries
53-
-- does garbage collection)
54-
-- 3. Since this is time expiration based cache there is no real risk of
55-
-- starvation - sooner or later we are going to have a cache miss.
56-
57-
(Right res, Nothing) -> do -- cache miss
58-
59-
let timeSpec = getTimeSpec res maxLifetime utc
60-
61-
-- purge expired cache entries
62-
C.purgeExpired jwtCache
63-
64-
-- insert new cache entry
65-
C.insert' jwtCache (Just timeSpec) token res
66-
67-
_ -> pure ()
68-
69-
return authResult
70-
71-
-- Used to extract JWT exp claim and add to JWT Cache
72-
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
73-
getTimeSpec res maxLifetime utc = do
74-
let expireJSON = KM.lookup "exp" (authClaims res)
75-
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
76-
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
77-
case expireJSON of
78-
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
79-
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0
45+
lookupJwtCache JwtCacheState{jwtCacheIORef} token maxLifetime parseJwt utc = do
46+
-- get cache from IORef
47+
jwtCache <- I.readIORef jwtCacheIORef
48+
49+
-- lookup key = token
50+
let (jwtCache', maybeVal) = C.lookup token jwtCache
51+
52+
-- check cache value otherwise parse jwt
53+
authResult <- maybe parseJwt (pure . Right) maybeVal
54+
55+
-- get updated authResult
56+
case (authResult, maybeVal) of
57+
58+
(Right res, Nothing) -> do -- CACHE MISS
59+
60+
-- When we get a cache miss, we get the parse result, add the exp
61+
-- claim result and then insert it into the cache. After that
62+
-- we write the cache IO ref with updated cache
63+
64+
let res' = addExpToClaims res maxLifetime utc
65+
jwtCache'' = C.insert token res' jwtCache'
66+
67+
-- update IORef
68+
I.writeIORef jwtCacheIORef jwtCache''
69+
70+
return $ Right res'
71+
72+
(parseJwt', Just res) -> -- CACHE HIT
73+
74+
-- For cache hit, we get the result from cache, we check the
75+
-- exp claim. If it expired, we delete it from cache and parse
76+
-- the jwt. Otherwise, the hit result is valid, so we return it
77+
78+
if isExpClaimExpired res utc then do
79+
80+
let (jwtCache'',_) = C.delete token jwtCache'
81+
82+
I.writeIORef jwtCacheIORef jwtCache''
83+
84+
return parseJwt'
85+
else
86+
return $ Right res
87+
88+
_ -> return authResult -- parsing failed, we fail later
89+
90+
91+
-- | Add the exp claim to result by using maxLifetime
92+
addExpToClaims :: AuthResult -> Int -> UTCTime -> AuthResult
93+
addExpToClaims res maxLifetime utc =
94+
let
95+
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
96+
newExp = now + maxLifetime
97+
authClaims' = KM.insert "exp" (JSON.Number $ Sci.scientific (integerFromInt newExp) 0) (authClaims res)
98+
in
99+
res{authClaims=authClaims'}
100+
101+
102+
-- | Check if exp claim is expired when looked up from cache
103+
isExpClaimExpired :: AuthResult -> UTCTime -> Bool
104+
isExpClaimExpired res utc =
105+
let
106+
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
107+
-- we can use fromJust, because "exp" claim is inserted in all entries
108+
expireJSON = fromJust $ KM.lookup "exp" (authClaims res)
109+
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
110+
in
111+
case expireJSON of
112+
JSON.Number expiredAt -> (sciToInt expiredAt - now) < 0
113+
_ -> True -- impossible case; we will always have "exp" as JSON.Number

src/PostgREST/Config.hs

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ data AppConfig = AppConfig
9898
, configJwtSecret :: Maybe BS.ByteString
9999
, configJwtSecretIsBase64 :: Bool
100100
, configJwtCacheMaxLifetime :: Int
101+
, configJwtCacheMaxEntries :: Int
101102
, configLogLevel :: LogLevel
102103
, configLogQuery :: LogQuery
103104
, configOpenApiMode :: OpenAPIMode
@@ -178,6 +179,7 @@ toText conf =
178179
,("jwt-secret", q . T.decodeUtf8 . showJwtSecret)
179180
,("jwt-secret-is-base64", T.toLower . show . configJwtSecretIsBase64)
180181
,("jwt-cache-max-lifetime", show . configJwtCacheMaxLifetime)
182+
,("jwt-cache-max-entries", show . configJwtCacheMaxEntries)
181183
,("log-level", q . dumpLogLevel . configLogLevel)
182184
,("log-query", q . dumpLogQuery . configLogQuery)
183185
,("openapi-mode", q . dumpOpenApiMode . configOpenApiMode)
@@ -288,6 +290,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl =
288290
(optBool "jwt-secret-is-base64")
289291
(optBool "secret-is-base64"))
290292
<*> (fromMaybe 0 <$> optInt "jwt-cache-max-lifetime")
293+
<*> (fromMaybe 100 <$> optInt "jwt-cache-max-entries")
291294
<*> parseLogLevel "log-level"
292295
<*> parseLogQuery "log-query"
293296
<*> parseOpenAPIMode "openapi-mode"

test/io/configs/expected/aliases.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"aliased\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/boolean-numeric.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/boolean-string.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = true
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/defaults.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/jwt-role-claim-key1.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"roles\"[?(@ == \"role1\")]"
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/jwt-role-claim-key2.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"roles\"[?(@ != \"role1\")]"
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/jwt-role-claim-key3.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"roles\"[?(@ ^== \"role1\")]"
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/jwt-role-claim-key4.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"roles\"[?(@ ==^ \"role1\")]"
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/jwt-role-claim-key5.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"roles\"[?(@ *== \"role1\")]"
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/expected/no-defaults-with-db-other-authenticator.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"other\".\"pre_config_role\""
2424
jwt-secret = "ODERREALLYREALLYREALLYREALLYVERYSAFE"
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 7200
27+
jwt-cache-max-entries = 1000
2728
log-level = "info"
2829
log-query = "main-query"
2930
openapi-mode = "disabled"

test/io/configs/expected/no-defaults-with-db.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"a\".\"role\""
2424
jwt-secret = "OVERRIDE=REALLY=REALLY=REALLY=REALLY=VERY=SAFE"
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 3600
27+
jwt-cache-max-entries = 1000
2728
log-level = "info"
2829
log-query = "main-query"
2930
openapi-mode = "ignore-privileges"

test/io/configs/expected/no-defaults.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"user\"[0].\"real-role\""
2424
jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5aW5iYXNlNjQ="
2525
jwt-secret-is-base64 = true
2626
jwt-cache-max-lifetime = 86400
27+
jwt-cache-max-entries = 1000
2728
log-level = "info"
2829
log-query = "main-query"
2930
openapi-mode = "ignore-privileges"

test/io/configs/expected/types.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".\"role\""
2424
jwt-secret = ""
2525
jwt-secret-is-base64 = false
2626
jwt-cache-max-lifetime = 0
27+
jwt-cache-max-entries = 100
2728
log-level = "error"
2829
log-query = "disabled"
2930
openapi-mode = "follow-privileges"

test/io/configs/no-defaults-env.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ PGRST_JWT_ROLE_CLAIM_KEY: '.user[0]."real-role"'
2727
PGRST_JWT_SECRET: c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5aW5iYXNlNjQ=
2828
PGRST_JWT_SECRET_IS_BASE64: true
2929
PGRST_JWT_CACHE_MAX_LIFETIME: 86400
30+
PGRST_JWT_CACHE_MAX_ENTRIES: 1000
3031
PGRST_LOG_LEVEL: info
3132
PGRST_LOG_QUERY: 'main-query'
3233
PGRST_OPENAPI_MODE: 'ignore-privileges'

test/io/configs/no-defaults.config

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ jwt-role-claim-key = ".user[0].\"real-role\""
2424
jwt-secret = "c2VjdXJpdHl0aHJvdWdob2JzY3VyaXR5aW5iYXNlNjQ="
2525
jwt-secret-is-base64 = true
2626
jwt-cache-max-lifetime = 86400
27+
jwt-cache-max-entries = 1000
2728
log-level = "info"
2829
log-query = "main-query"
2930
openapi-mode = "ignore-privileges"

test/spec/Main.hs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ main = do
8585
-- cached schema cache so most tests run fast
8686
baseSchemaCache <- loadSCache pool testCfg
8787
sockets <- AppState.initSockets testCfg
88-
jwtCacheState <- JwtCache.init
88+
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries testCfg)
8989
loggerState <- Logger.init
9090
metricsState <- Metrics.init (configDbPoolSize testCfg)
9191

test/spec/SpecHelper.hs

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in
138138
, configJwtSecret = Just secret
139139
, configJwtSecretIsBase64 = False
140140
, configJwtCacheMaxLifetime = 0
141+
, configJwtCacheMaxEntries = 100 -- default
141142
, configLogLevel = LogCrit
142143
, configLogQuery = LogQueryDisabled
143144
, configOpenApiMode = OAFollowPriv

0 commit comments

Comments
 (0)