1- use std:: { collections:: HashMap , sync:: Arc , time:: Duration } ;
1+ use std:: {
2+ collections:: HashMap ,
3+ sync:: Arc ,
4+ time:: { Duration , SystemTime , UNIX_EPOCH } ,
5+ } ;
26
37use async_trait:: async_trait;
48use oauth2:: {
@@ -61,6 +65,8 @@ pub struct StoredCredentials {
6165 pub token_response : Option < OAuthTokenResponse > ,
6266 #[ serde( default ) ]
6367 pub granted_scopes : Vec < String > ,
68+ #[ serde( default ) ]
69+ pub token_received_at : Option < u64 > ,
6470}
6571
6672/// Trait for storing and retrieving OAuth2 credentials
@@ -943,34 +949,67 @@ impl AuthorizationManager {
943949 client_id,
944950 token_response : Some ( token_result. clone ( ) ) ,
945951 granted_scopes,
952+ token_received_at : Some ( Self :: now_epoch_secs ( ) ) ,
946953 } ;
947954 self . credential_store . save ( stored) . await ?;
948955
949956 Ok ( token_result)
950957 }
951958
959+ fn now_epoch_secs ( ) -> u64 {
960+ SystemTime :: now ( )
961+ . duration_since ( UNIX_EPOCH )
962+ . unwrap_or_default ( )
963+ . as_secs ( )
964+ }
965+
966+ /// Proactive refresh buffer: refresh tokens this many seconds before they expire
967+ /// to avoid races between token retrieval and the actual HTTP request.
968+ const REFRESH_BUFFER_SECS : u64 = 30 ;
969+
952970 /// get access token, if expired, refresh it automatically
953971 pub async fn get_access_token ( & self ) -> Result < String , AuthError > {
954- // Load credentials from store
955972 let stored = self . credential_store . load ( ) . await ?;
956- let credentials = stored. and_then ( |s| s. token_response ) ;
957-
958- if let Some ( creds) = credentials. as_ref ( ) {
959- // check token expiry if we have a refresh token or an expiry time
960- if creds. refresh_token ( ) . is_some ( ) || creds. expires_in ( ) . is_some ( ) {
961- let expires_in = creds. expires_in ( ) . unwrap_or ( Duration :: from_secs ( 0 ) ) ;
962- if expires_in <= Duration :: from_secs ( 0 ) {
963- tracing:: info!( "Access token expired, refreshing." ) ;
964-
965- let new_creds = self . refresh_token ( ) . await ?;
966- tracing:: info!( "Refreshed access token." ) ;
967- return Ok ( new_creds. access_token ( ) . secret ( ) . to_string ( ) ) ;
968- }
973+ let Some ( stored_creds) = stored else {
974+ return Err ( AuthError :: AuthorizationRequired ) ;
975+ } ;
976+ let Some ( creds) = stored_creds. token_response . as_ref ( ) else {
977+ return Err ( AuthError :: AuthorizationRequired ) ;
978+ } ;
979+
980+ if let ( Some ( expires_in) , Some ( received_at) ) =
981+ ( creds. expires_in ( ) , stored_creds. token_received_at )
982+ {
983+ let elapsed = Self :: now_epoch_secs ( ) . saturating_sub ( received_at) ;
984+ let remaining = expires_in. as_secs ( ) . saturating_sub ( elapsed) ;
985+
986+ if remaining < Self :: REFRESH_BUFFER_SECS {
987+ tracing:: info!(
988+ remaining_secs = remaining,
989+ "Access token expired or nearly expired, refreshing."
990+ ) ;
991+ return self . try_refresh_or_reauth ( ) . await ;
969992 }
993+ }
970994
971- Ok ( creds. access_token ( ) . secret ( ) . to_string ( ) )
972- } else {
973- Err ( AuthError :: AuthorizationRequired )
995+ Ok ( creds. access_token ( ) . secret ( ) . to_string ( ) )
996+ }
997+
998+ /// Attempt to refresh the token. If refresh fails because there is no
999+ /// refresh token or the server rejected it, return `AuthorizationRequired`
1000+ /// so the caller can re-prompt the user. Infrastructure errors (e.g. store
1001+ /// I/O failures, misconfigured client) are propagated as-is.
1002+ async fn try_refresh_or_reauth ( & self ) -> Result < String , AuthError > {
1003+ match self . refresh_token ( ) . await {
1004+ Ok ( new_creds) => {
1005+ tracing:: info!( "Refreshed access token." ) ;
1006+ Ok ( new_creds. access_token ( ) . secret ( ) . to_string ( ) )
1007+ }
1008+ Err ( AuthError :: AuthorizationRequired | AuthError :: TokenRefreshFailed ( _) ) => {
1009+ tracing:: warn!( "Token refresh not possible, re-authorization required." ) ;
1010+ Err ( AuthError :: AuthorizationRequired )
1011+ }
1012+ Err ( e) => Err ( e) ,
9741013 }
9751014 }
9761015
@@ -999,10 +1038,10 @@ impl AuthorizationManager {
9991038 . await
10001039 . map_err ( |e| AuthError :: TokenRefreshFailed ( e. to_string ( ) ) ) ?;
10011040
1002- let granted_scopes: Vec < String > = token_result
1003- . scopes ( )
1004- . map ( |scopes| scopes . iter ( ) . map ( |s| s . to_string ( ) ) . collect ( ) )
1005- . unwrap_or_else ( || self . current_scopes . blocking_read ( ) . clone ( ) ) ;
1041+ let granted_scopes: Vec < String > = match token_result. scopes ( ) {
1042+ Some ( scopes) => scopes . iter ( ) . map ( |s| s . to_string ( ) ) . collect ( ) ,
1043+ None => self . current_scopes . read ( ) . await . clone ( ) ,
1044+ } ;
10061045
10071046 * self . current_scopes . write ( ) . await = granted_scopes. clone ( ) ;
10081047
@@ -1011,6 +1050,7 @@ impl AuthorizationManager {
10111050 client_id,
10121051 token_response : Some ( token_result. clone ( ) ) ,
10131052 granted_scopes,
1053+ token_received_at : Some ( Self :: now_epoch_secs ( ) ) ,
10141054 } ;
10151055 self . credential_store . save ( stored) . await ?;
10161056
@@ -1618,6 +1658,7 @@ impl OAuthState {
16181658 client_id : client_id. to_string ( ) ,
16191659 token_response : Some ( credentials) ,
16201660 granted_scopes,
1661+ token_received_at : Some ( AuthorizationManager :: now_epoch_secs ( ) ) ,
16211662 } ;
16221663 manager. credential_store . save ( stored) . await ?;
16231664
@@ -2636,4 +2677,116 @@ mod tests {
26362677 * manager. scope_upgrade_attempts . write ( ) . await = 1 ;
26372678 assert ! ( manager. can_attempt_scope_upgrade( ) . await ) ;
26382679 }
2680+
2681+ // -- get_access_token --
2682+
2683+ fn make_token_response ( access_token : & str , expires_in_secs : Option < u64 > ) -> OAuthTokenResponse {
2684+ use oauth2:: { AccessToken , EmptyExtraTokenFields , basic:: BasicTokenType } ;
2685+ let mut resp = OAuthTokenResponse :: new (
2686+ AccessToken :: new ( access_token. to_string ( ) ) ,
2687+ BasicTokenType :: Bearer ,
2688+ EmptyExtraTokenFields { } ,
2689+ ) ;
2690+ if let Some ( secs) = expires_in_secs {
2691+ resp. set_expires_in ( Some ( & std:: time:: Duration :: from_secs ( secs) ) ) ;
2692+ }
2693+ resp
2694+ }
2695+
2696+ use super :: { OAuthTokenResponse , StoredCredentials } ;
2697+
2698+ #[ tokio:: test]
2699+ async fn get_access_token_returns_error_when_no_credentials ( ) {
2700+ let manager = AuthorizationManager :: new ( "http://localhost" ) . await . unwrap ( ) ;
2701+ let err = manager. get_access_token ( ) . await . unwrap_err ( ) ;
2702+ assert ! ( matches!( err, AuthError :: AuthorizationRequired ) ) ;
2703+ }
2704+
2705+ #[ tokio:: test]
2706+ async fn get_access_token_returns_token_when_not_expired ( ) {
2707+ let manager = AuthorizationManager :: new ( "http://localhost" ) . await . unwrap ( ) ;
2708+ let stored = StoredCredentials {
2709+ client_id : "test" . to_string ( ) ,
2710+ token_response : Some ( make_token_response ( "my-access-token" , Some ( 3600 ) ) ) ,
2711+ granted_scopes : vec ! [ ] ,
2712+ token_received_at : Some ( AuthorizationManager :: now_epoch_secs ( ) ) ,
2713+ } ;
2714+ manager. credential_store . save ( stored) . await . unwrap ( ) ;
2715+
2716+ let token = manager. get_access_token ( ) . await . unwrap ( ) ;
2717+ assert_eq ! ( token, "my-access-token" ) ;
2718+ }
2719+
2720+ #[ tokio:: test]
2721+ async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token ( ) {
2722+ let mut manager = manager_with_metadata ( None ) . await ;
2723+ manager. configure_client ( test_client_config ( ) ) . unwrap ( ) ;
2724+
2725+ let stored = StoredCredentials {
2726+ client_id : "my-client" . to_string ( ) ,
2727+ token_response : Some ( make_token_response ( "stale-token" , Some ( 3600 ) ) ) ,
2728+ granted_scopes : vec ! [ ] ,
2729+ token_received_at : Some ( AuthorizationManager :: now_epoch_secs ( ) - 7200 ) ,
2730+ } ;
2731+ manager. credential_store . save ( stored) . await . unwrap ( ) ;
2732+
2733+ let err = manager. get_access_token ( ) . await . unwrap_err ( ) ;
2734+ assert ! (
2735+ matches!( err, AuthError :: AuthorizationRequired ) ,
2736+ "expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}"
2737+ ) ;
2738+ }
2739+
2740+ #[ tokio:: test]
2741+ async fn get_access_token_returns_token_without_expiry_info ( ) {
2742+ let manager = AuthorizationManager :: new ( "http://localhost" ) . await . unwrap ( ) ;
2743+ let stored = StoredCredentials {
2744+ client_id : "test" . to_string ( ) ,
2745+ token_response : Some ( make_token_response ( "no-expiry-token" , None ) ) ,
2746+ granted_scopes : vec ! [ ] ,
2747+ token_received_at : None ,
2748+ } ;
2749+ manager. credential_store . save ( stored) . await . unwrap ( ) ;
2750+
2751+ let token = manager. get_access_token ( ) . await . unwrap ( ) ;
2752+ assert_eq ! ( token, "no-expiry-token" ) ;
2753+ }
2754+
2755+ #[ tokio:: test]
2756+ async fn get_access_token_requires_reauth_when_within_refresh_buffer ( ) {
2757+ let mut manager = manager_with_metadata ( None ) . await ;
2758+ manager. configure_client ( test_client_config ( ) ) . unwrap ( ) ;
2759+
2760+ let stored = StoredCredentials {
2761+ client_id : "my-client" . to_string ( ) ,
2762+ token_response : Some ( make_token_response ( "almost-expired" , Some ( 3600 ) ) ) ,
2763+ granted_scopes : vec ! [ ] ,
2764+ token_received_at : Some ( AuthorizationManager :: now_epoch_secs ( ) - 3590 ) ,
2765+ } ;
2766+ manager. credential_store . save ( stored) . await . unwrap ( ) ;
2767+
2768+ let err = manager. get_access_token ( ) . await . unwrap_err ( ) ;
2769+ assert ! (
2770+ matches!( err, AuthError :: AuthorizationRequired ) ,
2771+ "expected AuthorizationRequired when token is within refresh buffer, got: {err:?}"
2772+ ) ;
2773+ }
2774+
2775+ #[ tokio:: test]
2776+ async fn get_access_token_propagates_internal_errors ( ) {
2777+ let manager = AuthorizationManager :: new ( "http://localhost" ) . await . unwrap ( ) ;
2778+ let stored = StoredCredentials {
2779+ client_id : "test" . to_string ( ) ,
2780+ token_response : Some ( make_token_response ( "stale-token" , Some ( 3600 ) ) ) ,
2781+ granted_scopes : vec ! [ ] ,
2782+ token_received_at : Some ( AuthorizationManager :: now_epoch_secs ( ) - 7200 ) ,
2783+ } ;
2784+ manager. credential_store . save ( stored) . await . unwrap ( ) ;
2785+
2786+ let err = manager. get_access_token ( ) . await . unwrap_err ( ) ;
2787+ assert ! (
2788+ matches!( err, AuthError :: InternalError ( _) ) ,
2789+ "expected InternalError when OAuth client is not configured, got: {err:?}"
2790+ ) ;
2791+ }
26392792}
0 commit comments