Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.MtlsPop.Attestation;
using Microsoft.Win32.SafeHandles;
using Microsoft.Identity.Client.PlatformsCommon.Shared;

namespace Microsoft.Identity.Client.MtlsPop
{
Expand All @@ -24,6 +20,22 @@ public static class ManagedIdentityPopExtensions
public static AcquireTokenForManagedIdentityParameterBuilder WithMtlsProofOfPossession(
this AcquireTokenForManagedIdentityParameterBuilder builder)
{
void MtlsNotSupportedForManagedIdentity(string message)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic was previously in the wrong file

{
throw new MsalClientException(
MsalError.MtlsNotSupportedForManagedIdentity,
message);
}

if (!DesktopOsHelper.IsWindows())
{
MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForNonWindowsMessage);
}

#if NET462
MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForManagedIdentityMessage);
#endif

builder.CommonParameters.IsMtlsPopRequested = true;
AddRuntimeSupport(builder);
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,6 @@ public AcquireTokenForClientParameterBuilder WithSendX5C(bool withSendX5C)
/// <returns>The current instance of <see cref="AcquireTokenForClientParameterBuilder"/> to enable method chaining.</returns>
public AcquireTokenForClientParameterBuilder WithMtlsProofOfPossession()
{
if (ServiceBundle.Config.IsManagedIdentity)
{
void MtlsNotSupportedForManagedIdentity(string message)
{
throw new MsalClientException(
MsalError.MtlsNotSupportedForManagedIdentity,
message);
}

if (!DesktopOsHelper.IsWindows())
{
MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForNonWindowsMessage);
}

#if NET462
MtlsNotSupportedForManagedIdentity(MsalErrorMessage.MtlsNotSupportedForManagedIdentityMessage);
#endif
}

if (ServiceBundle.Config.ClientCredential is CertificateClientCredential certificateCredential)
{
if (certificateCredential.Certificate == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
// Licensed under the MIT License.

using System;
using System.IO;
using System.Threading.Tasks;
using System.Threading;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using System.IO;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.ManagedIdentity.V2;
using System.Security.Cryptography.X509Certificates;
Expand Down Expand Up @@ -37,18 +37,39 @@ internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityA
AcquireTokenForManagedIdentityParameters parameters,
CancellationToken cancellationToken)
{
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
AbstractManagedIdentity msi = await GetOrSelectManagedIdentitySourceAsync(requestContext, parameters.IsMtlsPopRequested).ConfigureAwait(false);
return await msi.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
}

// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext)
private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsync(RequestContext requestContext, bool isMtlsPopRequested)
{
using (requestContext.Logger.LogMethodDuration())
{
requestContext.Logger.Info($"[Managed Identity] Selecting managed identity source if not cached. Cached value is {s_sourceName} ");

var source = (s_sourceName != ManagedIdentitySource.None) ? s_sourceName : await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
var source = ManagedIdentitySource.None;

// If the source is not already set, determine it
if (s_sourceName == ManagedIdentitySource.None)
{
source = await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
}
// Otherwise, check if the source has already been set to ImdsV2 (via this method, or GetManagedIdentitySourceAsync in ManagedIdentityApplication.cs) and mTLS PoP was NOT requested
// In this case, we need to fall back to ImdsV1, because ImdsV2 currently only supports mTLS PoP requests
else if ((s_sourceName == ManagedIdentitySource.ImdsV2) && !isMtlsPopRequested)
{
requestContext.Logger.Info("[Managed Identity] ImdsV2 detected, but mTLS PoP was not requested. Falling back to ImdsV1 for this request only. Please use the \"WithMtlsProofOfPossession\" API to request a token via ImdsV2.");

// keep the cached source (s_sourceName) as ImdsV2, since the developer may decide to use mTLS PoP in subsequent requests

source = ManagedIdentitySource.DefaultToImds;
}
else
{
source = s_sourceName;
}

return source switch
{
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
Expand All @@ -66,14 +87,17 @@ private async Task<AbstractManagedIdentity> GetOrSelectManagedIdentitySourceAsyn
// This method is perf sensitive any changes should be benchmarked.
internal async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(RequestContext requestContext)
{
// First check env vars to avoid the probe if possible
ManagedIdentitySource source = GetManagedIdentitySourceNoImdsV2(requestContext.Logger);

// If a source is detected via env vars, use it
if (source != ManagedIdentitySource.DefaultToImds)
{
s_sourceName = source;
return source;
}

// probe IMDSv2
// Otherwise, probe IMDSv2
var response = await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext, probeMode: true).ConfigureAwait(false);
if (response != null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private async Task<CertificateRequestResponse> ExecuteCertificateRequestAsync(
{ OAuth2Header.XMsCorrelationId, _requestContext.CorrelationId.ToString() }
};

if (_isMtlsPopRequested && managedIdentityKeyInfo.Type != ManagedIdentityKeyType.KeyGuard)
if (managedIdentityKeyInfo.Type != ManagedIdentityKeyType.KeyGuard)
{
throw new MsalClientException(
"mtls_pop_requires_keyguard",
Expand Down Expand Up @@ -316,12 +316,10 @@ protected override async Task<ManagedIdentityRequest> CreateRequestAsync(string
request.Headers.Add(ThrottleCommon.ThrottleRetryAfterHeaderName, ThrottleCommon.ThrottleRetryAfterHeaderValue);
request.Headers.Add(OAuth2Header.RequestCorrelationIdInResponse, "true");

var tokenType = _isMtlsPopRequested ? "mtls_pop" : "bearer";

request.BodyParameters.Add("client_id", certificateRequestResponse.ClientId);
request.BodyParameters.Add("grant_type", OAuth2GrantType.ClientCredentials);
request.BodyParameters.Add("scope", resource.TrimEnd('/') + "/.default");
request.BodyParameters.Add("token_type", tokenType);
request.BodyParameters.Add("token_type", "mtls_pop");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest we keep both bearer and mtls_pop for the time being, and we enable MSI v2 bearer based on a feature flag.


request.RequestType = RequestType.STS;

Expand Down
11 changes: 4 additions & 7 deletions tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ public static string GetBridgedHybridSpaTokenResponse(string spaAccountId)
public static string GetMsiSuccessfulResponse(
int expiresInHours = 1,
bool useIsoFormat = false,
bool mTLSPop = false,
bool imdsV2 = false)
{
var expiresOnKey = imdsV2 ? "expires_in" : "expires_on";
Expand All @@ -141,7 +140,7 @@ public static string GetMsiSuccessfulResponse(
expiresOnValue = DateTimeHelpers.DateTimeToUnixTimestamp(DateTime.UtcNow.AddHours(expiresInHours));
}

var tokenType = mTLSPop ? "mtls_pop" : "Bearer";
var tokenType = imdsV2 ? "mtls_pop" : "Bearer";

return
"{\"access_token\":\"" + TestConstants.ATSecret + "\",\"" + expiresOnKey + "\":\"" + expiresOnValue + "\",\"resource\":\"https://management.azure.com/\"," +
Expand Down Expand Up @@ -700,8 +699,7 @@ public static MockHttpMessageHandler MockCertificateRequestResponse(
}

public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse(
IdentityLoggerAdapter identityLoggerAdapter,
bool mTLSPop = false)
IdentityLoggerAdapter identityLoggerAdapter)
{
IDictionary<string, string> expectedPostData = new Dictionary<string, string>();
IDictionary<string, string> expectedRequestHeaders = new Dictionary<string, string>
Expand All @@ -719,8 +717,7 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse(
expectedRequestHeaders[idParam.Key] = idParam.Value;
}

var tokenType = mTLSPop ? "mtls_pop" : "bearer";
expectedPostData.Add("token_type", tokenType);
expectedPostData.Add("token_type", "mtls_pop");

var handler = new MockHttpMessageHandler()
{
Expand All @@ -731,7 +728,7 @@ public static MockHttpMessageHandler MockImdsV2EntraTokenRequestResponse(
PresentRequestHeaders = presentRequestHeaders,
ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(GetMsiSuccessfulResponse(mTLSPop: mTLSPop, imdsV2: true)),
Content = new StringContent(GetMsiSuccessfulResponse(imdsV2: true)),
}
};

Expand Down
Loading