Skip to content

Commit 992ae6e

Browse files
enhancements
1 parent 25e3a78 commit 992ae6e

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,14 @@ protected virtual void ApplyClaimsAndCapabilities(
319319
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
320320
}
321321

322-
// If claims are present, send the bad token hash to the Managed Identity endpoint.
323-
if (!string.IsNullOrEmpty(parameters.Claims))
322+
// Only include 'token_sha256_to_refresh' if we have both Claims and the old token's hash
323+
if (!string.IsNullOrEmpty(parameters.Claims) &&
324+
!string.IsNullOrEmpty(parameters.BadTokenHash))
324325
{
325-
if (!string.IsNullOrEmpty(parameters.BadTokenHash))
326-
{
327-
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
328-
_requestContext.Logger.Info(
329-
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
330-
);
331-
}
326+
SetRequestParameter(request, "token_sha256_to_refresh", parameters.BadTokenHash);
327+
_requestContext.Logger.Info(
328+
"[Managed Identity] Passing SHA-256 of the 'bad' token to Managed Identity endpoint."
329+
);
332330
}
333331
}
334332

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
427427
{
428428
MockHttpMessageHandler httpMessageHandler = new MockHttpMessageHandler();
429429
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
430+
IDictionary<string, string> notExpectedQueryParams = new Dictionary<string, string>();
430431
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
431432
IDictionary<string, string> expectedPostData = null; // Only used for Cloud Shell
432433

@@ -474,26 +475,35 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
474475
break;
475476
}
476477

478+
var manager = new CommonCryptographyManager();
479+
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
480+
477481
// If capabilityEnabled, add "xms_cc": "cp1"
478482
if (capabilityEnabled)
479483
{
480-
if (managedIdentitySourceType == ManagedIdentitySource.AppService
484+
if (managedIdentitySourceType == ManagedIdentitySource.AppService
481485
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
482486
{
483487
expectedQueryParams.Add("xms_cc", "cp1,cp2");
484488
}
485489
}
490+
else
491+
{
492+
notExpectedQueryParams.Add("xms_cc", "cp1,cp2");
493+
}
486494

487495
if (claimsEnabled)
488496
{
489-
var manager = new CommonCryptographyManager();
490-
var value = manager.CreateSha256Hash(TestConstants.ATSecret);
491497
if (managedIdentitySourceType == ManagedIdentitySource.AppService
492498
|| managedIdentitySourceType == ManagedIdentitySource.ServiceFabric)
493499
{
494500
expectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
495501
}
496502
}
503+
else
504+
{
505+
notExpectedQueryParams.Add("token_sha256_to_refresh", manager.CreateSha256Hash(TestConstants.ATSecret));
506+
}
497507

498508
if (managedIdentitySourceType != ManagedIdentitySource.CloudShell)
499509
{

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpMessageHandler.cs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal class MockHttpMessageHandler : HttpClientHandler
2525
public IDictionary<string, string> ExpectedRequestHeaders { get; set; }
2626
public IList<string> UnexpectedRequestHeaders { get; set; }
2727
public IDictionary<string, string> UnExpectedPostData { get; set; }
28+
public IDictionary<string, string> NotExpectedQueryParams { get; set; }
2829
public HttpMethod ExpectedMethod { get; set; }
2930

3031
public Exception ExceptionToThrow { get; set; }
@@ -65,7 +66,9 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
6566

6667
Assert.AreEqual(ExpectedMethod, request.Method);
6768

68-
ValidateQueryParams(uri);
69+
ValidateExpectedQueryParams(uri);
70+
71+
ValidateNotExpectedQueryParams(uri);
6972

7073
await ValidatePostDataAsync(request).ConfigureAwait(false);
7174

@@ -80,12 +83,12 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
8083
return ResponseMessage;
8184
}
8285

83-
private void ValidateQueryParams(Uri uri)
86+
private void ValidateExpectedQueryParams(Uri uri)
8487
{
8588
if (ExpectedQueryParams != null && ExpectedQueryParams.Any())
8689
{
8790
Assert.IsFalse(string.IsNullOrEmpty(uri.Query), $"Provided url ({uri.AbsoluteUri}) does not contain query parameters as expected.");
88-
var inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
91+
Dictionary<string, string> inputQp = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
8992
Assert.AreEqual(ExpectedQueryParams.Count, inputQp.Count, "Different number of query params.");
9093
foreach (var key in ExpectedQueryParams.Keys)
9194
{
@@ -95,6 +98,35 @@ private void ValidateQueryParams(Uri uri)
9598
}
9699
}
97100

101+
private void ValidateNotExpectedQueryParams(Uri uri)
102+
{
103+
if (NotExpectedQueryParams != null && NotExpectedQueryParams.Any())
104+
{
105+
// Parse actual query params again (or reuse inputQp if you like)
106+
Dictionary<string, string> actualQueryParams = CoreHelpers.ParseKeyValueList(uri.Query.Substring(1), '&', false, null);
107+
List<string> unexpectedKeysFound = new List<string>();
108+
109+
foreach (KeyValuePair<string, string> kvp in NotExpectedQueryParams)
110+
{
111+
// Check if the request's query has this key
112+
if (actualQueryParams.TryGetValue(kvp.Key, out string value))
113+
{
114+
// Optionally, also check if we care about matching the *value*:
115+
if (string.Equals(value, kvp.Value, StringComparison.OrdinalIgnoreCase))
116+
{
117+
unexpectedKeysFound.Add(kvp.Key);
118+
}
119+
}
120+
}
121+
122+
// Fail if any "not expected" key/value pairs were found
123+
Assert.IsTrue(
124+
unexpectedKeysFound.Count == 0,
125+
$"Did not expect to find these query parameter keys/values: {string.Join(", ", unexpectedKeysFound)}"
126+
);
127+
}
128+
}
129+
98130
private async Task ValidatePostDataAsync(HttpRequestMessage request)
99131
{
100132
if (request.Method != HttpMethod.Get && request.Content != null)

0 commit comments

Comments
 (0)