Skip to content

Commit 4ac538d

Browse files
initial
1 parent dbbda87 commit 4ac538d

File tree

11 files changed

+134
-18
lines changed

11 files changed

+134
-18
lines changed

src/client/Microsoft.Identity.Client/ApiConfig/Parameters/AcquireTokenForManagedIdentityParameters.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ internal class AcquireTokenForManagedIdentityParameters : IAcquireTokenParameter
1616

1717
public string Resource { get; set; }
1818

19+
public string Claims { get; set; }
20+
1921
public void LogParameters(ILoggerAdapter logger)
2022
{
2123
if (logger.IsLoggingEnabled(LogLevel.Info))

src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok
3636
// Skip checking cache when force refresh or claims is specified
3737
if (_managedIdentityParameters.ForceRefresh || !string.IsNullOrEmpty(AuthenticationRequestParameters.Claims))
3838
{
39+
_managedIdentityParameters.Claims = AuthenticationRequestParameters.Claims;
3940
AuthenticationRequestParameters.RequestContext.ApiEvent.CacheInfo = CacheRefreshReason.ForceRefreshOrClaims;
4041

4142
logger.Info("[ManagedIdentityRequest] Skipped looking for a cached access token because ForceRefresh or Claims were set. " +

src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
using System.Net;
1313
using Microsoft.Identity.Client.ApiConfig.Parameters;
1414
using System.Text;
15+
using System.Collections.Generic;
16+
using System.Linq;
1517
#if SUPPORTS_SYSTEM_TEXT_JSON
1618
using System.Text.Json;
1719
#else
@@ -48,7 +50,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
4850
// Convert the scopes to a resource string.
4951
string resource = parameters.Resource;
5052

51-
ManagedIdentityRequest request = CreateRequest(resource);
53+
ManagedIdentityRequest request = CreateRequest(resource, parameters);
5254

5355
_requestContext.Logger.Info("[Managed Identity] Sending request to managed identity endpoints.");
5456

@@ -125,7 +127,7 @@ protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
125127
throw exception;
126128
}
127129

128-
protected abstract ManagedIdentityRequest CreateRequest(string resource);
130+
protected abstract ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters);
129131

130132
protected ManagedIdentityResponse GetSuccessfulResponse(HttpResponse response)
131133
{
@@ -293,5 +295,54 @@ private static void CreateAndThrowException(string errorCode,
293295

294296
throw exception;
295297
}
298+
299+
/// <summary>
300+
/// Sets the claims and capabilities in the request.
301+
/// </summary>
302+
/// <param name="request"></param>
303+
/// <param name="parameters"></param>
304+
protected virtual void ApplyClaimsAndCapabilities(
305+
ManagedIdentityRequest request,
306+
AcquireTokenForManagedIdentityParameters parameters)
307+
{
308+
IEnumerable<string> clientCapabilities = _requestContext.ServiceBundle.Config.ClientCapabilities;
309+
310+
// If claims are present, set bypass_cache=true
311+
if (!string.IsNullOrEmpty(parameters.Claims))
312+
{
313+
SetRequestParameter(request, "bypass_cache", "true");
314+
_requestContext.Logger.Info("[Managed Identity] Setting bypass_cache=true in the Managed Identity request due to claims.");
315+
316+
// Set xms_cc only if clientCapabilities exist
317+
if (clientCapabilities != null && clientCapabilities.Any())
318+
{
319+
SetRequestParameter(request, "xms_cc", string.Join(",", clientCapabilities));
320+
_requestContext.Logger.Info("[Managed Identity] Adding client capabilities (xms_cc) to Managed Identity request.");
321+
}
322+
}
323+
else
324+
{
325+
SetRequestParameter(request, "bypass_cache", "false");
326+
_requestContext.Logger.Info("[Managed Identity] Setting bypass_cache=false (no claims provided).");
327+
}
328+
}
329+
330+
/// <summary>
331+
/// Sets the request parameter in either the query or body based on the request method.
332+
/// </summary>
333+
/// <param name="request"></param>
334+
/// <param name="key"></param>
335+
/// <param name="value"></param>
336+
protected void SetRequestParameter(ManagedIdentityRequest request, string key, string value)
337+
{
338+
if (request.Method == HttpMethod.Post)
339+
{
340+
request.BodyParameters[key] = value;
341+
}
342+
else
343+
{
344+
request.QueryParameters[key] = value;
345+
}
346+
}
296347
}
297348
}

src/client/Microsoft.Identity.Client/ManagedIdentity/AppServiceManagedIdentitySource.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using System;
55
using System.Collections.Generic;
66
using System.Globalization;
7+
using System.Linq;
8+
using Microsoft.Identity.Client.ApiConfig.Parameters;
79
using Microsoft.Identity.Client.Core;
810
using Microsoft.Identity.Client.Internal;
911
using Microsoft.Identity.Client.Utils;
@@ -65,14 +67,16 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
6567
return true;
6668
}
6769

68-
protected override ManagedIdentityRequest CreateRequest(string resource)
70+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
6971
{
7072
ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint);
7173

7274
request.Headers.Add(SecretHeaderName, _secret);
7375
request.QueryParameters["api-version"] = AppServiceMsiApiVersion;
7476
request.QueryParameters["resource"] = resource;
7577

78+
ApplyClaimsAndCapabilities(request, parameters);
79+
7680
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
7781
{
7882
case AppConfig.ManagedIdentityIdType.ClientId:

src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,16 @@ private AzureArcManagedIdentitySource(Uri endpoint, RequestContext requestContex
8181
}
8282
}
8383

84-
protected override ManagedIdentityRequest CreateRequest(string resource)
84+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
8585
{
8686
ManagedIdentityRequest request = new ManagedIdentityRequest(System.Net.Http.HttpMethod.Get, _endpoint);
8787

8888
request.Headers.Add("Metadata", "true");
8989
request.QueryParameters["api-version"] = ArcApiVersion;
9090
request.QueryParameters["resource"] = resource;
91+
request.QueryParameters["bypass_cache"] = "false";
92+
93+
ApplyClaimsAndCapabilities(request, parameters);
9194

9295
return request;
9396
}
@@ -121,7 +124,7 @@ protected override async Task<ManagedIdentityResponse> HandleResponseAsync(
121124

122125
var authHeaderValue = "Basic " + File.ReadAllText(splitChallenge[1]);
123126

124-
ManagedIdentityRequest request = CreateRequest(parameters.Resource);
127+
ManagedIdentityRequest request = CreateRequest(parameters.Resource, parameters);
125128

126129
_requestContext.Logger.Verbose(() => "[Managed Identity] Adding authorization header to the request.");
127130
request.Headers.Add("Authorization", authHeaderValue);

src/client/Microsoft.Identity.Client/ManagedIdentity/CloudShellManagedIdentitySource.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Globalization;
66
using System.Net.Http;
7+
using Microsoft.Identity.Client.ApiConfig.Parameters;
78
using Microsoft.Identity.Client.Core;
89
using Microsoft.Identity.Client.Internal;
910

@@ -73,14 +74,16 @@ private CloudShellManagedIdentitySource(Uri endpoint, RequestContext requestCont
7374
}
7475
}
7576

76-
protected override ManagedIdentityRequest CreateRequest(string resource)
77+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
7778
{
7879
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Post, _endpoint);
7980

8081
request.Headers.Add("ContentType", "application/x-www-form-urlencoded");
8182
request.Headers.Add("Metadata", "true");
8283

8384
request.BodyParameters.Add("resource", resource);
85+
86+
ApplyClaimsAndCapabilities(request, parameters);
8487

8588
return request;
8689
}

src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsManagedIdentitySource.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,16 @@ internal ImdsManagedIdentitySource(RequestContext requestContext) :
5656
requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint);
5757
}
5858

59-
protected override ManagedIdentityRequest CreateRequest(string resource)
59+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
6060
{
6161
ManagedIdentityRequest request = new(HttpMethod.Get, _imdsEndpoint);
6262

6363
request.Headers.Add("Metadata", "true");
6464
request.QueryParameters["api-version"] = ImdsApiVersion;
6565
request.QueryParameters["resource"] = resource;
6666

67+
ApplyClaimsAndCapabilities(request, parameters);
68+
6769
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
6870
{
6971
case AppConfig.ManagedIdentityIdType.ClientId:

src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Globalization;
6+
using Microsoft.Identity.Client.ApiConfig.Parameters;
67
using Microsoft.Identity.Client.Core;
78
using Microsoft.Identity.Client.Internal;
89

@@ -62,7 +63,7 @@ private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger
6263
return true;
6364
}
6465

65-
protected override ManagedIdentityRequest CreateRequest(string resource)
66+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
6667
{
6768
ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint);
6869

@@ -71,6 +72,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource)
7172
request.QueryParameters["api-version"] = MachineLearningMsiApiVersion;
7273
request.QueryParameters["resource"] = resource;
7374

75+
ApplyClaimsAndCapabilities(request, parameters);
76+
7477
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
7578
{
7679
case AppConfig.ManagedIdentityIdType.ClientId:

src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Globalization;
66
using System.Net.Http;
7+
using Microsoft.Identity.Client.ApiConfig.Parameters;
78
using Microsoft.Identity.Client.Core;
89
using Microsoft.Identity.Client.Internal;
910

@@ -77,7 +78,6 @@ internal HttpClientHandler CreateHandlerWithSslValidation(ILoggerAdapter logger)
7778
#endif
7879
}
7980

80-
8181
private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri endpoint, string identityHeaderValue) :
8282
base(requestContext, ManagedIdentitySource.ServiceFabric)
8383
{
@@ -90,7 +90,7 @@ private ServiceFabricManagedIdentitySource(RequestContext requestContext, Uri en
9090
}
9191
}
9292

93-
protected override ManagedIdentityRequest CreateRequest(string resource)
93+
protected override ManagedIdentityRequest CreateRequest(string resource, AcquireTokenForManagedIdentityParameters parameters)
9494
{
9595
ManagedIdentityRequest request = new ManagedIdentityRequest(HttpMethod.Get, _endpoint);
9696

@@ -99,6 +99,8 @@ protected override ManagedIdentityRequest CreateRequest(string resource)
9999
request.QueryParameters["api-version"] = ServiceFabricMsiApiVersion;
100100
request.QueryParameters["resource"] = resource;
101101

102+
ApplyClaimsAndCapabilities(request, parameters);
103+
102104
switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
103105
{
104106
case AppConfig.ManagedIdentityIdType.ClientId:

tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,21 @@ public static void AddManagedIdentityMockHandler(
367367
ManagedIdentitySource managedIdentitySourceType,
368368
string userAssignedId = null,
369369
UserAssignedIdentityId userAssignedIdentityId = UserAssignedIdentityId.None,
370-
HttpStatusCode statusCode = HttpStatusCode.OK
370+
HttpStatusCode statusCode = HttpStatusCode.OK,
371+
bool capabilityEnabled = false,
372+
bool claimsEnabled = false
371373
)
372374
{
373375
HttpResponseMessage responseMessage = new HttpResponseMessage(statusCode);
374376
HttpContent content = new StringContent(response);
375377
responseMessage.Content = content;
376378

377-
MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(managedIdentitySourceType, resource);
379+
MockHttpMessageHandler httpMessageHandler = BuildMockHandlerForManagedIdentitySource(
380+
managedIdentitySourceType,
381+
resource,
382+
capabilityEnabled,
383+
claimsEnabled
384+
);
378385

379386
if (userAssignedIdentityId == UserAssignedIdentityId.ClientId)
380387
{
@@ -396,12 +403,19 @@ public static void AddManagedIdentityMockHandler(
396403

397404
httpManager.AddMockHandler(httpMessageHandler);
398405
}
399-
400-
private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource)
406+
407+
private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(
408+
ManagedIdentitySource managedIdentitySourceType,
409+
string resource,
410+
bool capabilityEnabled = false,
411+
bool claimsEnabled = false)
401412
{
402413
MockHttpMessageHandler httpMessageHandler = new MockHttpMessageHandler();
403414
IDictionary<string, string> expectedQueryParams = new Dictionary<string, string>();
404415
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>();
416+
IDictionary<string, string> expectedPostData = null; // Only used for Cloud Shell
417+
418+
string bypassCacheValue = claimsEnabled ? "true" : "false"; // Set based on claimsEnabled
405419

406420
switch (managedIdentitySourceType)
407421
{
@@ -410,45 +424,73 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(M
410424
expectedQueryParams.Add("api-version", "2019-08-01");
411425
expectedQueryParams.Add("resource", resource);
412426
expectedRequestHeaders.Add("X-IDENTITY-HEADER", "secret");
427+
expectedQueryParams.Add("bypass_cache", bypassCacheValue);
413428
break;
414429
case ManagedIdentitySource.AzureArc:
415430
httpMessageHandler.ExpectedMethod = HttpMethod.Get;
416431
expectedQueryParams.Add("api-version", "2019-11-01");
417432
expectedQueryParams.Add("resource", resource);
418433
expectedRequestHeaders.Add("Metadata", "true");
434+
expectedQueryParams.Add("bypass_cache", bypassCacheValue);
419435
break;
420436
case ManagedIdentitySource.Imds:
421437
httpMessageHandler.ExpectedMethod = HttpMethod.Get;
422438
expectedQueryParams.Add("api-version", "2018-02-01");
423439
expectedQueryParams.Add("resource", resource);
424440
expectedRequestHeaders.Add("Metadata", "true");
441+
expectedQueryParams.Add("bypass_cache", bypassCacheValue);
425442
break;
426443
case ManagedIdentitySource.CloudShell:
427444
httpMessageHandler.ExpectedMethod = HttpMethod.Post;
428445
expectedRequestHeaders.Add("Metadata", "true");
429446
expectedRequestHeaders.Add("ContentType", "application/x-www-form-urlencoded");
430-
httpMessageHandler.ExpectedPostData = new Dictionary<string, string> { { "resource", resource } };
447+
448+
expectedPostData = new Dictionary<string, string>
449+
{
450+
{ "resource", resource },
451+
{ "bypass_cache", bypassCacheValue }
452+
};
431453
break;
432454
case ManagedIdentitySource.ServiceFabric:
433455
httpMessageHandler.ExpectedMethod = HttpMethod.Get;
434456
expectedRequestHeaders.Add("secret", "secret");
435457
expectedQueryParams.Add("api-version", "2019-07-01-preview");
436458
expectedQueryParams.Add("resource", resource);
459+
expectedQueryParams.Add("bypass_cache", bypassCacheValue);
437460
break;
438461
case ManagedIdentitySource.MachineLearning:
439462
httpMessageHandler.ExpectedMethod = HttpMethod.Get;
440463
expectedRequestHeaders.Add("secret", "secret");
441464
expectedRequestHeaders.Add("Metadata", "true");
442465
expectedQueryParams.Add("api-version", "2017-09-01");
443466
expectedQueryParams.Add("resource", resource);
467+
expectedQueryParams.Add("bypass_cache", bypassCacheValue);
444468
break;
445469
}
446470

471+
// If capabilityEnabled, add "xms_cc": "cp1"
472+
if (capabilityEnabled)
473+
{
474+
if (managedIdentitySourceType == ManagedIdentitySource.CloudShell)
475+
{
476+
expectedPostData ??= new Dictionary<string, string>();
477+
expectedPostData.Add("xms_cc", "cp1,cp2");
478+
}
479+
else
480+
{
481+
expectedQueryParams.Add("xms_cc", "cp1,cp2");
482+
}
483+
}
484+
447485
if (managedIdentitySourceType != ManagedIdentitySource.CloudShell)
448486
{
449487
httpMessageHandler.ExpectedQueryParams = expectedQueryParams;
450488
}
451-
489+
else
490+
{
491+
httpMessageHandler.ExpectedPostData = expectedPostData;
492+
}
493+
452494
httpMessageHandler.ExpectedRequestHeaders = expectedRequestHeaders;
453495

454496
return httpMessageHandler;

0 commit comments

Comments
 (0)