Skip to content

CSR Metadata request acts as a probe for ImdsV2 #5359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f6f627f
Implemented ImdsV2 class and it's probe
Robbie-Microsoft Jun 20, 2025
cb2c763
Implemented unit tests
Robbie-Microsoft Jun 23, 2025
e96c630
added try/catch to network request
Robbie-Microsoft Jun 24, 2025
6ffafec
Merge branch 'main' into rginsburg/msiv2_probe
Robbie-Microsoft Jul 1, 2025
f28b71f
Merged from main
Robbie-Microsoft Jul 1, 2025
b9f12b1
Merge branch 'rginsburg/msiv2_feature_branch' into rginsburg/msiv2_probe
Robbie-Microsoft Jul 10, 2025
6aa43d0
Refactored ManagedIdentityClient
Robbie-Microsoft Jul 11, 2025
c2a8282
Renamed variable since it's no longer static
Robbie-Microsoft Jul 11, 2025
da727eb
ManagedIdentityClient's identitySource and sourceName are now static,…
Robbie-Microsoft Jul 11, 2025
3588084
Fixed many broken unit tests
Robbie-Microsoft Jul 12, 2025
ba43757
ManagedIdentityClient is now per ManagedIdentityApplication instead o…
Robbie-Microsoft Jul 21, 2025
78730ba
Merge branch 'rginsburg/msiv2_feature_branch' into rginsburg/msiv2_probe
Robbie-Microsoft Jul 21, 2025
9706db6
Merge branch 'rginsburg/msiv2_feature_branch' into rginsburg/msiv2_probe
Robbie-Microsoft Jul 23, 2025
d894554
Implemented some GitHub feedback
Robbie-Microsoft Jul 24, 2025
71a654e
Refactored GetCsrMetadataAsync method into smaller methods
Robbie-Microsoft Jul 24, 2025
7450081
Implemented some GitHub feedback and fixed rebase errors
Robbie-Microsoft Jul 28, 2025
8ddfe93
Small updates and fixed some unit tests
Robbie-Microsoft Jul 28, 2025
3783f3d
undo changes to global.json
Robbie-Microsoft Jul 29, 2025
f986c51
Refactor. All unit tests pass now.
Robbie-Microsoft Jul 30, 2025
c7af31c
Implemented some feedback + improvements
Robbie-Microsoft Jul 30, 2025
1d3d3db
Implemented some feedback
Robbie-Microsoft Jul 31, 2025
80215a3
CsrMetadata follows new spec now
Robbie-Microsoft Jul 31, 2025
3a1dfe2
IMDSv2: Created a retry policy for the CSR Metadata Probe (#5419)
Robbie-Microsoft Aug 1, 2025
24bccf7
removed unused imports in test file
Robbie-Microsoft Aug 1, 2025
8788b51
Fixed minor spec issue
Robbie-Microsoft Aug 1, 2025
f0589c1
Implemented some feedback
Robbie-Microsoft Aug 1, 2025
cb2647e
refactored Imds UAMI id switch statement
Robbie-Microsoft Aug 1, 2025
1398d92
Implemented feedback
Robbie-Microsoft Aug 1, 2025
961d787
undid changes to global.json
Robbie-Microsoft Aug 1, 2025
56742cd
Implemented final feedback
Robbie-Microsoft Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ public async Task<AuthenticationResult> ExecuteAsync(
var handler = new ManagedIdentityAuthRequest(
ServiceBundle,
requestParams,
managedIdentityParameters);
managedIdentityParameters,
_managedIdentityApplication.ManagedIdentityClient);

return await handler.RunAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;

namespace Microsoft.Identity.Client.Http.Retry
{
internal class CsrMetadataProbeRetryPolicy : ImdsRetryPolicy
{
protected override bool ShouldRetry(HttpResponse response, Exception exception)
{
return HttpRetryConditions.CsrMetadataProbe(response, exception);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ public static bool RegionDiscovery(HttpResponse response, Exception exception)
return (int)response.StatusCode is not (404 or 408);
}

/// <summary>
/// Retry policy specific to CSR Metadata Probe.
/// Extends Imds retry policy but excludes 404 and 408 status codes.
/// </summary>
public static bool CsrMetadataProbe(HttpResponse response, Exception exception)
{
if (!Imds(response, exception))
{
return false;
}

// If Imds would retry but the status code is 404, don't retry
return (int)response.StatusCode is not 404;
}

/// <summary>
/// Retry condition for /token and /authorize endpoints
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ internal virtual Task DelayAsync(int milliseconds)
return Task.Delay(milliseconds);
}

protected virtual bool ShouldRetry(HttpResponse response, Exception exception)
{
return HttpRetryConditions.Imds(response, exception);
}

public async Task<bool> PauseForRetryAsync(HttpResponse response, Exception exception, int retryCount, ILoggerAdapter logger)
{
int httpStatusCode = (int)response.StatusCode;
Expand All @@ -46,7 +51,7 @@ public async Task<bool> PauseForRetryAsync(HttpResponse response, Exception exce
}

// Check if the status code is retriable and if the current retry count is less than max retries
if (HttpRetryConditions.Imds(response, exception) &&
if (ShouldRetry(response, exception) &&
retryCount < _maxRetries)
{
int retryAfterDelay = httpStatusCode == (int)HttpStatusCode.Gone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public virtual IRetryPolicy GetRetryPolicy(RequestType requestType)
return new ImdsRetryPolicy();
case RequestType.RegionDiscovery:
return new RegionDiscoveryRetryPolicy();
case RequestType.CsrMetadataProbe:
return new CsrMetadataProbeRetryPolicy();
default:
throw new ArgumentOutOfRangeException(nameof(requestType), requestType, "Unknown request type.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Threading.Tasks;
using Microsoft.Identity.Client.ManagedIdentity;

namespace Microsoft.Identity.Client
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ namespace Microsoft.Identity.Client.Internal.Requests
internal class ManagedIdentityAuthRequest : RequestBase
{
private readonly AcquireTokenForManagedIdentityParameters _managedIdentityParameters;
private readonly ManagedIdentityClient _managedIdentityClient;
private static readonly SemaphoreSlim s_semaphoreSlim = new SemaphoreSlim(1, 1);

public ManagedIdentityAuthRequest(
IServiceBundle serviceBundle,
AuthenticationRequestParameters authenticationRequestParameters,
AcquireTokenForManagedIdentityParameters managedIdentityParameters)
AcquireTokenForManagedIdentityParameters managedIdentityParameters,
ManagedIdentityClient managedIdentityClient)
: base(serviceBundle, authenticationRequestParameters, managedIdentityParameters)
{
_managedIdentityParameters = managedIdentityParameters;
_managedIdentityClient = managedIdentityClient;
}

protected override async Task<AuthenticationResult> ExecuteAsync(CancellationToken cancellationToken)
Expand Down Expand Up @@ -152,12 +155,9 @@ private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync

await ResolveAuthorityAsync().ConfigureAwait(false);

ManagedIdentityClient managedIdentityClient =
new ManagedIdentityClient(AuthenticationRequestParameters.RequestContext);

ManagedIdentityResponse managedIdentityResponse =
await managedIdentityClient
.SendTokenRequestForManagedIdentityAsync(_managedIdentityParameters, cancellationToken)
await _managedIdentityClient
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)
.ConfigureAwait(false);

var msalTokenResponse = MsalTokenResponse.CreateFromManagedIdentityResponse(managedIdentityResponse);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if SUPPORTS_SYSTEM_TEXT_JSON
using Microsoft.Identity.Client.Platforms.net;
using JsonProperty = System.Text.Json.Serialization.JsonPropertyNameAttribute;
#else
using Microsoft.Identity.Json;
#endif

namespace Microsoft.Identity.Client.ManagedIdentity
{
/// <summary>
/// Represents metadata required for Certificate Signing Request (CSR) operations.
/// </summary>
internal class CsrMetadata
{
/// <summary>
/// VM unique Id
/// </summary>
[JsonProperty("cuid")]
public CuidInfo Cuid { get; set; }

/// <summary>
/// client_id of the Managed Identity
/// </summary>
[JsonProperty("clientId")]
public string ClientId { get; set; }

/// <summary>
/// AAD Tenant of the Managed Identity
/// </summary>
[JsonProperty("tenantId")]
public string TenantId { get; set; }

/// <summary>
/// MAA Regional / Custom Endpoint for attestation purposes.
/// </summary>
[JsonProperty("attestationEndpoint")]
public string AttestationEndpoint { get; set; }

// Parameterless constructor for deserialization
public CsrMetadata() { }

/// <summary>
/// Validates a JSON decoded CsrMetadata instance.
/// </summary>
/// <param name="csrMetadata">The CsrMetadata object.</param>
/// <returns>false if any field is null.</returns>
public static bool ValidateCsrMetadata(CsrMetadata csrMetadata)
{
if (csrMetadata == null ||
csrMetadata.Cuid == null ||
string.IsNullOrEmpty(csrMetadata.Cuid.Vmid) ||
string.IsNullOrEmpty(csrMetadata.Cuid.Vmssid) ||
string.IsNullOrEmpty(csrMetadata.ClientId) ||
string.IsNullOrEmpty(csrMetadata.TenantId) ||
string.IsNullOrEmpty(csrMetadata.AttestationEndpoint))
{
return false;
}

return true;
}
}

/// <summary>
/// Represents VM unique Ids for CSR metadata.
/// </summary>
internal class CuidInfo
{
[JsonProperty("vmid")]
public string Vmid { get; set; }

[JsonProperty("vmssid")]
public string Vmssid { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace Microsoft.Identity.Client.ManagedIdentity
internal class ImdsManagedIdentitySource : AbstractManagedIdentity
{
// IMDS constants. Docs for IMDS are available here https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http
private static readonly Uri s_imdsEndpoint = new("http://169.254.169.254/metadata/identity/oauth2/token");

private const string DefaultImdsBaseEndpoint= "http://169.254.169.254";
private const string ImdsTokenPath = "/metadata/identity/oauth2/token";
private const string ImdsApiVersion = "2018-02-01";
public const string ImdsApiVersion = "2018-02-01";

private const string DefaultMessage = "[Managed Identity] Service request failed.";

internal const string IdentityUnavailableError = "[Managed Identity] Authentication unavailable. " +
Expand All @@ -37,20 +37,7 @@ internal ImdsManagedIdentitySource(RequestContext requestContext) :
{
requestContext.Logger.Info(() => "[Managed Identity] Defaulting to IMDS endpoint for managed identity.");

if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
{
requestContext.Logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint);
var builder = new UriBuilder(EnvironmentVariables.PodIdentityEndpoint)
{
Path = ImdsTokenPath
};
_imdsEndpoint = builder.Uri;
}
else
{
requestContext.Logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint.");
_imdsEndpoint = s_imdsEndpoint;
}
_imdsEndpoint = GetValidatedEndpoint(requestContext.Logger, ImdsTokenPath);

requestContext.Logger.Verbose(() => "[Managed Identity] Creating IMDS managed identity source. Endpoint URI: " + _imdsEndpoint);
}
Expand Down Expand Up @@ -152,5 +139,38 @@ internal static string CreateRequestFailedMessage(HttpResponse response, string

return messageBuilder.ToString();
}

public static Uri GetValidatedEndpoint(
ILoggerAdapter logger,
string subPath,
string queryParams = null
)
{
UriBuilder builder;

if (!string.IsNullOrEmpty(EnvironmentVariables.PodIdentityEndpoint))
{
logger.Verbose(() => "[Managed Identity] Environment variable AZURE_POD_IDENTITY_AUTHORITY_HOST for IMDS returned endpoint: " + EnvironmentVariables.PodIdentityEndpoint);
builder = new UriBuilder(EnvironmentVariables.PodIdentityEndpoint)
{
Path = subPath
};
}
else
{
logger.Verbose(() => "[Managed Identity] Unable to find AZURE_POD_IDENTITY_AUTHORITY_HOST environment variable for IMDS, using the default endpoint.");
builder = new UriBuilder(DefaultImdsBaseEndpoint)
{
Path = subPath
};
}

if (!string.IsNullOrEmpty(queryParams))
{
builder.Query = queryParams;
}

return builder.Uri;
}
}
}
Loading