From e227fc1ae2dd7ef5c40cb6abb31e8a86f7e21bc9 Mon Sep 17 00:00:00 2001 From: "Arooshi Avasthy (from Dev Box)" Date: Thu, 16 Oct 2025 17:18:22 -0700 Subject: [PATCH] Add AAD sope fallback --- .../src/pipeline/authorization_policy.rs | 159 ++++++++++- .../azure_data_cosmos/tests/aad_scope.rs | 251 ++++++++++++++++++ .../tests/framework/test_account.rs | 38 ++- 3 files changed, 437 insertions(+), 11 deletions(-) create mode 100644 sdk/cosmos/azure_data_cosmos/tests/aad_scope.rs diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs b/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs index 19cf51c2a0..c9918c5b79 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs +++ b/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs @@ -26,6 +26,8 @@ use crate::{pipeline::signature_target::SignatureTarget, resource_context::Resou use crate::utils::url_encode; const AZURE_VERSION: &str = "2020-07-15"; +const ENV_SCOPE_OVERRIDE: &str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"; +const PUBLIC_COSMOS_SCOPE: &str = "https://cosmos.azure.com/.default"; #[derive(Debug, Clone)] enum Credential { @@ -111,19 +113,23 @@ impl Policy for AuthorizationPolicy { async fn generate_authorization( auth_token: &Credential, url: &Url, - - // Unused unless feature="key_auth", but I don't want to mess with excluding it since it makes call sites more complicated #[allow(unused_variables)] signature_target: SignatureTarget<'_>, ) -> azure_core::Result { let token = match auth_token { Credential::Token(token_credential) => { - let token = token_credential - .get_token(&[&scope_from_url(url)], None) - .await? - .token - .secret() - .to_string(); - format!("type=aad&ver=1.0&sig={token}") + // Env override: use ONLY this scope, no fallback. + if let Ok(s) = std::env::var(ENV_SCOPE_OVERRIDE) { + let override_scope = s.trim(); + if !override_scope.is_empty() { + let at = token_credential.get_token(&[override_scope], None).await?; + format!("type=aad&ver=1.0&sig={}", at.token.secret()) + } else { + // 2) No override => host scope, with single fallback on AADSTS500011 + acquire_with_account_scope_then_maybe_fallback(token_credential.as_ref(), url).await? + } + } else { + acquire_with_account_scope_then_maybe_fallback(token_credential.as_ref(), url).await? + } } #[cfg(feature = "key_auth")] @@ -133,6 +139,25 @@ async fn generate_authorization( Ok(url_encode(token)) } +async fn acquire_with_account_scope_then_maybe_fallback( + cred: &dyn TokenCredential, + url: &Url, +) -> azure_core::Result { + let account_scope = scope_from_url(url); + match cred.get_token(&[&account_scope], None).await { + Ok(at) => Ok(format!("type=aad&ver=1.0&sig={}", at.token.secret())), + Err(e) => { + let msg = e.to_string(); + if msg.contains("AADSTS500011") { + let at = cred.get_token(&[PUBLIC_COSMOS_SCOPE], None).await?; + Ok(format!("type=aad&ver=1.0&sig={}", at.token.secret())) + } else { + Err(e) + } + } + } +} + /// This function generates the scope string from the passed url. The scope string is used to /// request the AAD token. fn scope_from_url(url: &Url) -> String { @@ -277,9 +302,125 @@ mod tests { assert_eq!(ret, expected); } + #[tokio::test] + async fn generate_authorization_with_env_override_only() { + use crate::pipeline::authorization_policy::ENV_SCOPE_OVERRIDE; + + let _guard = EnvGuard::set(ENV_SCOPE_OVERRIDE, "https://custom.example/.default"); + + let time_nonce = + azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); + let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase(); + + let cred = std::sync::Arc::new(TestTokenCredential("test_token".to_string())); + let auth_token = Credential::Token(cred); + + let url = url::Url::parse("https://acct.documents.azure.com/dbs/x").unwrap(); + + let ret = generate_authorization( + &auth_token, + &url, + SignatureTarget::new( + azure_core::http::Method::Get, + &ResourceLink::root(ResourceType::Databases).item("x"), + &date_string, + ), + ) + .await + .unwrap(); + + let expected = url_encode(b"type=aad&ver=1.0&sig=test_token+https://custom.example/.default"); + assert_eq!(ret, expected); + } + + #[derive(Debug)] + struct FallbackTokenCredential; + #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] + #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] + impl TokenCredential for FallbackTokenCredential { + async fn get_token( + &self, + scopes: &[&str], + _opts: Option>, + ) -> azure_core::Result { + let requested = scopes.join(","); + // Simulate failure for the account/host scope; success for public scope + if requested.starts_with("https://acct.documents.azure.com/.default") { + return Err(azure_core::error::Error::with_message( + azure_core::error::ErrorKind::Other, + "AADSTS500011: The resource principal named ... was not found", + )); + } + // Success for the public fallback scope + Ok(AccessToken::new( + format!("ok_token+{requested}"), + OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)), + )) + } + } + + #[tokio::test] + async fn generate_authorization_fallback_to_public_scope_on_500011() { + use crate::pipeline::authorization_policy::ENV_SCOPE_OVERRIDE; + use crate::pipeline::authorization_policy::PUBLIC_COSMOS_SCOPE; + + let _guard = EnvGuard::remove(ENV_SCOPE_OVERRIDE); + + let time_nonce = + azure_core::time::parse_rfc3339("1900-01-01T01:00:00.000000000+00:00").unwrap(); + let date_string = azure_core::time::to_rfc7231(&time_nonce).to_lowercase(); + + let cred = std::sync::Arc::new(FallbackTokenCredential); + let auth_token = Credential::Token(cred); + let url = url::Url::parse("https://acct.documents.azure.com/dbs/todo").unwrap(); + + let ret = generate_authorization( + &auth_token, + &url, + SignatureTarget::new( + azure_core::http::Method::Get, + &ResourceLink::root(ResourceType::Databases).item("todo"), + &date_string, + ), + ) + .await + .unwrap(); + + let expected = url_encode( + format!("type=aad&ver=1.0&sig=ok_token+{PUBLIC_COSMOS_SCOPE}").as_bytes(), + ); + assert_eq!(ret, expected); + } + #[test] fn scope_from_url_extracts_correct_scope() { let scope = scope_from_url(&Url::parse("https://example.com/dbs/test_db/colls").unwrap()); assert_eq!(scope, "https://example.com/.default"); } + + struct EnvGuard { + key: &'static str, + original: Option, + } + impl EnvGuard { + fn set(key: &'static str, val: &str) -> Self { + let original = std::env::var(key).ok(); + std::env::set_var(key, val); + Self { key, original } + } + fn remove(key: &'static str) -> Self { + let original = std::env::var(key).ok(); + std::env::remove_var(key); + Self { key, original } + } + } + impl Drop for EnvGuard { + fn drop(&mut self) { + if let Some(ref v) = self.original { + std::env::set_var(self.key, v); + } else { + std::env::remove_var(self.key); + } + } + } } diff --git a/sdk/cosmos/azure_data_cosmos/tests/aad_scope.rs b/sdk/cosmos/azure_data_cosmos/tests/aad_scope.rs new file mode 100644 index 0000000000..64cf2cb511 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/tests/aad_scope.rs @@ -0,0 +1,251 @@ +// Licensed under the MIT License. + +#![cfg(feature = "key_auth")] + +use std::sync::{Arc, Mutex}; +use azure_core::credentials::{AccessToken, TokenCredential, TokenRequestOptions}; +use azure_core::time::{Duration, OffsetDateTime}; +use azure_core_test::{recorded, TestContext}; +use serde_json::json; + +mod framework; +use framework::TestAccount; +use azure_data_cosmos::{CosmosClient, models::ContainerProperties}; + +// +// ========== Helpers for capturing scopes & simulating failures ========== +// + +#[derive(Debug, Clone)] +struct CapturedScopes(Arc>>); + +impl CapturedScopes { + fn new() -> Self { + Self(Arc::new(Mutex::new(Vec::new()))) + } + fn push(&self, scope: &str) { + self.0.lock().unwrap().push(scope.to_string()); + } + fn take(&self) -> Vec { + std::mem::take(&mut *self.0.lock().unwrap()) + } +} + +#[derive(Debug)] +struct RecordingCredential { + tag: &'static str, + captured: CapturedScopes, +} + +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl TokenCredential for RecordingCredential { + async fn get_token( + &self, + scopes: &[&str], + _opts: Option>, + ) -> azure_core::Result { + let scope = scopes.join(","); + self.captured.push(&scope); + Ok(AccessToken::new( + format!("{}_token_for_{}", self.tag, scope), + OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)), + )) + } +} + +#[derive(Debug)] +struct AlwaysFailCredential { + captured: CapturedScopes, + message: &'static str, +} + +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl TokenCredential for AlwaysFailCredential { + async fn get_token( + &self, + scopes: &[&str], + _opts: Option>, + ) -> azure_core::Result { + let scope = scopes.join(","); + self.captured.push(&scope); + Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + self.message, + )) + } +} + +/// Fails once with AADSTS500011 on account scope, then succeeds +#[derive(Debug)] +struct FailOnceThenSucceedCredential { + captured: CapturedScopes, + first_call_done: Arc>, + account_scope_prefix: String, +} + +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl TokenCredential for FailOnceThenSucceedCredential { + async fn get_token( + &self, + scopes: &[&str], + _opts: Option>, + ) -> azure_core::Result { + let scope = scopes.join(","); + self.captured.push(&scope); + + let mut done = self.first_call_done.lock().unwrap(); + if !*done && scope.starts_with(&self.account_scope_prefix) { + *done = true; + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "AADSTS500011: Simulated error for fallback", + )); + } + + Ok(AccessToken::new( + format!("ok_token_for_{}", scope), + OffsetDateTime::now_utc().saturating_add(Duration::minutes(5)), + )) + } +} + +// Env override guard + +struct TestEnvGuard { + key: &'static str, + original: Option, +} +impl TestEnvGuard { + fn set(key: &'static str, val: &str) -> Self { + let original = std::env::var(key).ok(); + std::env::set_var(key, val); + Self { key, original } + } +} +impl Drop for TestEnvGuard { + fn drop(&mut self) { + if let Some(ref v) = self.original { + std::env::set_var(self.key, v); + } else { + std::env::remove_var(self.key); + } + } +} + + +async fn create_db_container_and_item( + client: &CosmosClient, + db_id: &str, + container_id: &str, +) -> Result<(), Box> { + + let _ = client.create_database(db_id, None).await?; + + let db_client = client.database_client(db_id); + + db_client + .create_container( + ContainerProperties { + id: container_id.to_string().into(), + partition_key: "/pk".into(), + ..Default::default() + }, + None, + ) + .await?; + + let cont = db_client.container_client(container_id); + let doc = json!({"id":"Item_1","pk":"pk"}); + cont.create_item("pk", &doc, None).await?; + Ok(()) +} + +// AAD Tests + +const ENV_SCOPE_OVERRIDE: &str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"; +const PUBLIC_COSMOS_SCOPE: &str = "https://cosmos.azure.com/.default"; + +#[recorded::test] +async fn aad_override_scope_no_fallback(context: TestContext) -> Result<(), Box> { + let _guard = TestEnvGuard::set(ENV_SCOPE_OVERRIDE, "https://cosmos.azure.com/.default"); + + let captured = CapturedScopes::new(); + let cred = Arc::new(RecordingCredential { tag: "override", captured: captured.clone() }); + + let account = TestAccount::from_env(context, None).await?; + let client = account.connect_with_token(cred)?; + + create_db_container_and_item(&client, "AAD_Override_DB", "AAD_Override_Cont").await?; + + let scopes = captured.take(); + assert!(scopes.iter().all(|s| s == "https://cosmos.azure.com/.default")); + Ok(()) +} + +#[recorded::test] +async fn aad_override_scope_auth_error_no_fallback(context: TestContext) -> Result<(), Box> { + let _guard = TestEnvGuard::set(ENV_SCOPE_OVERRIDE, "https://my.custom.scope/.default"); + + let captured = CapturedScopes::new(); + let cred = Arc::new(AlwaysFailCredential { captured: captured.clone(), message: "fail" }); + + let account = TestAccount::from_env(context, None).await?; + let client = account.connect_with_token(cred)?; + + let result = create_db_container_and_item(&client, "AAD_OverrideFail_DB", "AAD_OverrideFail_Cont").await; + assert!(result.is_err()); + + let scopes = captured.take(); + assert_eq!(scopes, vec!["https://my.custom.scope/.default"]); + Ok(()) +} + +#[recorded::test] +async fn aad_account_scope_only(context: TestContext) -> Result<(), Box> { + // Empty override -> use account scope (no fallback unless error) + let _guard = TestEnvGuard::set(ENV_SCOPE_OVERRIDE, ""); + + let captured = CapturedScopes::new(); + let cred = Arc::new(RecordingCredential { tag: "account", captured: captured.clone() }); + + let account = TestAccount::from_env(context, None).await?; + let client = account.connect_with_token(cred)?; + + create_db_container_and_item(&client, "AAD_Account_DB", "AAD_Account_Cont").await?; + + let scopes = captured.take(); + assert!(!scopes.is_empty()); + Ok(()) +} + +#[recorded::test] +async fn aad_account_scope_fallback_on_error(context: TestContext) -> Result<(), Box> { + // Empty override -> use account/host scope, and if AADSTS500011 then fallback to public scope + let _guard = TestEnvGuard::set(ENV_SCOPE_OVERRIDE, ""); + + let account = TestAccount::from_env(context, None).await?; + + let captured_probe = CapturedScopes::new(); + let probe_cred = Arc::new(RecordingCredential { tag: "probe", captured: captured_probe.clone() }); + let probe_client = account.connect_with_token(probe_cred)?; + let _ = probe_client.create_database("AAD_Fallback_Probe_DB", None).await; + + let scopes_seen = captured_probe.take(); + let account_scope = scopes_seen.into_iter().find(|s| s.ends_with("/.default")).unwrap(); + + //Run with FailOnceThenSucceedCredential to trigger fallback + let captured = CapturedScopes::new(); + let cred = Arc::new(FailOnceThenSucceedCredential { + captured: captured.clone(), + first_call_done: Arc::new(Mutex::new(false)), + account_scope_prefix: account_scope.clone(), + }); + + let client = account.connect_with_token(cred)?; + create_db_container_and_item(&client, "AAD_Fallback_DB", "AAD_Fallback_Cont").await?; + + let scopes = captured.take(); + assert!(scopes.contains(&account_scope)); + assert!(scopes.contains(&PUBLIC_COSMOS_SCOPE.to_string())); + Ok(()) +} \ No newline at end of file diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs index 6cd8eeb215..b601844fb3 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs @@ -4,9 +4,13 @@ use std::{borrow::Cow, str::FromStr, sync::Arc}; -use azure_core::{credentials::Secret, http::Transport, test::TestMode}; +use azure_core::{ + credentials::{Secret, TokenCredential}, + http::Transport, + test::TestMode, +}; use azure_core_test::TestContext; -use azure_data_cosmos::{ConnectionString, CosmosClientOptions, Query}; +use azure_data_cosmos::{ConnectionString, CosmosClient, CosmosClientOptions, Query}; use reqwest::ClientBuilder; /// Represents a Cosmos DB account for testing purposes. @@ -120,6 +124,36 @@ impl TestAccount { )?) } + pub fn connect_with_token( + &self, + cred: Arc, + ) -> Result> { + let allow_invalid_certificates = match self.options.allow_invalid_certificates { + Some(b) => b, + None => std::env::var(ALLOW_INVALID_CERTS_ENV_VAR).map(|s| s.parse())??, + }; + + let mut options = CosmosClientOptions::default(); + + if allow_invalid_certificates { + let client = ClientBuilder::new() + .danger_accept_invalid_certs(true) + .pool_max_idle_per_host(0) + .build()?; + options.client_options.transport = Some(Transport::new(Arc::new(client))); + } + + self.context + .recording() + .instrument(&mut options.client_options); + + Ok(CosmosClient::new( + &self.endpoint, + cred, + Some(options), + )?) + } + /// Generates a unique database ID including the [`TestAccount::context_id`]. /// /// This database will be automatically deleted when [`TestAccount::cleanup`] is called.