-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: change JWT cache to limited LRU based cache #4008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,7 +57,6 @@ import Data.IORef (IORef, atomicWriteIORef, newIORef, | |
readIORef) | ||
import Data.Time.Clock (UTCTime, getCurrentTime) | ||
|
||
import PostgREST.Auth.JwtCache (JwtCacheState) | ||
import PostgREST.Config (AppConfig (..), | ||
addFallbackAppName, | ||
readAppConfig) | ||
|
@@ -105,8 +104,8 @@ data AppState = AppState | |
, stateSocketAdmin :: Maybe NS.Socket | ||
-- | Observation handler | ||
, stateObserver :: ObservationHandler | ||
-- | JWT Cache | ||
, stateJwtCache :: JwtCache.JwtCacheState | ||
-- | JWT Cache, disabled when config jwt-cache-max-entries is set to 0 | ||
, stateJwtCache :: IORef JwtCache.JwtCacheState | ||
, stateLogger :: Logger.LoggerState | ||
, stateMetrics :: Metrics.MetricsState | ||
} | ||
|
@@ -120,14 +119,14 @@ data SchemaCacheStatus | |
type AppSockets = (NS.Socket, Maybe NS.Socket) | ||
|
||
init :: AppConfig -> IO AppState | ||
init conf@AppConfig{configLogLevel, configDbPoolSize} = do | ||
init conf = do | ||
loggerState <- Logger.init | ||
metricsState <- Metrics.init configDbPoolSize | ||
let observer = liftA2 (>>) (Logger.observationLogger loggerState configLogLevel) (Metrics.observationMetrics metricsState) | ||
metricsState <- Metrics.init (configDbPoolSize conf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this change have anything to do with JWT cache? |
||
let observer = liftA2 (>>) (Logger.observationLogger loggerState (configLogLevel conf)) (Metrics.observationMetrics metricsState) | ||
|
||
observer $ AppStartObs prettyVersion | ||
|
||
jwtCacheState <- JwtCache.init | ||
jwtCacheState <- JwtCache.init (configJwtCacheMaxEntries conf) | ||
pool <- initPool conf observer | ||
(sock, adminSock) <- initSockets conf | ||
state' <- initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState observer | ||
|
@@ -150,7 +149,7 @@ initWithPool (sock, adminSock) pool conf jwtCacheState loggerState metricsState | |
<*> pure sock | ||
<*> pure adminSock | ||
<*> pure observer | ||
<*> pure jwtCacheState | ||
<*> newIORef jwtCacheState | ||
<*> pure loggerState | ||
<*> pure metricsState | ||
|
||
|
@@ -311,8 +310,8 @@ putConfig = atomicWriteIORef . stateConf | |
getTime :: AppState -> IO UTCTime | ||
getTime = stateGetTime | ||
|
||
getJwtCacheState :: AppState -> JwtCacheState | ||
getJwtCacheState = stateJwtCache | ||
getJwtCacheState :: AppState -> IO JwtCache.JwtCacheState | ||
getJwtCacheState = readIORef . stateJwtCache | ||
|
||
getSocketREST :: AppState -> NS.Socket | ||
getSocketREST = stateSocketREST | ||
|
@@ -473,8 +472,9 @@ readInDbConfig startingUp appState@AppState{stateObserver=observer} = do | |
-- entries, because they were cached using the old secret | ||
if configJwtSecret conf == configJwtSecret newConf then | ||
pass | ||
else | ||
JwtCache.emptyCache (getJwtCacheState appState) -- atomic O(1) operation | ||
else do | ||
newJwtCacheState <- JwtCache.init (configJwtCacheMaxEntries newConf) | ||
atomicWriteIORef (stateJwtCache appState) newJwtCacheState | ||
|
||
if startingUp then | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,96 +4,100 @@ | |
|
||
This module provides functions to deal with the JWT cache | ||
-} | ||
{-# LANGUAGE NamedFieldPuns #-} | ||
module PostgREST.Auth.JwtCache | ||
( init | ||
, JwtCacheState | ||
, lookupJwtCache | ||
, emptyCache | ||
) where | ||
|
||
import qualified Data.Aeson as JSON | ||
import qualified Data.Aeson.KeyMap as KM | ||
import qualified Data.Cache as C | ||
import qualified Data.Cache.LRU as C | ||
import qualified Data.IORef as I | ||
import qualified Data.Scientific as Sci | ||
|
||
import Control.Debounce | ||
|
||
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) | ||
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) | ||
import System.Clock (TimeSpec (..)) | ||
import GHC.Num (integerFromInt) | ||
|
||
import PostgREST.Auth.Types (AuthResult (..)) | ||
import PostgREST.Error (Error (..)) | ||
|
||
import Protolude | ||
|
||
-- | JWT Cache and IO action that triggers purging old entries from the cache | ||
data JwtCacheState = JwtCacheState | ||
{ jwtCache :: C.Cache ByteString AuthResult | ||
, purgeCache :: IO () | ||
-- | Jwt Cache State | ||
newtype JwtCacheState = JwtCacheState | ||
{ maybeJwtCache :: Maybe (I.IORef (C.LRU ByteString AuthResult)) | ||
} | ||
|
||
-- | Initialize JwtCacheState | ||
init :: IO JwtCacheState | ||
init = do | ||
cache <- C.newCache Nothing -- no default expiration | ||
-- purgeExpired has O(n^2) complexity | ||
-- so we wrap it in debounce to make sure it: | ||
-- 1) is executed asynchronously | ||
-- 2) only a single purge operation is running at a time | ||
debounce <- mkDebounce defaultDebounceSettings | ||
-- debounceFreq is set to default 1 second | ||
{ debounceAction = C.purgeExpired cache | ||
, debounceEdge = leadingEdge | ||
} | ||
pure $ JwtCacheState cache debounce | ||
init :: Int -> IO JwtCacheState | ||
init 0 = return $ JwtCacheState Nothing | ||
init maxEntries = do | ||
cache <- I.newIORef $ C.newLRU (Just $ integerFromInt maxEntries) | ||
return $ JwtCacheState $ Just cache | ||
|
||
|
||
-- | Used to retrieve and insert JWT to JWT Cache | ||
lookupJwtCache :: JwtCacheState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) | ||
lookupJwtCache JwtCacheState{jwtCache, purgeCache} token maxLifetime parseJwt utc = do | ||
checkCache <- C.lookup jwtCache token | ||
authResult <- maybe parseJwt (pure . Right) checkCache | ||
|
||
case (authResult,checkCache) of | ||
-- From comment: | ||
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914 | ||
-- | ||
-- We purge expired cache entries on a cache miss | ||
-- The reasoning is that: | ||
-- | ||
-- 1. We expect it to be rare (otherwise there is no point of the cache) | ||
-- 2. It makes sure the cache is not growing (as inserting new entries | ||
-- does garbage collection) | ||
-- 3. Since this is time expiration based cache there is no real risk of | ||
-- starvation - sooner or later we are going to have a cache miss. | ||
|
||
(Right res, Nothing) -> do -- cache miss | ||
|
||
let timeSpec = getTimeSpec res maxLifetime utc | ||
|
||
-- insert new cache entry | ||
C.insert' jwtCache (Just timeSpec) token res | ||
|
||
-- Execute IO action to purge the cache | ||
-- It is assumed this action returns immidiately | ||
-- so that request processing is not blocked. | ||
purgeCache | ||
|
||
_ -> pure () | ||
|
||
return authResult | ||
|
||
-- Used to extract JWT exp claim and add to JWT Cache | ||
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec | ||
getTimeSpec res maxLifetime utc = do | ||
let expireJSON = KM.lookup "exp" (authClaims res) | ||
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds | ||
sciToInt = fromMaybe 0 . Sci.toBoundedInteger | ||
case expireJSON of | ||
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0 | ||
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0 | ||
|
||
-- | Empty the cache (done when the config is reloaded) | ||
emptyCache :: JwtCacheState -> IO () | ||
emptyCache JwtCacheState{jwtCache} = C.purge jwtCache | ||
lookupJwtCache :: JwtCacheState -> ByteString -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) | ||
lookupJwtCache jwtCacheState token parseJwt utc = do | ||
case maybeJwtCache jwtCacheState of | ||
Nothing -> parseJwt | ||
Just jwtCacheIORef -> do | ||
-- get cache from IORef | ||
jwtCache <- I.readIORef jwtCacheIORef | ||
|
||
-- MAKE SURE WE UPDATE THE CACHE ON ALL PATHS AFTER LOOKUP | ||
-- This is because it is a pure LRU cache, so lookup returns the | ||
-- the cache with new state, hence it should be updated | ||
let (jwtCache', maybeVal) = C.lookup token jwtCache | ||
|
||
case maybeVal of | ||
Nothing -> do -- CACHE MISS | ||
|
||
-- When we get a cache miss, we get the parse result, insert it | ||
-- into the cache. After that, we write the cache IO ref with | ||
-- updated cache | ||
authResult <- parseJwt | ||
|
||
case authResult of | ||
Right result -> do | ||
-- insert token -> update cache -> return token | ||
let jwtCache'' = C.insert token result jwtCache' | ||
I.writeIORef jwtCacheIORef jwtCache'' | ||
return $ Right result | ||
Left e -> do | ||
-- update cache after lookup -> return error | ||
I.writeIORef jwtCacheIORef jwtCache' | ||
return $ Left e | ||
|
||
Just result -> -- CACHE HIT | ||
|
||
-- For cache hit, we get the result from cache, we check the | ||
-- exp claim. If it expired, we delete it from cache and parse | ||
-- the jwt. Otherwise, the hit result is valid, so we return it | ||
|
||
if isExpClaimExpired result utc then do | ||
-- delete token -> update cache -> parse token | ||
let (jwtCache'',_) = C.delete token jwtCache' | ||
I.writeIORef jwtCacheIORef jwtCache'' | ||
parseJwt | ||
else do | ||
-- update cache after lookup -> return result | ||
I.writeIORef jwtCacheIORef jwtCache' | ||
return $ Right result | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above code is not thread safe - it is read-modify-write and causes lost updates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should run the entire There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If lookupJwtCache is atomic then you introduce very high contention as all threads wait for each other to just perform lookup. The problem is even worse because with LRU cache lookups are not read-only anymore, so the shared variable has to be updated upon every access. I don't think there is a good solution to this without a proper mutable concurrent high performance LRU cache implementation. @steve-chavez @wolfgangwalther - what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds very plausible, @mkleczek. I suggest we use real numbers to see the effect. Making the whole There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
With many threads, performance for any shared resource will take a hit no matter what because of mutual exclusion.
So is that even possible to have "high" performance with concurrency when waiting will be involved? I think the only way to maintain performance here would be to decrease waiting which means we need to use a distributed the cache which will get complex. 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't think any of that is a blocker for a separate lib, all of that can easily be done immediately via overrides etc. Still, building our own library for that is too big of a task to pursue anyway. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @taimoorzaeem @wolfgangwalther It is implementation based on the blogpost: https://jaspervdj.be/posts/2015-02-24-lru-cache.html and uses psqueues. It provides a version of the cache that minimises contention by splitting the cache into stripes: I guess we could configure it to https://hackage.haskell.org/package/base-4.21.0.0/docs/GHC-Conc.html#v:getNumCapabilities number of stripes and not bother users with configuring it. It does not support expiration though. It would be good to add it otherwise the cache is vulnerable to cache thrashing attacks by continuously sending requests with expired JWTs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation is pure as well, so we'd still need to write the new cache on a lookup 😢.
That is not an issue at all. We don't need a built-in expiration support. We can just check the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is not - see this link: https://hackage.haskell.org/package/lrucaching-0.3.4/docs/Data-LruCache-IO.html#v:stripedCached There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
type Expired = Bool | ||
|
||
-- | Check if exp claim is expired when looked up from cache | ||
isExpClaimExpired :: AuthResult -> UTCTime -> Expired | ||
isExpClaimExpired result utc = | ||
case expireJSON of | ||
Nothing -> False -- if exp not present then it is valid | ||
Just (JSON.Number expiredAt) -> (sciToInt expiredAt - now) < 0 | ||
Just _ -> False -- if exp is not a number then valid | ||
where | ||
expireJSON = KM.lookup "exp" (authClaims result) | ||
now = (floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds) utc :: Int | ||
sciToInt = fromMaybe 0 . Sci.toBoundedInteger | ||
Comment on lines
+93
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this take the 30 seconds clock skew we allow into account?
This comment was marked as outdated.
Sorry, something went wrong. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if IORef is needed here - we already have a mutable variable in JwtCacheState? Why do we need another one here?
A cleaner solution would be to export reset :: JwtCacheState -> IO () function from JwtCache module and call it upon config reload.