Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 150 additions & 9 deletions sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use crate::{pipeline::signature_target::SignatureTarget, resource_context::Resou
use crate::utils::url_encode;

const AZURE_VERSION: &str = "2020-07-15";
const ENV_SCOPE_OVERRIDE: &str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE";
const PUBLIC_COSMOS_SCOPE: &str = "https://cosmos.azure.com/.default";

#[derive(Debug, Clone)]
enum Credential {
Expand Down Expand Up @@ -111,19 +113,23 @@ impl Policy for AuthorizationPolicy {
async fn generate_authorization(
auth_token: &Credential,
url: &Url,

// Unused unless feature="key_auth", but I don't want to mess with excluding it since it makes call sites more complicated
#[allow(unused_variables)] signature_target: SignatureTarget<'_>,
) -> azure_core::Result<String> {
let token = match auth_token {
Credential::Token(token_credential) => {
let token = token_credential
.get_token(&[&scope_from_url(url)], None)
.await?
.token
.secret()
.to_string();
format!("type=aad&ver=1.0&sig={token}")
// Env override: use ONLY this scope, no fallback.
if let Ok(s) = std::env::var(ENV_SCOPE_OVERRIDE) {
let override_scope = s.trim();
if !override_scope.is_empty() {
let at = token_credential.get_token(&[override_scope], None).await?;
format!("type=aad&ver=1.0&sig={}", at.token.secret())
} else {
// 2) No override => host scope, with single fallback on AADSTS500011
acquire_with_account_scope_then_maybe_fallback(token_credential.as_ref(), url).await?
}
} else {
acquire_with_account_scope_then_maybe_fallback(token_credential.as_ref(), url).await?
}
}

#[cfg(feature = "key_auth")]
Expand All @@ -133,6 +139,25 @@ async fn generate_authorization(
Ok(url_encode(token))
}

async fn acquire_with_account_scope_then_maybe_fallback(
cred: &dyn TokenCredential,
url: &Url,
) -> azure_core::Result<String> {
let account_scope = scope_from_url(url);
match cred.get_token(&[&account_scope], None).await {
Ok(at) => Ok(format!("type=aad&ver=1.0&sig={}", at.token.secret())),
Err(e) => {
let msg = e.to_string();
if msg.contains("AADSTS500011") {
let at = cred.get_token(&[PUBLIC_COSMOS_SCOPE], None).await?;
Ok(format!("type=aad&ver=1.0&sig={}", at.token.secret()))
} else {
Err(e)
}
}
}
}

/// This function generates the scope string from the passed url. The scope string is used to
/// request the AAD token.
fn scope_from_url(url: &Url) -> String {
Expand Down Expand Up @@ -277,9 +302,125 @@ mod tests {
assert_eq!(ret, expected);
}

#[tokio::test]
async fn generate_authorization_with_env_override_only() {
use crate::pipeline::authorization_policy::ENV_SCOPE_OVERRIDE;

let _guard = EnvGuard::set(ENV_SCOPE_OVERRIDE, "https://custom.example/.default");

let time_nonce =
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();

let cred = std::sync::Arc::new(TestTokenCredential("test_token".to_string()));
let auth_token = Credential::Token(cred);

let url = url::Url::parse("https://acct.documents.azure.com/dbs/x").unwrap();

let ret = generate_authorization(
&auth_token,
&url,
SignatureTarget::new(
azure_core::http::Method::Get,
&ResourceLink::root(ResourceType::Databases).item("x"),
&date_string,
),
)
.await
.unwrap();

let expected = url_encode(b"type=aad&ver=1.0&sig=test_token+https://custom.example/.default");
assert_eq!(ret, expected);
}

#[derive(Debug)]
struct FallbackTokenCredential;
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for FallbackTokenCredential {
async fn get_token(
&self,
scopes: &[&str],
_opts: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
let requested = scopes.join(",");
// Simulate failure for the account/host scope; success for public scope
if requested.starts_with("https://acct.documents.azure.com/.default") {
return Err(azure_core::error::Error::with_message(
azure_core::error::ErrorKind::Other,
"AADSTS500011: The resource principal named ... was not found",
));
}
// Success for the public fallback scope
Ok(AccessToken::new(
format!("ok_token+{requested}"),
OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)),
))
}
}

#[tokio::test]
async fn generate_authorization_fallback_to_public_scope_on_500011() {
use crate::pipeline::authorization_policy::ENV_SCOPE_OVERRIDE;
use crate::pipeline::authorization_policy::PUBLIC_COSMOS_SCOPE;

let _guard = EnvGuard::remove(ENV_SCOPE_OVERRIDE);

let time_nonce =
azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap();
let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase();

let cred = std::sync::Arc::new(FallbackTokenCredential);
let auth_token = Credential::Token(cred);
let url = url::Url::parse("https://acct.documents.azure.com/dbs/todo").unwrap();

let ret = generate_authorization(
&auth_token,
&url,
SignatureTarget::new(
azure_core::http::Method::Get,
&ResourceLink::root(ResourceType::Databases).item("todo"),
&date_string,
),
)
.await
.unwrap();

let expected = url_encode(
format!("type=aad&ver=1.0&sig=ok_token+{PUBLIC_COSMOS_SCOPE}").as_bytes(),
);
assert_eq!(ret, expected);
}

#[test]
fn scope_from_url_extracts_correct_scope() {
let scope = scope_from_url(&Url::parse("https://example.com/dbs/test_db/colls").unwrap());
assert_eq!(scope, "https://example.com/.default");
}

struct EnvGuard {
key: &'static str,
original: Option<String>,
}
impl EnvGuard {
fn set(key: &'static str, val: &str) -> Self {
let original = std::env::var(key).ok();
std::env::set_var(key, val);
Self { key, original }
}
fn remove(key: &'static str) -> Self {
let original = std::env::var(key).ok();
std::env::remove_var(key);
Self { key, original }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
if let Some(ref v) = self.original {
std::env::set_var(self.key, v);
} else {
std::env::remove_var(self.key);
}
}
}
}
Loading
Loading