Skip to content

feat: change JWT cache to limited LRU based cache #4008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
+ The selected columns in the embedded resources are aggregated into arrays
+ Aggregates are not supported
- #2967, Add `Proxy-Status` header for better error response - @taimoorzaeem
- #4003, Add config `jwt-cache-max-entries` for maximum number of cached entries - @taimoorzaeem

### Fixed

Expand Down Expand Up @@ -55,6 +56,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
+ Diagnostic error messages instead of exposed internals
+ Return new `PGRST303` error when jwt claims decoding fails
- #3906, Return `PGRST125` and `PGRST126` errors instead of empty json - @taimoorzaeem
- #4003, Remove config `jwt-cache-max-lifetime` and add config `jwt-cache-max-entries` for JWT cache - @taimoorzaeem

## [12.2.10] - 2025-04-18

Expand Down
3 changes: 3 additions & 0 deletions docs/postgrest.dict
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Cloudflare
config
cors
CORS
cryptographic
cryptographically
CSV
durations
Expand Down Expand Up @@ -74,6 +75,7 @@ JSON
JSPath
JWK
JWT
JWTs
jwt
Keycloak
Kubernetes
Expand All @@ -84,6 +86,7 @@ Logins
LIBPQ
logins
lon
LRU
lt
lte
macOS
Expand Down
4 changes: 2 additions & 2 deletions docs/references/auth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ The ``Bearer`` header value can be used with or without capitalization(``bearer`
JWT Caching
-----------

PostgREST validates ``JWTs`` on every request. We can cache ``JWTs`` to avoid this performance overhead.
PostgREST caches ``JWTs`` on every request to avoid performance overhead of parsing and cryptographic operations.

To enable JWT caching, the config :code:`jwt-cache-max-lifetime` is to be set. It is the maximum number of seconds for which the cache stores the JWT validation results. The cache uses the :code:`exp` claim to set the cache entry lifetime. If the JWT does not have an :code:`exp` claim, it uses the config value. See :ref:`jwt-cache-max-lifetime` for more details.
To disable JWT caching, the config :code:`jwt-cache-max-entries` is to be set to ``0``. It is the maximum number of JWTs for which the cache stores their validation results. If the cache reaches its maximum, the `least recently used <https://redis.io/glossary/lru-cache/>`_ entry will be removed. The cache honors :code:`exp` claim. If the JWT does not have an :code:`exp` claim, it is cached until it gets removed by the LRU policy. See :ref:`jwt-cache-max-entries` for configuration details.

.. note::

Expand Down
14 changes: 7 additions & 7 deletions docs/references/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -658,20 +658,20 @@ jwt-secret-is-base64

When this is set to :code:`true`, the value derived from :code:`jwt-secret` will be treated as a base64 encoded secret.

.. _jwt-cache-max-lifetime:
.. _jwt-cache-max-entries:

jwt-cache-max-lifetime
----------------------
jwt-cache-max-entries
---------------------

=============== =================================
**Type** Int
**Default** 0
**Default** 1000
**Reloadable** Y
**Environment** PGRST_JWT_CACHE_MAX_LIFETIME
**In-Database** pgrst.jwt_cache_max_lifetime
**Environment** PGRST_JWT_CACHE_MAX_ENTRIES
**In-Database** pgrst.jwt_cache_max_entries
=============== =================================

Maximum number of seconds of lifetime for cached entries. The default :code:`0` disables caching. See :ref:`jwt_caching`.
Maximum number of JWTs that can be cached. Set to :code:`0` to disable caching. See :ref:`jwt_caching`.

.. _log-level:

Expand Down
2 changes: 1 addition & 1 deletion nix/tools/loadtest.nix
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ let
export PGRST_DB_TX_END="rollback-allow-override"
export PGRST_LOG_LEVEL="crit"
export PGRST_JWT_SECRET="reallyreallyreallyreallyverysafe"
export PGRST_JWT_CACHE_MAX_LIFETIME="86400"
export PGRST_JWT_CACHE_MAX_ENTRIES="1000" # default

mkdir -p "$(dirname "$_arg_output")"
abs_output="$(realpath "$_arg_output")"
Expand Down
3 changes: 1 addition & 2 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ library
, auto-update >= 0.1.4 && < 0.2
, base64-bytestring >= 1 && < 1.3
, bytestring >= 0.10.8 && < 0.13
, cache >= 0.1.3 && < 0.2.0
, case-insensitive >= 1.2 && < 1.3
, cassava >= 0.4.5 && < 0.6
, clock >= 0.8.3 && < 0.9.0
, configurator-pg >= 0.2 && < 0.3
, containers >= 0.5.7 && < 0.7
, cookie >= 0.4.2 && < 0.5
Expand All @@ -122,6 +120,7 @@ library
, jose-jwt >= 0.9.6 && < 0.11
, lens >= 4.14 && < 5.3
, lens-aeson >= 1.0.1 && < 1.3
, lrucache >= 1.2.0.1 && < 1.3
, mtl >= 2.2.2 && < 2.4
, neat-interpolation >= 0.5 && < 0.6
, network >= 2.6 && < 3.2
Expand Down
24 changes: 12 additions & 12 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef,
readIORef)
import Data.Time.Clock (UTCTime, getCurrentTime)

import PostgREST.Auth.JwtCache (JwtCacheState)
import PostgREST.Config (AppConfig (..),
addFallbackAppName,
readAppConfig)
Expand Down Expand Up @@ -105,8 +104,8 @@ data AppState = AppState
, stateSocketAdmin :: Maybe NS.Socket
-- | Observation handler
, stateObserver :: ObservationHandler
-- | JWT Cache
, stateJwtCache :: JwtCache.JwtCacheState
-- | JWT Cache, disabled when config jwt-cache-max-entries is set to 0
, stateJwtCache :: IORef JwtCache.JwtCacheState
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if IORef is needed here - we already have a mutable variable in JwtCacheState? Why do we need another one here?
A cleaner solution would be to export reset :: JwtCacheState -> IO () function from JwtCache module and call it upon config reload.

, stateLogger :: Logger.LoggerState
, stateMetrics :: Metrics.MetricsState
}
Expand All @@ -120,14 +119,14 @@ data SchemaCacheStatus
type AppSockets = (NS.Socket, Maybe NS.Socket)

init :: AppConfig -> IO AppState
init conf@AppConfig{configLogLevel, configDbPoolSize} = do
init conf = do
loggerState <- Logger.init
metricsState <- Metrics.init configDbPoolSize
let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState)
metricsState <- Metrics.init (configDbPoolSize conf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change have anything to do with JWT cache?

let observer = liftA2 (>>) (Logger.observationLogger loggerState (configLogLevel conf)) (Metrics.observationMetrics metricsState)

observer $ AppStartObs prettyVersion

jwtCacheState <- JwtCache.init
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries conf)
pool <- initPool conf observer
(sock, adminSock) <- initSockets conf
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer
Expand All @@ -150,7 +149,7 @@ initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState
<*> pure sock
<*> pure adminSock
<*> pure observer
<*> pure jwtCacheState
<*> newIORef jwtCacheState
<*> pure loggerState
<*> pure metricsState

Expand Down Expand Up @@ -311,8 +310,8 @@ putConfig = atomicWriteIORef . stateConf
getTime :: AppState -> IO UTCTime
getTime = stateGetTime

getJwtCacheState :: AppState -> JwtCacheState
getJwtCacheState = stateJwtCache
getJwtCacheState :: AppState -> IO JwtCache.JwtCacheState
getJwtCacheState = readIORef . stateJwtCache

getSocketREST :: AppState -> NS.Socket
getSocketREST = stateSocketREST
Expand Down Expand Up @@ -473,8 +472,9 @@ readInDbConfig startingUp appState@AppState{stateObserver=observer} = do
-- entries, because they were cached using the old secret
if configJwtSecret conf == configJwtSecret newConf then
pass
else
JwtCache.emptyCache (getJwtCacheState appState) -- atomic O(1) operation
else do
newJwtCacheState <- JwtCache.init (configJwtCacheMaxEntries newConf)
atomicWriteIORef (stateJwtCache appState) newJwtCacheState

if startingUp then
pass
Expand Down
37 changes: 17 additions & 20 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -160,30 +160,27 @@ middleware appState app req respond = do

let token = Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf
jwtCacheState = getJwtCacheState appState

jwtCacheState <- getJwtCacheState appState

-- If ServerTimingEnabled -> calculate JWT validation time
-- If JwtCacheMaxLifetime -> cache JWT validation result
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
(True, 0) -> do
(dur, authResult) <- timeItT parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

(True, maxLifetime) -> do
(dur, authResult) <- timeItT $ case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }
req' <- if configServerTimingEnabled conf then do

(dur, authResult) <- timeItT $ case token of

Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
Nothing -> parseJwt

(False, 0) -> do
authResult <- parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult & Vault.insert jwtDurKey dur }

else do

authResult <- case token of

Just tkn -> lookupJwtCache jwtCacheState tkn parseJwt time
Nothing -> parseJwt

(False, maxLifetime) -> do
authResult <- case token of
Just tkn -> lookupJwtCache jwtCacheState tkn maxLifetime parseJwt time
Nothing -> parseJwt
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }
return $ req { Wai.vault = Wai.vault req & Vault.insert authResultKey authResult }

app req' respond

Expand Down
144 changes: 74 additions & 70 deletions src/PostgREST/Auth/JwtCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,96 +4,100 @@

This module provides functions to deal with the JWT cache
-}
{-# LANGUAGE NamedFieldPuns #-}
module PostgREST.Auth.JwtCache
( init
, JwtCacheState
, lookupJwtCache
, emptyCache
) where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as KM
import qualified Data.Cache as C
import qualified Data.Cache.LRU as C
import qualified Data.IORef as I
import qualified Data.Scientific as Sci

import Control.Debounce

import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.Clock (TimeSpec (..))
import GHC.Num (integerFromInt)

import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Error (Error (..))

import Protolude

-- | JWT Cache and IO action that triggers purging old entries from the cache
data JwtCacheState = JwtCacheState
{ jwtCache :: C.Cache ByteString AuthResult
, purgeCache :: IO ()
-- | Jwt Cache State
newtype JwtCacheState = JwtCacheState
{ maybeJwtCache :: Maybe (I.IORef (C.LRU ByteString AuthResult))
}

-- | Initialize JwtCacheState
init :: IO JwtCacheState
init = do
cache <- C.newCache Nothing -- no default expiration
-- purgeExpired has O(n^2) complexity
-- so we wrap it in debounce to make sure it:
-- 1) is executed asynchronously
-- 2) only a single purge operation is running at a time
debounce <- mkDebounce defaultDebounceSettings
-- debounceFreq is set to default 1 second
{ debounceAction = C.purgeExpired cache
, debounceEdge = leadingEdge
}
pure $ JwtCacheState cache debounce
init :: Int -> IO JwtCacheState
init 0 = return $ JwtCacheState Nothing
init maxEntries = do
cache <- I.newIORef $ C.newLRU (Just $ integerFromInt maxEntries)
return $ JwtCacheState $ Just cache


-- | Used to retrieve and insert JWT to JWT Cache
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
lookupJwtCache JwtCacheState{jwtCache, purgeCache} token maxLifetime parseJwt utc = do
checkCache <- C.lookup jwtCache token
authResult <- maybe parseJwt (pure . Right) checkCache

case (authResult,checkCache) of
-- From comment:
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
--
-- We purge expired cache entries on a cache miss
-- The reasoning is that:
--
-- 1. We expect it to be rare (otherwise there is no point of the cache)
-- 2. It makes sure the cache is not growing (as inserting new entries
-- does garbage collection)
-- 3. Since this is time expiration based cache there is no real risk of
-- starvation - sooner or later we are going to have a cache miss.

(Right res, Nothing) -> do -- cache miss

let timeSpec = getTimeSpec res maxLifetime utc

-- insert new cache entry
C.insert' jwtCache (Just timeSpec) token res

-- Execute IO action to purge the cache
-- It is assumed this action returns immidiately
-- so that request processing is not blocked.
purgeCache

_ -> pure ()

return authResult

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
getTimeSpec res maxLifetime utc = do
let expireJSON = KM.lookup "exp" (authClaims res)
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
case expireJSON of
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0

-- | Empty the cache (done when the config is reloaded)
emptyCache :: JwtCacheState -> IO ()
emptyCache JwtCacheState{jwtCache} = C.purge jwtCache
lookupJwtCache :: JwtCacheState -> ByteString -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
lookupJwtCache jwtCacheState token parseJwt utc = do
case maybeJwtCache jwtCacheState of
Nothing -> parseJwt
Just jwtCacheIORef -> do
-- get cache from IORef
jwtCache <- I.readIORef jwtCacheIORef

-- MAKE SURE WE UPDATE THE CACHE ON ALL PATHS AFTER LOOKUP
-- This is because it is a pure LRU cache, so lookup returns the
-- the cache with new state, hence it should be updated
let (jwtCache', maybeVal) = C.lookup token jwtCache

case maybeVal of
Nothing -> do -- CACHE MISS

-- When we get a cache miss, we get the parse result, insert it
-- into the cache. After that, we write the cache IO ref with
-- updated cache
authResult <- parseJwt

case authResult of
Right result -> do
-- insert token -> update cache -> return token
let jwtCache'' = C.insert token result jwtCache'
I.writeIORef jwtCacheIORef jwtCache''
return $ Right result
Left e -> do
-- update cache after lookup -> return error
I.writeIORef jwtCacheIORef jwtCache'
return $ Left e

Just result -> -- CACHE HIT

-- For cache hit, we get the result from cache, we check the
-- exp claim. If it expired, we delete it from cache and parse
-- the jwt. Otherwise, the hit result is valid, so we return it

if isExpClaimExpired result utc then do
-- delete token -> update cache -> parse token
let (jwtCache'',_) = C.delete token jwtCache'

Check warning on line 82 in src/PostgREST/Auth/JwtCache.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/Auth/JwtCache.hs#L82

Added line #L82 was not covered by tests
I.writeIORef jwtCacheIORef jwtCache''
parseJwt
else do
-- update cache after lookup -> return result
I.writeIORef jwtCacheIORef jwtCache'
return $ Right result

Copy link
Contributor

@mkleczek mkleczek Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above code is not thread safe - it is read-modify-write and causes lost updates.
Granted - in case of caching lost updates might not be that much of a deal (athough see my comment about possible performance implications) but maybe TRef instead of IORef would be better here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should run the entire lookupJwtCache function atomically. This would make it more thread safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If lookupJwtCache is atomic then you introduce very high contention as all threads wait for each other to just perform lookup. The problem is even worse because with LRU cache lookups are not read-only anymore, so the shared variable has to be updated upon every access.

I don't think there is a good solution to this without a proper mutable concurrent high performance LRU cache implementation.

@steve-chavez @wolfgangwalther - what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds very plausible, @mkleczek.

I suggest we use real numbers to see the effect. Making the whole lookupJwtCache function atomic should not be hard (?) - and then we can run the JWT loadtest on it. Then compare that to running with https://hackage.haskell.org/package/psqueues. Ideally, we'd have both branches available, so we can run a comparison in one go.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If lookupJwtCache is atomic then you introduce very high contention as all threads wait for each other to just perform lookup. The problem is even worse because with LRU cache lookups are not read-only anymore, so the shared variable has to be updated upon every access.

With many threads, performance for any shared resource will take a hit no matter what because of mutual exclusion.

I don't think there is a good solution to this without a proper mutable concurrent high performance LRU cache implementation.

So is that even possible to have "high" performance with concurrency when waiting will be involved? I think the only way to maintain performance here would be to decrease waiting which means we need to use a distributed the cache which will get complex. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maintaining a separate library would have been possible if we only used hackage+cabal, but we don't. It takes months for a new version of a library to be available on Stack and NixOS. I think our own implementation in postgrest would be our best bet.

I don't think any of that is a blocker for a separate lib, all of that can easily be done immediately via overrides etc.

Still, building our own library for that is too big of a task to pursue anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taimoorzaeem @wolfgangwalther
After some research I found the right library: https://hackage.haskell.org/package/lrucaching-0.3.4

It is implementation based on the blogpost: https://jaspervdj.be/posts/2015-02-24-lru-cache.html and uses psqueues.

It provides a version of the cache that minimises contention by splitting the cache into stripes:
https://hackage.haskell.org/package/lrucaching-0.3.4/docs/Data-LruCache-IO.html#v:stripedCached

I guess we could configure it to https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-Conc.html#v:getNumCapabilities number of stripes and not bother users with configuring it.

It does not support expiration though. It would be good to add it otherwise the cache is vulnerable to cache thrashing attacks by continuously sending requests with expired JWTs.

Copy link
Collaborator Author

@taimoorzaeem taimoorzaeem May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is pure as well, so we'd still need to write the new cache on a lookup 😢.

It does not support expiration though. It would be good to add it otherwise the cache is vulnerable to cache thrashing attacks by continuously sending requests with expired JWTs.

That is not an issue at all. We don't need a built-in expiration support. We can just check the exp claim on lookup. It would just be an O(1) task.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is pure as well, so we'd still need to write the new cache on a lookup 😢.

It is not - see this link: https://hackage.haskell.org/package/lrucaching-0.3.4/docs/Data-LruCache-IO.html#v:stripedCached

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stripedCached does not have any interface for insert,update,delete etc. Internally, it is still LruCache.


type Expired = Bool

-- | Check if exp claim is expired when looked up from cache
isExpClaimExpired :: AuthResult -> UTCTime -> Expired
isExpClaimExpired result utc =
case expireJSON of
Nothing -> False -- if exp not present then it is valid
Just (JSON.Number expiredAt) -> (sciToInt expiredAt - now) < 0
Just _ -> False -- if exp is not a number then valid

Check warning on line 99 in src/PostgREST/Auth/JwtCache.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/Auth/JwtCache.hs#L99

Added line #L99 was not covered by tests
where
expireJSON = KM.lookup "exp" (authClaims result)
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
Comment on lines +93 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this take the 30 seconds clock skew we allow into account?

This comment was marked as outdated.

Loading