Skip to content

Commit 33e8677

Browse files
authored
add TokenFetcher struct for automatic token refresh (#13)
* add TokenFetcher struct for automatic token refresh * add tests for TokenFetcher * use arc-swap so we don't have to make fetcher mut * fix comment * add TokenFetcher usage to README
1 parent 30c8b4c commit 33e8677

File tree

6 files changed

+309
-7
lines changed

6 files changed

+309
-7
lines changed

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@ name = "goauth"
1515
path = "src/lib.rs"
1616

1717
[dependencies]
18+
arc-swap = "^0.4.7"
1819
serde = "^1.0"
1920
serde_derive = "^1.0"
2021
serde_json = "^1.0"
2122
time = "0.2.14"
2223
log = "0.4.8"
2324
smpl_jwt = "^0.5"
24-
reqwest = {version = "0.10.4", features = ["blocking", "json"]}
25+
reqwest = { version = "0.10.4", features = ["blocking", "json"] }
2526
futures = "0.3.4"
2627
simpl = "0.1.0"
27-
tokio = "^0.2"
28+
tokio = { version = "^0.2", features = ["macros"] }
2829

2930
[dev-dependencies]
3031
doc-comment = "0.3.3"
32+
mockito = "^0.27.0"

README.md

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use goauth::auth::JwtClaims;
1919
use goauth::scopes::Scope;
2020
use goauth::{get_token, get_token_blocking, GoErr};
2121
use goauth::credentials::Credentials;
22+
use goauth::fetcher::TokenFetcher;
2223
use smpl_jwt::{RSAKey, Jwt};
24+
use time::Duration;
2325
2426
fn main() -> Result<(), GoErr>{
2527
let token_url = "https://www.googleapis.com/oauth2/v4/token";
@@ -45,7 +47,39 @@ fn main() -> Result<(), GoErr>{
4547
4648
// Token provides `access_token` method that outputs a value that should be placed in the Authorization header
4749
50+
// Or use the TokenFetcher abstraction which will automatically refresh tokens
51+
let fetcher = TokenFetcher::new(jwt, credentials, Duration::new(1, 0));
52+
53+
let token = async {
54+
match fetcher.fetch_token().await {
55+
Ok(token) => token,
56+
Err(e) => panic!(e)
57+
}
58+
};
59+
60+
// Now a couple seconds later we want the token again - the initial token is cached so it will re-use
61+
// the same token, saving a network trip to fetch another token
62+
let new_token = async {
63+
match fetcher.fetch_token().await {
64+
Ok(token) => token,
65+
Err(e) => panic!(e)
66+
}
67+
};
68+
69+
assert_eq!(token, new_token);
70+
71+
// Now say the token has expired or is close to expiring ("close" defined by the configurable
72+
// `refresh_buffer` parameter) at this point "later in the program." The next call to
73+
// `fetch_token` will notice this and automatically fetch a new token, cache it, and return it.
74+
let new_token = async {
75+
match fetcher.fetch_token().await {
76+
Ok(token) => token,
77+
Err(e) => panic!(e)
78+
}
79+
};
80+
81+
assert_ne!(token, new_token);
82+
4883
Ok(())
49-
5084
}
5185
```

src/auth.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl Error for TokenErr {
6767
}
6868
}
6969

70-
#[derive(Serialize, Deserialize, Debug, Clone)]
70+
#[derive(Eq, PartialEq, Serialize, Deserialize, Debug, Clone)]
7171
pub struct Token {
7272
access_token: String,
7373
token_type: String,

src/credentials.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ pub struct Credentials {
1515
client_email: String,
1616
client_id: String,
1717
auth_uri: String,
18-
token_uri: String,
18+
// pub(crate) to this can be overriden in tests
19+
pub(crate) token_uri: String,
1920
auth_provider_x509_cert_url: String,
2021
client_x509_cert_url: String
2122
}

src/fetcher.rs

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
//! Defines a `TokenFetcher` struct that will automatically refresh tokens
2+
//! at some configured time prior to the token's expiration.
3+
4+
use crate::auth::{JwtClaims, Token};
5+
use crate::credentials::Credentials;
6+
use crate::{get_token_with_client, Result};
7+
8+
use arc_swap::ArcSwapOption;
9+
use reqwest::Client;
10+
use smpl_jwt::Jwt;
11+
use std::sync::Arc;
12+
use time::{Duration, OffsetDateTime};
13+
14+
/// A `TokenFetcher` stores a `Token` on first fetch and will continue returning
15+
/// that token until it needs to be refreshed, as determined by the token's
16+
/// `expires_in` field and the configured `refresh_buffer`.
17+
///
18+
/// Specifically on each token fetch request, it will check the current time
19+
/// against the expected time the currently stored token will expire. If it
20+
/// is within the `refresh_buffer` window, it will fetch a new token, store
21+
/// that (along with the new expired time), and return the new token.
22+
pub struct TokenFetcher {
23+
client: Client,
24+
jwt: Jwt<JwtClaims>,
25+
credentials: Credentials,
26+
token_state: ArcSwapOption<TokenState>,
27+
refresh_buffer: Duration,
28+
}
29+
30+
struct TokenState {
31+
/// The currently stored token
32+
token: Token,
33+
/// The lower bound of the time at which the token needs to be refreshed
34+
refresh_at: OffsetDateTime,
35+
}
36+
37+
impl TokenFetcher {
38+
pub fn new(
39+
jwt: Jwt<JwtClaims>,
40+
credentials: Credentials,
41+
refresh_buffer: Duration
42+
) -> TokenFetcher {
43+
TokenFetcher::with_client(Client::new(), jwt, credentials, refresh_buffer)
44+
}
45+
46+
pub fn with_client(
47+
client: Client,
48+
jwt: Jwt<JwtClaims>,
49+
credentials: Credentials,
50+
refresh_buffer: Duration
51+
) -> TokenFetcher {
52+
let token_state = ArcSwapOption::from(None);
53+
54+
TokenFetcher {
55+
client,
56+
jwt,
57+
credentials,
58+
token_state,
59+
refresh_buffer,
60+
}
61+
}
62+
63+
/// Returns a token if the token is still considered "valid" per the
64+
/// currently stored token's `expires_in` field and the configured
65+
/// `refresh_buffer`. If it is, return the stored token. If not,
66+
/// fetch a new token, store it, and return the new token.
67+
pub async fn fetch_token(&self) -> Result<Token> {
68+
let token_state = self.token_state.load();
69+
70+
match &*token_state {
71+
// First time calling `fetch_token` since initialization, so fetch
72+
// a token.
73+
None => self.get_token().await,
74+
Some(token_state) => {
75+
let now = OffsetDateTime::now_utc();
76+
77+
if now >= token_state.refresh_at {
78+
// We have an existing token but it is time to refresh it
79+
self.get_token().await
80+
} else {
81+
// We have an existing, valid token, so return immediately
82+
Ok(token_state.token.clone())
83+
}
84+
},
85+
}
86+
}
87+
88+
/// Refresh the token
89+
async fn get_token(&self) -> Result<Token> {
90+
let now = OffsetDateTime::now_utc();
91+
92+
let token = get_token_with_client(&self.client, &self.jwt, &self.credentials).await?;
93+
let expires_in = Duration::new(token.expires_in().into(), 0);
94+
95+
assert!(expires_in >= self.refresh_buffer, "Received a token whose expires_in is less than the configured refresh buffer!");
96+
97+
let refresh_at = now + (expires_in - self.refresh_buffer);
98+
let token_state = TokenState {
99+
token: token.clone(),
100+
refresh_at,
101+
};
102+
103+
self.token_state.swap(Some(Arc::new(token_state)));
104+
Ok(token)
105+
}
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
use crate::auth::{JwtClaims, Token};
111+
use crate::credentials::Credentials;
112+
use crate::fetcher::TokenFetcher;
113+
use crate::scopes::Scope;
114+
use mockito::{self, mock};
115+
use smpl_jwt::Jwt;
116+
use std::thread;
117+
use std::time::{Duration as StdDuration};
118+
use time::Duration;
119+
120+
fn get_mocks() -> (Jwt<JwtClaims>, Credentials) {
121+
let token_url = mockito::server_url();
122+
let iss = "some_iss";
123+
124+
let mut credentials = Credentials::from_file("dummy_credentials_file_for_tests.json").unwrap();
125+
credentials.token_uri = token_url.clone();
126+
127+
let claims = JwtClaims::new(
128+
String::from(iss),
129+
&Scope::DevStorageReadWrite,
130+
String::from(token_url.clone()),
131+
None,
132+
None,
133+
);
134+
135+
let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
136+
137+
(jwt, credentials)
138+
}
139+
140+
fn token_json(access_token: &str, token_type: &str, expires_in: u32) -> (Token, String) {
141+
let json = serde_json::json!({
142+
"access_token": access_token.to_string(),
143+
"token_type": token_type.to_string(),
144+
"expires_in": expires_in
145+
});
146+
147+
let token = serde_json::from_value(json.clone()).unwrap();
148+
149+
(token, json.to_string())
150+
}
151+
152+
#[tokio::test]
153+
async fn basic_token_fetch() {
154+
let (jwt, credentials) = get_mocks();
155+
156+
let refresh_buffer = Duration::new(0, 0);
157+
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
158+
159+
let (expected_token, json) = token_json("token", "Bearer", 1);
160+
161+
let _mock = mock("POST", "/")
162+
.with_status(200)
163+
.with_body(json)
164+
.create();
165+
166+
let token = fetcher.fetch_token().await.unwrap();
167+
assert_eq!(expected_token, token);
168+
}
169+
170+
#[tokio::test]
171+
async fn basic_token_refresh() {
172+
let (jwt, credentials) = get_mocks();
173+
174+
let refresh_buffer = Duration::new(0, 0);
175+
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
176+
177+
let expires_in = 1;
178+
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
179+
180+
let mock = mock("POST", "/")
181+
.with_status(200)
182+
.with_body(json)
183+
.expect(2) // we expect to be hit twice due to refresh
184+
.create();
185+
186+
// this should work
187+
fetcher.fetch_token().await.unwrap();
188+
189+
// sleep for `expires_in`
190+
thread::sleep(StdDuration::from_secs(expires_in.into()));
191+
192+
// this should refresh
193+
fetcher.fetch_token().await.unwrap();
194+
195+
mock.assert();
196+
}
197+
198+
#[tokio::test]
199+
async fn token_refresh_with_buffer() {
200+
let (jwt, credentials) = get_mocks();
201+
202+
let refresh_buffer = 4;
203+
let fetcher = TokenFetcher::new(jwt, credentials, Duration::new(refresh_buffer, 0));
204+
205+
let expires_in = 5;
206+
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
207+
208+
let mock = mock("POST", "/")
209+
.with_status(200)
210+
.with_body(json)
211+
.expect(2) // we expect to be hit twice due to refresh
212+
.create();
213+
214+
// this should work
215+
fetcher.fetch_token().await.unwrap();
216+
217+
// sleep for `expires_in`
218+
let sleep_for = expires_in - (refresh_buffer as u32);
219+
thread::sleep(StdDuration::from_secs(sleep_for.into()));
220+
221+
// this should refresh
222+
fetcher.fetch_token().await.unwrap();
223+
224+
mock.assert();
225+
}
226+
227+
#[tokio::test]
228+
async fn doesnt_token_refresh_unnecessarily() {
229+
let (jwt, credentials) = get_mocks();
230+
231+
let refresh_buffer = Duration::new(0, 0);
232+
let fetcher = TokenFetcher::new(jwt, credentials, refresh_buffer);
233+
234+
let expires_in = 1;
235+
let (_expected_token, json) = token_json("token", "Bearer", expires_in);
236+
237+
let mock = mock("POST", "/")
238+
.with_status(200)
239+
.with_body(json)
240+
.expect(1) // we expect to be hit only once
241+
.create();
242+
243+
// this should work
244+
fetcher.fetch_token().await.unwrap();
245+
246+
// fetch again, should not refresh
247+
fetcher.fetch_token().await.unwrap();
248+
249+
mock.assert();
250+
}
251+
}

src/lib.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ extern crate doc_comment;
77

88
pub mod auth;
99
pub mod credentials;
10+
pub mod fetcher;
1011
pub mod scopes;
1112

1213
use auth::{JwtClaims, Token};
@@ -182,16 +183,29 @@ pub fn get_token_blocking(
182183
pub async fn get_token(
183184
jwt: &Jwt<JwtClaims>,
184185
credentials: &Credentials,
186+
) -> Result<Token> {
187+
let client = Client::new();
188+
189+
get_token_with_client(
190+
&client,
191+
jwt,
192+
credentials
193+
).await
194+
}
195+
196+
pub async fn get_token_with_client(
197+
client: &Client,
198+
jwt: &Jwt<JwtClaims>,
199+
credentials: &Credentials,
185200
) -> Result<Token> {
186201
let final_jwt = jwt.finalize()?;
187202
let request_body = form_body(&final_jwt);
188203

189-
let client = Client::new();
190204
let response = client
191205
.post(&credentials.token_uri())
192206
.form(&request_body)
193207
.send().await?;
194-
208+
195209
let token = response.json::<Token>().await?;
196210
Ok(token)
197211
}

0 commit comments

Comments
 (0)