Skip to content

Commit 8fb3df5

Browse files
committed
feat: change JWT cache to limited LRU based cache
BREAKING CHANGE Our JWT cache implementation had no upper bound for number of cache entries. This caused OOM errors. Additionally, the purge mechanism for expired entries was quite slow. This changes our implementation to a LRU based cache which limits the amount of cached entries.
1 parent c732591 commit 8fb3df5

30 files changed

+191
-192
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
2121
+ The selected columns in the embedded resources are aggregated into arrays
2222
+ Aggregates are not supported
2323
- #2967, Add `Proxy-Status` header for better error response - @taimoorzaeem
24+
- #4003, Add config `jwt-cache-max-entries` for maximum number of cached entries - @taimoorzaeem
2425

2526
### Fixed
2627

@@ -54,6 +55,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
5455
+ Diagnostic error messages instead of exposed internals
5556
+ Return new `PGRST303` error when jwt claims decoding fails
5657
- #3906, Return `PGRST125` and `PGRST126` errors instead of empty json - @taimoorzaeem
58+
- #4003, Remove config `jwt-cache-max-lifetime` and add config `jwt-cache-max-entries` for JWT cache - @taimoorzaeem
5759

5860
## [12.2.10] - 2025-04-18
5961

docs/postgrest.dict

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Cloudflare
2323
config
2424
cors
2525
CORS
26+
cryptographic
2627
cryptographically
2728
CSV
2829
durations
@@ -74,6 +75,7 @@ JSON
7475
JSPath
7576
JWK
7677
JWT
78+
JWTs
7779
jwt
7880
Keycloak
7981
Kubernetes
@@ -84,6 +86,7 @@ Logins
8486
LIBPQ
8587
logins
8688
lon
89+
LRU
8790
lt
8891
lte
8992
macOS

docs/references/auth.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ The ``Bearer`` header value can be used with or without capitalization(``bearer`
9898
JWT Caching
9999
-----------
100100

101-
PostgREST validates ``JWTs`` on every request. We can cache ``JWTs`` to avoid this performance overhead.
101+
PostgREST caches ``JWTs`` on every request to avoid performance overhead of parsing and cryptographic operations.
102102

103-
To enable JWT caching, the config :code:`jwt-cache-max-lifetime` is to be set. It is the maximum number of seconds for which the cache stores the JWT validation results. The cache uses the :code:`exp` claim to set the cache entry lifetime. If the JWT does not have an :code:`exp` claim, it uses the config value. See :ref:`jwt-cache-max-lifetime` for more details.
103+
To disable JWT caching, the config :code:`jwt-cache-max-entries` is to be set to ``0``. It is the maximum number of JWTs for which the cache stores their validation results. If the cache reaches its maximum, the `least recently used <https://redis.io/glossary/lru-cache/>`_ entry will be removed. The cache honors :code:`exp` claim. If the JWT does not have an :code:`exp` claim, it is cached until it gets removed by the LRU policy. See :ref:`jwt-cache-max-entries` for configuration details.
104104

105105
.. note::
106106

docs/references/configuration.rst

+7-7
Original file line numberDiff line numberDiff line change
@@ -658,20 +658,20 @@ jwt-secret-is-base64
658658

659659
When this is set to :code:`true`, the value derived from :code:`jwt-secret` will be treated as a base64 encoded secret.
660660

661-
.. _jwt-cache-max-lifetime:
661+
.. _jwt-cache-max-entries:
662662

663-
jwt-cache-max-lifetime
664-
----------------------
663+
jwt-cache-max-entries
664+
---------------------
665665

666666
=============== =================================
667667
**Type** Int
668-
**Default** 0
668+
**Default** 1000
669669
**Reloadable** Y
670-
**Environment** PGRST_JWT_CACHE_MAX_LIFETIME
671-
**In-Database** pgrst.jwt_cache_max_lifetime
670+
**Environment** PGRST_JWT_CACHE_MAX_ENTRIES
671+
**In-Database** pgrst.jwt_cache_max_entries
672672
=============== =================================
673673

674-
Maximum number of seconds of lifetime for cached entries. The default :code:`0` disables caching. See :ref:`jwt_caching`.
674+
Maximum number of JWTs that can be cached. Set to :code:`0` to disable caching. See :ref:`jwt_caching`.
675675

676676
.. _log-level:
677677

postgrest.cabal

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,8 @@ library
9999
, auto-update >= 0.1.4 && < 0.2
100100
, base64-bytestring >= 1 && < 1.3
101101
, bytestring >= 0.10.8 && < 0.13
102-
, cache >= 0.1.3 && < 0.2.0
103102
, case-insensitive >= 1.2 && < 1.3
104103
, cassava >= 0.4.5 && < 0.6
105-
, clock >= 0.8.3 && < 0.9.0
106104
, configurator-pg >= 0.2 && < 0.3
107105
, containers >= 0.5.7 && < 0.7
108106
, cookie >= 0.4.2 && < 0.5
@@ -122,6 +120,7 @@ library
122120
, jose-jwt >= 0.9.6 && < 0.11
123121
, lens >= 4.14 && < 5.3
124122
, lens-aeson >= 1.0.1 && < 1.3
123+
, lrucache >= 1.2.0.1 && < 1.3
125124
, mtl >= 2.2.2 && < 2.4
126125
, neat-interpolation >= 0.5 && < 0.6
127126
, network >= 2.6 && < 3.2

src/PostgREST/AppState.hs

+12-12
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef,
5757
readIORef)
5858
import Data.Time.Clock (UTCTime, getCurrentTime)
5959

60-
import PostgREST.Auth.JwtCache (JwtCacheState)
6160
import PostgREST.Config (AppConfig (..),
6261
addFallbackAppName,
6362
readAppConfig)
@@ -105,8 +104,8 @@ data AppState = AppState
105104
, stateSocketAdmin :: Maybe NS.Socket
106105
-- | Observation handler
107106
, stateObserver :: ObservationHandler
108-
-- | JWT Cache
109-
, stateJwtCache :: JwtCache.JwtCacheState
107+
-- | JWT Cache, disabled when config jwt-cache-max-entries is set to 0
108+
, stateJwtCache :: IORef JwtCache.JwtCacheState
110109
, stateLogger :: Logger.LoggerState
111110
, stateMetrics :: Metrics.MetricsState
112111
}
@@ -120,14 +119,14 @@ data SchemaCacheStatus
120119
type AppSockets = (NS.Socket, Maybe NS.Socket)
121120

122121
init :: AppConfig -> IO AppState
123-
init conf@AppConfig{configLogLevel, configDbPoolSize} = do
122+
init conf = do
124123
loggerState <- Logger.init
125-
metricsState <- Metrics.init configDbPoolSize
126-
let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState)
124+
metricsState <- Metrics.init (configDbPoolSize conf)
125+
let observer = liftA2 (>>) (Logger.observationLogger loggerState (configLogLevel conf)) (Metrics.observationMetrics metricsState)
127126

128127
observer $ AppStartObs prettyVersion
129128

130-
jwtCacheState <- JwtCache.init
129+
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries conf)
131130
pool <- initPool conf observer
132131
(sock, adminSock) <- initSockets conf
133132
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer
@@ -150,7 +149,7 @@ initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState
150149
<*> pure sock
151150
<*> pure adminSock
152151
<*> pure observer
153-
<*> pure jwtCacheState
152+
<*> newIORef jwtCacheState
154153
<*> pure loggerState
155154
<*> pure metricsState
156155

@@ -311,8 +310,8 @@ putConfig = atomicWriteIORef . stateConf
311310
getTime :: AppState -> IO UTCTime
312311
getTime = stateGetTime
313312

314-
getJwtCacheState :: AppState -> JwtCacheState
315-
getJwtCacheState = stateJwtCache
313+
getJwtCacheState :: AppState -> IO JwtCache.JwtCacheState
314+
getJwtCacheState = readIORef . stateJwtCache
316315

317316
getSocketREST :: AppState -> NS.Socket
318317
getSocketREST = stateSocketREST
@@ -473,8 +472,9 @@ readInDbConfig startingUp appState@AppState{stateObserver=observer} = do
473472
-- entries, because they were cached using the old secret
474473
if configJwtSecret conf == configJwtSecret newConf then
475474
pass
476-
else
477-
JwtCache.emptyCache (getJwtCacheState appState) -- atomic O(1) operation
475+
else do
476+
newJwtCacheState <- JwtCache.init (configJwtCacheMaxEntries newConf)
477+
atomicWriteIORef (stateJwtCache appState) newJwtCacheState
478478

479479
if startingUp then
480480
pass

src/PostgREST/Auth.hs

+17-20
Original file line numberDiff line numberDiff line change
@@ -160,30 +160,27 @@ middleware appState app req respond = do
160160

161161
let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
162162
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf
163-
jwtCacheState = getJwtCacheState appState
163+
164+
jwtCacheState <- getJwtCacheState appState
164165

165166
-- If ServerTimingEnabled -> calculate JWT validation time
166-
-- If JwtCacheMaxLifetime -> cache JWT validation result
167-
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
168-
(True, 0) -> do
169-
(dur, authResult) <- timeItT parseJwt
170-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
171-
172-
(True, maxLifetime) -> do
173-
(dur, authResult) <- timeItT $ case token of
174-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
175-
Nothing -> parseJwt
176-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
167+
req' <- if configServerTimingEnabled conf then do
168+
169+
(dur, authResult) <- timeItT $ case token of
170+
171+
Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
172+
Nothing -> parseJwt
177173

178-
(False, 0) -> do
179-
authResult <- parseJwt
180-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
174+
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
175+
176+
else do
177+
178+
authResult <- case token of
179+
180+
Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
181+
Nothing -> parseJwt
181182

182-
(False, maxLifetime) -> do
183-
authResult <- case token of
184-
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
185-
Nothing -> parseJwt
186-
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
183+
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
187184

188185
app req' respond
189186

src/PostgREST/Auth/JwtCache.hs

+73-70
Original file line numberDiff line numberDiff line change
@@ -4,96 +4,99 @@ Description : PostgREST Jwt Authentication Result Cache.
44
55
This module provides functions to deal with the JWT cache
66
-}
7-
{-# LANGUAGE NamedFieldPuns #-}
87
module PostgREST.Auth.JwtCache
98
( init
109
, JwtCacheState
1110
, lookupJwtCache
12-
, emptyCache
1311
) where
1412

1513
import qualified Data.Aeson as JSON
1614
import qualified Data.Aeson.KeyMap as KM
17-
import qualified Data.Cache as C
15+
import qualified Data.Cache.LRU as C
16+
import qualified Data.IORef as I
1817
import qualified Data.Scientific as Sci
1918

20-
import Control.Debounce
21-
2219
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
2320
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
24-
import System.Clock (TimeSpec (..))
21+
import GHC.Num (integerFromInt)
2522

2623
import PostgREST.Auth.Types (AuthResult (..))
2724
import PostgREST.Error (Error (..))
2825

2926
import Protolude
3027

31-
-- | JWT Cache and IO action that triggers purging old entries from the cache
32-
data JwtCacheState = JwtCacheState
33-
{ jwtCache :: C.Cache ByteString AuthResult
34-
, purgeCache :: IO ()
28+
-- | Jwt Cache State
29+
newtype JwtCacheState = JwtCacheState
30+
{ maybeJwtCache :: Maybe (I.IORef (C.LRU ByteString AuthResult))
3531
}
3632

3733
-- | Initialize JwtCacheState
38-
init :: IO JwtCacheState
39-
init = do
40-
cache <- C.newCache Nothing -- no default expiration
41-
-- purgeExpired has O(n^2) complexity
42-
-- so we wrap it in debounce to make sure it:
43-
-- 1) is executed asynchronously
44-
-- 2) only a single purge operation is running at a time
45-
debounce <- mkDebounce defaultDebounceSettings
46-
-- debounceFreq is set to default 1 second
47-
{ debounceAction = C.purgeExpired cache
48-
, debounceEdge = leadingEdge
49-
}
50-
pure $ JwtCacheState cache debounce
34+
init :: Int -> IO JwtCacheState
35+
init 0 = return $ JwtCacheState Nothing
36+
init maxEntries = do
37+
cache <- I.newIORef $ C.newLRU (Just $ integerFromInt maxEntries)
38+
return $ JwtCacheState $ Just cache
39+
5140

5241
-- | Used to retrieve and insert JWT to JWT Cache
53-
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
54-
lookupJwtCache JwtCacheState{jwtCache, purgeCache} token maxLifetime parseJwt utc = do
55-
checkCache <- C.lookup jwtCache token
56-
authResult <- maybe parseJwt (pure . Right) checkCache
57-
58-
case (authResult,checkCache) of
59-
-- From comment:
60-
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
61-
--
62-
-- We purge expired cache entries on a cache miss
63-
-- The reasoning is that:
64-
--
65-
-- 1. We expect it to be rare (otherwise there is no point of the cache)
66-
-- 2. It makes sure the cache is not growing (as inserting new entries
67-
-- does garbage collection)
68-
-- 3. Since this is time expiration based cache there is no real risk of
69-
-- starvation - sooner or later we are going to have a cache miss.
70-
71-
(Right res, Nothing) -> do -- cache miss
72-
73-
let timeSpec = getTimeSpec res maxLifetime utc
74-
75-
-- insert new cache entry
76-
C.insert' jwtCache (Just timeSpec) token res
77-
78-
-- Execute IO action to purge the cache
79-
-- It is assumed this action returns immidiately
80-
-- so that request processing is not blocked.
81-
purgeCache
82-
83-
_ -> pure ()
84-
85-
return authResult
86-
87-
-- Used to extract JWT exp claim and add to JWT Cache
88-
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
89-
getTimeSpec res maxLifetime utc = do
90-
let expireJSON = KM.lookup "exp" (authClaims res)
91-
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
92-
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
93-
case expireJSON of
94-
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
95-
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0
96-
97-
-- | Empty the cache (done when the config is reloaded)
98-
emptyCache :: JwtCacheState -> IO ()
99-
emptyCache JwtCacheState{jwtCache} = C.purge jwtCache
42+
lookupJwtCache :: JwtCacheState -> ByteString -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
43+
lookupJwtCache jwtCacheState token parseJwt utc = do
44+
case maybeJwtCache jwtCacheState of
45+
Nothing -> parseJwt
46+
Just jwtCacheIORef -> do
47+
-- get cache from IORef
48+
jwtCache <- I.readIORef jwtCacheIORef
49+
50+
-- lookup key = token
51+
-- MAKE SURE WE UPDATE THE CACHE ON ALL PATHS AFTER LOOKUP (changes cache)
52+
let (jwtCache', maybeVal) = C.lookup token jwtCache
53+
54+
case maybeVal of
55+
Nothing -> do -- CACHE MISS
56+
57+
-- When we get a cache miss, we get the parse result, insert it
58+
-- into the cache. After that, we write the cache IO ref with
59+
-- updated cache
60+
authResult <- parseJwt
61+
62+
case authResult of
63+
Right result -> do
64+
-- insert token -> update cache -> return token
65+
let jwtCache'' = C.insert token result jwtCache'
66+
I.writeIORef jwtCacheIORef jwtCache''
67+
return $ Right result
68+
Left e -> do
69+
-- update cache after lookup -> return error
70+
I.writeIORef jwtCacheIORef jwtCache'
71+
return $ Left e
72+
73+
Just result -> -- CACHE HIT
74+
75+
-- For cache hit, we get the result from cache, we check the
76+
-- exp claim. If it expired, we delete it from cache and parse
77+
-- the jwt. Otherwise, the hit result is valid, so we return it
78+
79+
if isExpClaimExpired result utc then do
80+
-- delete token -> update cache -> parse token
81+
let (jwtCache'',_) = C.delete token jwtCache'
82+
I.writeIORef jwtCacheIORef jwtCache''
83+
parseJwt
84+
else do
85+
-- update cache afte lookup -> return result
86+
I.writeIORef jwtCacheIORef jwtCache'
87+
return $ Right result
88+
89+
90+
type Expired = Bool
91+
92+
-- | Check if exp claim is expired when looked up from cache
93+
isExpClaimExpired :: AuthResult -> UTCTime -> Expired
94+
isExpClaimExpired result utc =
95+
case expireJSON of
96+
Nothing -> False -- if exp not present then it is valid
97+
Just (JSON.Number expiredAt) -> (sciToInt expiredAt - now) < 0
98+
Just _ -> False -- if exp is not a number then valid
99+
where
100+
expireJSON = KM.lookup "exp" (authClaims result)
101+
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
102+
sciToInt = fromMaybe 0 . Sci.toBoundedInteger

src/PostgREST/CLI.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ exampleConfigFile =
203203
|# jwt-secret = "secret_with_at_least_32_characters"
204204
|jwt-secret-is-base64 = false
205205
|
206-
|## Enables and set JWT Cache max lifetime, disables caching with 0
207-
|# jwt-cache-max-lifetime = 0
206+
|## Maximum number of auth token that can be cached
207+
|# jwt-cache-max-entries = 1000
208208
|
209209
|## Logging level, the admitted values are: crit, error, warn, info and debug.
210210
|log-level = "error"

0 commit comments

Comments
 (0)