Skip to content

Commit 646713d

Browse files
msi v1 tr
1 parent 579b189 commit 646713d

File tree

11 files changed

+167
-71
lines changed

11 files changed

+167
-71
lines changed

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter
1818

1919
public string Claims { get; set; }
2020

21+
public string BadTokenHash { get; set; }
22+
2123
public void LogParameters(ILoggerAdapter logger)
2224
{
2325
if (logger.IsLoggingEnabled(LogLevel.Info))

src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.Identity.Client.Core;
1111
using Microsoft.Identity.Client.ManagedIdentity;
1212
using Microsoft.Identity.Client.OAuth2;
13+
using Microsoft.Identity.Client.PlatformsCommon.Interfaces;
1314
using Microsoft.Identity.Client.Utils;
1415

1516
namespace Microsoft.Identity.Client.Internal.Requests
@@ -18,6 +19,7 @@ internal class ManagedIdentityAuthRequest : RequestBase
1819
{
1920
private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters;
2021
private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1);
22+
private readonly ICryptographyManager _cryptoManager;
2123

2224
public ManagedIdentityAuthRequest(
2325
IServiceBundle serviceBundle,
@@ -26,72 +28,105 @@ public ManagedIdentityAuthRequest(
2628
: base(serviceBundle, authenticationRequestParameters, managedIdentityParameters)
2729
{
2830
_managedIdentityParameters = managedIdentityParameters;
31+
_cryptoManager = serviceBundle.PlatformProxy.CryptographyManager;
2932
}
3033

3134
protected override async Task<AuthenticationResult> ExecuteAsync(CancellationToken cancellationToken)
3235
{
3336
AuthenticationResult authResult = null;
3437
ILoggerAdapter logger = AuthenticationRequestParameters.RequestContext.Logger;
3538

36-
// Skip checking cache when force refresh or claims is specified
37-
if (_managedIdentityParameters.ForceRefresh || !string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
39+
// 1. FIRST, handle ForceRefresh
40+
if (_managedIdentityParameters.ForceRefresh)
3841
{
39-
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
4042
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
41-
42-
logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh or Claims were set. " +
43-
"This means either a force refresh was requested or claims were present.");
43+
logger.Info("[ManagedIdentityRequest] Skipped using the cache because ForceRefresh was set.");
44+
45+
// We still respect claims if present
46+
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
4447

48+
// Straight to the MI endpoint
4549
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
4650
return authResult;
4751
}
4852

53+
// 2. Otherwise, look for a cached token
4954
MsalAccessTokenCacheItem cachedAccessTokenItem = await GetCachedAccessTokenAsync().ConfigureAwait(false);
5055

51-
// No access token or cached access token needs to be refreshed
56+
// If we have claims, we do NOT use the cached token (but we still need it to compute the hash).
57+
if (!string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
58+
{
59+
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
60+
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
61+
62+
// If there is a cached token, compute its hash for the “bad token” scenario
63+
if (cachedAccessTokenItem != null)
64+
{
65+
string cachedTokenHash = _cryptoManager.CreateSha256Hash(cachedAccessTokenItem.Secret);
66+
_managedIdentityParameters.BadTokenHash = cachedTokenHash;
67+
68+
logger.Info("[ManagedIdentityRequest] Claims are present. Computed hash of the cached (bad) token. " +
69+
"Will now request a fresh token from the MI endpoint.");
70+
}
71+
else
72+
{
73+
logger.Info("[ManagedIdentityRequest] Claims are present, but no cached token was found. " +
74+
"Requesting a fresh token from the MI endpoint without a bad-token hash.");
75+
}
76+
77+
// In both cases, we skip using the cached token and get a new one
78+
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
79+
return authResult;
80+
}
81+
82+
// 3. If we have no ForceRefresh and no claims, we can use the cache
5283
if (cachedAccessTokenItem != null)
5384
{
85+
// Found a valid token in cache
5486
authResult = CreateAuthenticationResultFromCache(cachedAccessTokenItem);
55-
5687
logger.Info("[ManagedIdentityRequest] Access token retrieved from cache.");
5788

5889
try
59-
{
60-
var proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);
61-
62-
// If needed, refreshes token in the background
90+
{
91+
// If token is close to expiry, proactively refresh it in the background
92+
bool proactivelyRefresh = SilentRequestHelper.NeedsRefresh(cachedAccessTokenItem);
6393
if (proactivelyRefresh)
6494
{
6595
logger.Info("[ManagedIdentityRequest] Initiating a proactive refresh.");
6696

6797
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ProactivelyRefreshed;
6898

6999
SilentRequestHelper.ProcessFetchInBackground(
70-
cachedAccessTokenItem,
71-
() =>
72-
{
73-
// Use a linked token source, in case the original cancellation token source is disposed before this background task completes.
74-
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
75-
return GetAccessTokenAsync(tokenSource.Token, logger);
76-
}, logger, ServiceBundle, AuthenticationRequestParameters.RequestContext.ApiEvent,
77-
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId,
78-
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion);
100+
cachedAccessTokenItem,
101+
() =>
102+
{
103+
// Use a linked token source, in case the original cts is disposed
104+
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
105+
return GetAccessTokenAsync(tokenSource.Token, logger);
106+
},
107+
logger,
108+
ServiceBundle,
109+
AuthenticationRequestParameters.RequestContext.ApiEvent,
110+
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkApiId,
111+
AuthenticationRequestParameters.RequestContext.ApiEvent.CallerSdkVersion);
79112
}
80113
}
81114
catch (MsalServiceException e)
82115
{
116+
// If background refresh fails, we handle the exception
83117
return await HandleTokenRefreshErrorAsync(e, cachedAccessTokenItem).ConfigureAwait(false);
84118
}
85119
}
86120
else
87121
{
88-
// No AT in the cache
122+
// No cached token
89123
if (AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo != CacheRefreshReason.Expired)
90124
{
91125
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.NoCachedAccessToken;
92126
}
93127

94-
logger.Info("[ManagedIdentityRequest] No cached access token. Getting a token from the managed identity endpoint.");
128+
logger.Info("[ManagedIdentityRequest] No cached access token found. " +
129+
"Getting a token from the managed identity endpoint.");
95130
authResult = await GetAccessTokenAsync(cancellationToken, logger).ConfigureAwait(false);
96131
}
97132

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -307,24 +307,24 @@ protected virtual void ApplyClaimsAndCapabilities(
307307
{
308308
IEnumerable<string> clientCapabilities = _requestContext.ServiceBundle.Config.ClientCapabilities;
309309

310-
// If claims are present, set bypass_cache=true
311-
if (!string.IsNullOrEmpty(parameters.Claims))
310+
// Set xms_cc only if clientCapabilities exist
311+
if (clientCapabilities != null && clientCapabilities.Any())
312312
{
313-
SetRequestParameter(request, "bypass_cache", "true");
314-
_requestContext.Logger.Info("[Managed Identity] Setting bypass_cache=true in the Managed Identity request due to claims.");
313+
SetRequestParameter(request, "xms_cc", string.Join(",", clientCapabilities));
314+
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
315+
}
315316

316-
// Set xms_cc only if clientCapabilities exist
317-
if (clientCapabilities != null && clientCapabilities.Any())
317+
// If claims are present, send the bad token hash to the Managed Identity endpoint.
318+
if (!string.IsNullOrEmpty(parameters.Claims))
319+
{
320+
if (!string.IsNullOrEmpty(parameters.BadTokenHash))
318321
{
319-
SetRequestParameter(request, "xms_cc", string.Join(",", clientCapabilities));
320-
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
322+
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
323+
_requestContext.Logger.Info(
324+
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
325+
);
321326
}
322327
}
323-
else
324-
{
325-
SetRequestParameter(request, "bypass_cache", "false");
326-
_requestContext.Logger.Info("[Managed Identity] Setting bypass_cache=false (no claims provided).");
327-
}
328328
}
329329

330330
/// <summary>

src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ protected override ManagedIdentityRequest CreateRequest(string resource, Acquire
8888
request.Headers.Add("Metadata", "true");
8989
request.QueryParameters["api-version"] = ArcApiVersion;
9090
request.QueryParameters["resource"] = resource;
91-
request.QueryParameters["bypass_cache"] = "false";
92-
93-
ApplyClaimsAndCapabilities(request, parameters);
9491

9592
return request;
9693
}

src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ protected override ManagedIdentityRequest CreateRequest(string resource, Acquire
8383

8484
request.BodyParameters.Add("resource", resource);
8585

86-
ApplyClaimsAndCapabilities(request, parameters);
87-
8886
return request;
8987
}
9088
}

src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ protected override ManagedIdentityRequest CreateRequest(string resource, Acquire
6464
request.QueryParameters["api-version"] = ImdsApiVersion;
6565
request.QueryParameters["resource"] = resource;
6666

67-
ApplyClaimsAndCapabilities(request, parameters);
68-
6967
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
7068
{
7169
case AppConfig.ManagedIdentityIdType.ClientId:

src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ protected override ManagedIdentityRequest CreateRequest(string resource, Acquire
7272
request.QueryParameters["api-version"] = MachineLearningMsiApiVersion;
7373
request.QueryParameters["resource"] = resource;
7474

75-
ApplyClaimsAndCapabilities(request, parameters);
76-
7775
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
7876
{
7977
case AppConfig.ManagedIdentityIdType.ClientId:

src/client/Microsoft.Identity.Client/Utils/CoreHelpers.cs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,40 +75,54 @@ public static string ToQueryParameter(this IDictionary<string, string> input)
7575
return builder.ToString();
7676
}
7777

78-
public static Dictionary<string, string> ParseKeyValueList(string input, char delimiter, bool urlDecode,
79-
bool lowercaseKeys,
78+
public static Dictionary<string, string> ParseKeyValueList(
79+
string input,
80+
char delimiter,
81+
bool urlDecode,
82+
bool lowercaseKeys,
8083
RequestContext requestContext)
8184
{
8285
var response = new Dictionary<string, string>();
8386

87+
// Split the full query string on & (or any provided delimiter) to get individual k=v pairs.
8488
var queryPairs = SplitWithQuotes(input, delimiter);
8589

8690
foreach (string queryPair in queryPairs)
8791
{
88-
var pair = SplitWithQuotes(queryPair, '=');
92+
// Instead of splitting on *all* '=' characters, find only the first one.
93+
// This ensures that if the value itself contains '=', such as a trailing '=' in Base64,
94+
// we do not accidentally split the base64 value into extra parts and lose the padding.
95+
int idx = queryPair.IndexOf('=');
8996

90-
if (pair.Count == 2 && !string.IsNullOrWhiteSpace(pair[0]) && !string.IsNullOrWhiteSpace(pair[1]))
97+
// idx > 0 means we found an '=' and have a valid key substring before it
98+
if (idx > 0)
9199
{
92-
string key = pair[0];
93-
string value = pair[1];
100+
// The key is everything before the first '='
101+
string key = queryPair.Substring(0, idx);
102+
103+
// The value is everything after the first '=' (including any trailing '=')
104+
string value = queryPair.Substring(idx + 1);
94105

95-
// Url decoding is needed for parsing OAuth response, but not for parsing WWW-Authenticate header in 401 challenge
106+
// If urlDecode == true, decode both key and value
96107
if (urlDecode)
97108
{
98109
key = UrlDecode(key);
99110
value = UrlDecode(value);
100111
}
101112

113+
// Optionally convert key to lowercase
102114
if (lowercaseKeys)
103115
{
104116
key = key.Trim().ToLowerInvariant();
105117
}
106118

119+
// Trim quotes and whitespace around the value
107120
value = value.Trim().Trim('\"').Trim();
108121

109122
if (response.ContainsKey(key))
110123
{
111-
requestContext?.Logger.Warning(string.Format(CultureInfo.InvariantCulture,
124+
requestContext?.Logger.Warning(
125+
string.Format(CultureInfo.InvariantCulture,
112126
"Key/value pair list contains redundant key '{0}'.", key));
113127
}
114128

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ public static HttpResponseMessage CreateSuccessTokenResponseMessage(
370370
string[] scope,
371371
bool foci = false,
372372
string utid = TestConstants.Utid,
373-
string accessToken = "some-access-token",
373+
string accessToken = TestConstants.ATSecret,
374374
string refreshToken = "OAAsomethingencrypedQwgAA")
375375
{
376376
HttpResponseMessage responseMessage = new HttpResponseMessage(HttpStatusCode.OK);
@@ -385,7 +385,7 @@ public static string CreateSuccessTokenResponseString(string uniqueId,
385385
string[] scope,
386386
bool foci = false,
387387
string utid = TestConstants.Utid,
388-
string accessToken = "some-access-token",
388+
string accessToken = TestConstants.ATSecret,
389389
string refreshToken = "OAAsomethingencrypedQwgAA")
390390
{
391391
string idToken = CreateIdToken(uniqueId, displayableId, TestConstants.Utid);

0 commit comments

Comments
 (0)