Skip to content

Commit 41739b7

Browse files
enhancements
1 parent f33cf95 commit 41739b7

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -317,16 +317,14 @@ protected virtual void ApplyClaimsAndCapabilities(
317317
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
318318
}
319319

320-
// If claims are present, send the bad token hash to the Managed Identity endpoint.
321-
if (!string.IsNullOrEmpty(parameters.Claims))
320+
// Only include 'token_sha256_to_refresh' if we have both Claims and the old token's hash
321+
if (!string.IsNullOrEmpty(parameters.Claims) &&
322+
!string.IsNullOrEmpty(parameters.BadTokenHash))
322323
{
323-
if (!string.IsNullOrEmpty(parameters.BadTokenHash))
324-
{
325-
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
326-
_requestContext.Logger.Info(
327-
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
328-
);
329-
}
324+
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
325+
_requestContext.Logger.Info(
326+
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
327+
);
330328
}
331329
}
332330

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
424424
{
425425
MockHttpMessageHandler httpMessageHandler = new MockHttpMessageHandler();
426426
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
427+
IDictionary<string, string> notExpectedQueryParams = new Dictionary<string, string>();
427428
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
428429
IDictionary<string, string> expectedPostData = null; // Only used for Cloud Shell
429430

@@ -471,26 +472,35 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
471472
break;
472473
}
473474

475+
var manager = new CommonCryptographyManager();
476+
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
477+
474478
// If capabilityEnabled, add "xms_cc": "cp1"
475479
if (capabilityEnabled)
476480
{
477-
if (managedIdentitySourceType == ManagedIdentitySource.AppService
481+
if (managedIdentitySourceType == ManagedIdentitySource.AppService
478482
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
479483
{
480484
expectedQueryParams.Add("xms_cc", "cp1,cp2");
481485
}
482486
}
487+
else
488+
{
489+
notExpectedQueryParams.Add("xms_cc", "cp1,cp2");
490+
}
483491

484492
if (claimsEnabled)
485493
{
486-
var manager = new CommonCryptographyManager();
487-
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
488494
if (managedIdentitySourceType == ManagedIdentitySource.AppService
489495
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
490496
{
491497
expectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
492498
}
493499
}
500+
else
501+
{
502+
notExpectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
503+
}
494504

495505
if (managedIdentitySourceType != ManagedIdentitySource.CloudShell)
496506
{

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal class MockHttpMessageHandler : HttpClientHandler
2525
public IDictionary<string, string> ExpectedRequestHeaders { get; set; }
2626
public IList<string> UnexpectedRequestHeaders { get; set; }
2727
public IDictionary<string, string> UnExpectedPostData { get; set; }
28+
public IDictionary<string, string> NotExpectedQueryParams { get; set; }
2829
public HttpMethod ExpectedMethod { get; set; }
2930

3031
public Exception ExceptionToThrow { get; set; }
@@ -65,7 +66,9 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
6566

6667
Assert.AreEqual(ExpectedMethod, request.Method);
6768

68-
ValidateQueryParams(uri);
69+
ValidateExpectedQueryParams(uri);
70+
71+
ValidateNotExpectedQueryParams(uri);
6972

7073
await ValidatePostDataAsync(request).ConfigureAwait(false);
7174

@@ -80,12 +83,12 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
8083
return ResponseMessage;
8184
}
8285

83-
private void ValidateQueryParams(Uri uri)
86+
private void ValidateExpectedQueryParams(Uri uri)
8487
{
8588
if (ExpectedQueryParams != null && ExpectedQueryParams.Any())
8689
{
8790
Assert.IsFalse(string.IsNullOrEmpty(uri.Query), $"Provided url ({uri.AbsoluteUri}) does not contain query parameters as expected.");
88-
var inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
91+
Dictionary<string, string> inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
8992
Assert.AreEqual(ExpectedQueryParams.Count, inputQp.Count, "Different number of query params.");
9093
foreach (var key in ExpectedQueryParams.Keys)
9194
{
@@ -95,6 +98,35 @@ private void ValidateQueryParams(Uri uri)
9598
}
9699
}
97100

101+
private void ValidateNotExpectedQueryParams(Uri uri)
102+
{
103+
if (NotExpectedQueryParams != null && NotExpectedQueryParams.Any())
104+
{
105+
// Parse actual query params again (or reuse inputQp if you like)
106+
Dictionary<string, string> actualQueryParams = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
107+
List<string> unexpectedKeysFound = new List<string>();
108+
109+
foreach (KeyValuePair<string, string> kvp in NotExpectedQueryParams)
110+
{
111+
// Check if the request's query has this key
112+
if (actualQueryParams.TryGetValue(kvp.Key, out string value))
113+
{
114+
// Optionally, also check if we care about matching the *value*:
115+
if (string.Equals(value, kvp.Value, StringComparison.OrdinalIgnoreCase))
116+
{
117+
unexpectedKeysFound.Add(kvp.Key);
118+
}
119+
}
120+
}
121+
122+
// Fail if any "not expected" key/value pairs were found
123+
Assert.IsTrue(
124+
unexpectedKeysFound.Count == 0,
125+
$"Did not expect to find these query parameter keys/values: {string.Join(", ", unexpectedKeysFound)}"
126+
);
127+
}
128+
}
129+
98130
private async Task ValidatePostDataAsync(HttpRequestMessage request)
99131
{
100132
if (request.Method != HttpMethod.Get && request.Content != null)

0 commit comments

Comments
 (0)