Skip to content

Commit 83808d3

Browse files
authored
fix: refresh token expiry (#680)
1 parent 66c7000 commit 83808d3

File tree

1 file changed

+175
-22
lines changed

1 file changed

+175
-22
lines changed

crates/rmcp/src/transport/auth.rs

Lines changed: 175 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use std::{collections::HashMap, sync::Arc, time::Duration};
1+
use std::{
2+
collections::HashMap,
3+
sync::Arc,
4+
time::{Duration, SystemTime, UNIX_EPOCH},
5+
};
26

37
use async_trait::async_trait;
48
use oauth2::{
@@ -61,6 +65,8 @@ pub struct StoredCredentials {
6165
pub token_response: Option<OAuthTokenResponse>,
6266
#[serde(default)]
6367
pub granted_scopes: Vec<String>,
68+
#[serde(default)]
69+
pub token_received_at: Option<u64>,
6470
}
6571

6672
/// Trait for storing and retrieving OAuth2 credentials
@@ -943,34 +949,67 @@ impl AuthorizationManager {
943949
client_id,
944950
token_response: Some(token_result.clone()),
945951
granted_scopes,
952+
token_received_at: Some(Self::now_epoch_secs()),
946953
};
947954
self.credential_store.save(stored).await?;
948955

949956
Ok(token_result)
950957
}
951958

959+
fn now_epoch_secs() -> u64 {
960+
SystemTime::now()
961+
.duration_since(UNIX_EPOCH)
962+
.unwrap_or_default()
963+
.as_secs()
964+
}
965+
966+
/// Proactive refresh buffer: refresh tokens this many seconds before they expire
967+
/// to avoid races between token retrieval and the actual HTTP request.
968+
const REFRESH_BUFFER_SECS: u64 = 30;
969+
952970
/// get access token, if expired, refresh it automatically
953971
pub async fn get_access_token(&self) -> Result<String, AuthError> {
954-
// Load credentials from store
955972
let stored = self.credential_store.load().await?;
956-
let credentials = stored.and_then(|s| s.token_response);
957-
958-
if let Some(creds) = credentials.as_ref() {
959-
// check token expiry if we have a refresh token or an expiry time
960-
if creds.refresh_token().is_some() || creds.expires_in().is_some() {
961-
let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0));
962-
if expires_in <= Duration::from_secs(0) {
963-
tracing::info!("Access token expired, refreshing.");
964-
965-
let new_creds = self.refresh_token().await?;
966-
tracing::info!("Refreshed access token.");
967-
return Ok(new_creds.access_token().secret().to_string());
968-
}
973+
let Some(stored_creds) = stored else {
974+
return Err(AuthError::AuthorizationRequired);
975+
};
976+
let Some(creds) = stored_creds.token_response.as_ref() else {
977+
return Err(AuthError::AuthorizationRequired);
978+
};
979+
980+
if let (Some(expires_in), Some(received_at)) =
981+
(creds.expires_in(), stored_creds.token_received_at)
982+
{
983+
let elapsed = Self::now_epoch_secs().saturating_sub(received_at);
984+
let remaining = expires_in.as_secs().saturating_sub(elapsed);
985+
986+
if remaining < Self::REFRESH_BUFFER_SECS {
987+
tracing::info!(
988+
remaining_secs = remaining,
989+
"Access token expired or nearly expired, refreshing."
990+
);
991+
return self.try_refresh_or_reauth().await;
969992
}
993+
}
970994

971-
Ok(creds.access_token().secret().to_string())
972-
} else {
973-
Err(AuthError::AuthorizationRequired)
995+
Ok(creds.access_token().secret().to_string())
996+
}
997+
998+
/// Attempt to refresh the token. If refresh fails because there is no
999+
/// refresh token or the server rejected it, return `AuthorizationRequired`
1000+
/// so the caller can re-prompt the user. Infrastructure errors (e.g. store
1001+
/// I/O failures, misconfigured client) are propagated as-is.
1002+
async fn try_refresh_or_reauth(&self) -> Result<String, AuthError> {
1003+
match self.refresh_token().await {
1004+
Ok(new_creds) => {
1005+
tracing::info!("Refreshed access token.");
1006+
Ok(new_creds.access_token().secret().to_string())
1007+
}
1008+
Err(AuthError::AuthorizationRequired | AuthError::TokenRefreshFailed(_)) => {
1009+
tracing::warn!("Token refresh not possible, re-authorization required.");
1010+
Err(AuthError::AuthorizationRequired)
1011+
}
1012+
Err(e) => Err(e),
9741013
}
9751014
}
9761015

@@ -999,10 +1038,10 @@ impl AuthorizationManager {
9991038
.await
10001039
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;
10011040

1002-
let granted_scopes: Vec<String> = token_result
1003-
.scopes()
1004-
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
1005-
.unwrap_or_else(|| self.current_scopes.blocking_read().clone());
1041+
let granted_scopes: Vec<String> = match token_result.scopes() {
1042+
Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(),
1043+
None => self.current_scopes.read().await.clone(),
1044+
};
10061045

10071046
*self.current_scopes.write().await = granted_scopes.clone();
10081047

@@ -1011,6 +1050,7 @@ impl AuthorizationManager {
10111050
client_id,
10121051
token_response: Some(token_result.clone()),
10131052
granted_scopes,
1053+
token_received_at: Some(Self::now_epoch_secs()),
10141054
};
10151055
self.credential_store.save(stored).await?;
10161056

@@ -1618,6 +1658,7 @@ impl OAuthState {
16181658
client_id: client_id.to_string(),
16191659
token_response: Some(credentials),
16201660
granted_scopes,
1661+
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
16211662
};
16221663
manager.credential_store.save(stored).await?;
16231664

@@ -2636,4 +2677,116 @@ mod tests {
26362677
*manager.scope_upgrade_attempts.write().await = 1;
26372678
assert!(manager.can_attempt_scope_upgrade().await);
26382679
}
2680+
2681+
// -- get_access_token --
2682+
2683+
fn make_token_response(access_token: &str, expires_in_secs: Option<u64>) -> OAuthTokenResponse {
2684+
use oauth2::{AccessToken, EmptyExtraTokenFields, basic::BasicTokenType};
2685+
let mut resp = OAuthTokenResponse::new(
2686+
AccessToken::new(access_token.to_string()),
2687+
BasicTokenType::Bearer,
2688+
EmptyExtraTokenFields {},
2689+
);
2690+
if let Some(secs) = expires_in_secs {
2691+
resp.set_expires_in(Some(&std::time::Duration::from_secs(secs)));
2692+
}
2693+
resp
2694+
}
2695+
2696+
use super::{OAuthTokenResponse, StoredCredentials};
2697+
2698+
#[tokio::test]
2699+
async fn get_access_token_returns_error_when_no_credentials() {
2700+
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
2701+
let err = manager.get_access_token().await.unwrap_err();
2702+
assert!(matches!(err, AuthError::AuthorizationRequired));
2703+
}
2704+
2705+
#[tokio::test]
2706+
async fn get_access_token_returns_token_when_not_expired() {
2707+
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
2708+
let stored = StoredCredentials {
2709+
client_id: "test".to_string(),
2710+
token_response: Some(make_token_response("my-access-token", Some(3600))),
2711+
granted_scopes: vec![],
2712+
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
2713+
};
2714+
manager.credential_store.save(stored).await.unwrap();
2715+
2716+
let token = manager.get_access_token().await.unwrap();
2717+
assert_eq!(token, "my-access-token");
2718+
}
2719+
2720+
#[tokio::test]
2721+
async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token() {
2722+
let mut manager = manager_with_metadata(None).await;
2723+
manager.configure_client(test_client_config()).unwrap();
2724+
2725+
let stored = StoredCredentials {
2726+
client_id: "my-client".to_string(),
2727+
token_response: Some(make_token_response("stale-token", Some(3600))),
2728+
granted_scopes: vec![],
2729+
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
2730+
};
2731+
manager.credential_store.save(stored).await.unwrap();
2732+
2733+
let err = manager.get_access_token().await.unwrap_err();
2734+
assert!(
2735+
matches!(err, AuthError::AuthorizationRequired),
2736+
"expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}"
2737+
);
2738+
}
2739+
2740+
#[tokio::test]
2741+
async fn get_access_token_returns_token_without_expiry_info() {
2742+
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
2743+
let stored = StoredCredentials {
2744+
client_id: "test".to_string(),
2745+
token_response: Some(make_token_response("no-expiry-token", None)),
2746+
granted_scopes: vec![],
2747+
token_received_at: None,
2748+
};
2749+
manager.credential_store.save(stored).await.unwrap();
2750+
2751+
let token = manager.get_access_token().await.unwrap();
2752+
assert_eq!(token, "no-expiry-token");
2753+
}
2754+
2755+
#[tokio::test]
2756+
async fn get_access_token_requires_reauth_when_within_refresh_buffer() {
2757+
let mut manager = manager_with_metadata(None).await;
2758+
manager.configure_client(test_client_config()).unwrap();
2759+
2760+
let stored = StoredCredentials {
2761+
client_id: "my-client".to_string(),
2762+
token_response: Some(make_token_response("almost-expired", Some(3600))),
2763+
granted_scopes: vec![],
2764+
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 3590),
2765+
};
2766+
manager.credential_store.save(stored).await.unwrap();
2767+
2768+
let err = manager.get_access_token().await.unwrap_err();
2769+
assert!(
2770+
matches!(err, AuthError::AuthorizationRequired),
2771+
"expected AuthorizationRequired when token is within refresh buffer, got: {err:?}"
2772+
);
2773+
}
2774+
2775+
#[tokio::test]
2776+
async fn get_access_token_propagates_internal_errors() {
2777+
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
2778+
let stored = StoredCredentials {
2779+
client_id: "test".to_string(),
2780+
token_response: Some(make_token_response("stale-token", Some(3600))),
2781+
granted_scopes: vec![],
2782+
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
2783+
};
2784+
manager.credential_store.save(stored).await.unwrap();
2785+
2786+
let err = manager.get_access_token().await.unwrap_err();
2787+
assert!(
2788+
matches!(err, AuthError::InternalError(_)),
2789+
"expected InternalError when OAuth client is not configured, got: {err:?}"
2790+
);
2791+
}
26392792
}

0 commit comments

Comments
 (0)