Skip to content

Commit 92b01c9

Browse files
enhancements
1 parent 646713d commit 92b01c9

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,14 @@ protected virtual void ApplyClaimsAndCapabilities(
314314
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
315315
}
316316

317-
// If claims are present, send the bad token hash to the Managed Identity endpoint.
318-
if (!string.IsNullOrEmpty(parameters.Claims))
317+
// Only include 'token_sha256_to_refresh' if we have both Claims and the old token's hash
318+
if (!string.IsNullOrEmpty(parameters.Claims) &&
319+
!string.IsNullOrEmpty(parameters.BadTokenHash))
319320
{
320-
if (!string.IsNullOrEmpty(parameters.BadTokenHash))
321-
{
322-
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
323-
_requestContext.Logger.Info(
324-
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
325-
);
326-
}
321+
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
322+
_requestContext.Logger.Info(
323+
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
324+
);
327325
}
328326
}
329327

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
415415
{
416416
MockHttpMessageHandler httpMessageHandler = new MockHttpMessageHandler();
417417
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
418+
IDictionary<string, string> notExpectedQueryParams = new Dictionary<string, string>();
418419
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
419420
IDictionary<string, string> expectedPostData = null; // Only used for Cloud Shell
420421

@@ -462,26 +463,35 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
462463
break;
463464
}
464465

466+
var manager = new CommonCryptographyManager();
467+
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
468+
465469
// If capabilityEnabled, add "xms_cc": "cp1"
466470
if (capabilityEnabled)
467471
{
468-
if (managedIdentitySourceType == ManagedIdentitySource.AppService
472+
if (managedIdentitySourceType == ManagedIdentitySource.AppService
469473
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
470474
{
471475
expectedQueryParams.Add("xms_cc", "cp1,cp2");
472476
}
473477
}
478+
else
479+
{
480+
notExpectedQueryParams.Add("xms_cc", "cp1,cp2");
481+
}
474482

475483
if (claimsEnabled)
476484
{
477-
var manager = new CommonCryptographyManager();
478-
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
479485
if (managedIdentitySourceType == ManagedIdentitySource.AppService
480486
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
481487
{
482488
expectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
483489
}
484490
}
491+
else
492+
{
493+
notExpectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
494+
}
485495

486496
if (managedIdentitySourceType != ManagedIdentitySource.CloudShell)
487497
{

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal class MockHttpMessageHandler : HttpClientHandler
2525
public IDictionary<string, string> ExpectedRequestHeaders { get; set; }
2626
public IList<string> UnexpectedRequestHeaders { get; set; }
2727
public IDictionary<string, string> UnExpectedPostData { get; set; }
28+
public IDictionary<string, string> NotExpectedQueryParams { get; set; }
2829
public HttpMethod ExpectedMethod { get; set; }
2930

3031
public Exception ExceptionToThrow { get; set; }
@@ -65,7 +66,9 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
6566

6667
Assert.AreEqual(ExpectedMethod, request.Method);
6768

68-
ValidateQueryParams(uri);
69+
ValidateExpectedQueryParams(uri);
70+
71+
ValidateNotExpectedQueryParams(uri);
6972

7073
await ValidatePostDataAsync(request).ConfigureAwait(false);
7174

@@ -80,12 +83,12 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
8083
return ResponseMessage;
8184
}
8285

83-
private void ValidateQueryParams(Uri uri)
86+
private void ValidateExpectedQueryParams(Uri uri)
8487
{
8588
if (ExpectedQueryParams != null && ExpectedQueryParams.Any())
8689
{
8790
Assert.IsFalse(string.IsNullOrEmpty(uri.Query), $"Provided url ({uri.AbsoluteUri}) does not contain query parameters as expected.");
88-
var inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
91+
Dictionary<string, string> inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
8992
Assert.AreEqual(ExpectedQueryParams.Count, inputQp.Count, "Different number of query params.");
9093
foreach (var key in ExpectedQueryParams.Keys)
9194
{
@@ -95,6 +98,35 @@ private void ValidateQueryParams(Uri uri)
9598
}
9699
}
97100

101+
private void ValidateNotExpectedQueryParams(Uri uri)
102+
{
103+
if (NotExpectedQueryParams != null && NotExpectedQueryParams.Any())
104+
{
105+
// Parse actual query params again (or reuse inputQp if you like)
106+
Dictionary<string, string> actualQueryParams = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
107+
List<string> unexpectedKeysFound = new List<string>();
108+
109+
foreach (KeyValuePair<string, string> kvp in NotExpectedQueryParams)
110+
{
111+
// Check if the request's query has this key
112+
if (actualQueryParams.TryGetValue(kvp.Key, out string value))
113+
{
114+
// Optionally, also check if we care about matching the *value*:
115+
if (string.Equals(value, kvp.Value, StringComparison.OrdinalIgnoreCase))
116+
{
117+
unexpectedKeysFound.Add(kvp.Key);
118+
}
119+
}
120+
}
121+
122+
// Fail if any "not expected" key/value pairs were found
123+
Assert.IsTrue(
124+
unexpectedKeysFound.Count == 0,
125+
$"Did not expect to find these query parameter keys/values: {string.Join(", ", unexpectedKeysFound)}"
126+
);
127+
}
128+
}
129+
98130
private async Task ValidatePostDataAsync(HttpRequestMessage request)
99131
{
100132
if (request.Method != HttpMethod.Get && request.Content != null)

0 commit comments

Comments
 (0)