Skip to content

Commit 6aa43d0

Browse files
Refactored ManagedIdentityClient
1 parent b9f12b1 commit 6aa43d0

25 files changed

+267
-43
lines changed

src/client/Microsoft.Identity.Client/IManagedIdentityApplication.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.ComponentModel;
77
using System.Threading.Tasks;
8+
using Microsoft.Identity.Client.ManagedIdentity;
89

910
namespace Microsoft.Identity.Client
1011
{
@@ -27,5 +28,11 @@ public interface IManagedIdentityApplication : IApplicationBase
2728
/// <see cref="AcquireTokenForManagedIdentityParameterBuilder.WithForceRefresh(bool)"/>
2829
/// </remarks>
2930
AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIdentity(string resource);
31+
32+
/// <summary>
33+
/// Detects and returns the managed identity source available on the environment.
34+
/// </summary>
35+
/// <returns>Managed identity source detected on the environment if any.</returns>
36+
Task<ManagedIdentitySource> GetManagedIdentitySourceAsync();
3037
}
3138
}

src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync
153153
await ResolveAuthorityAsync().ConfigureAwait(false);
154154

155155
ManagedIdentityClient managedIdentityClient =
156-
new ManagedIdentityClient(AuthenticationRequestParameters.RequestContext);
156+
new ManagedIdentityClient();
157157

158158
ManagedIdentityResponse managedIdentityResponse =
159159
await managedIdentityClient
160-
.SendTokenRequestForManagedIdentityAsync(_managedIdentityParameters, cancellationToken)
160+
.SendTokenRequestForManagedIdentityAsync(AuthenticationRequestParameters.RequestContext, _managedIdentityParameters, cancellationToken)
161161
.ConfigureAwait(false);
162162

163163
var msalTokenResponse = MsalTokenResponse.CreateFromManagedIdentityResponse(managedIdentityResponse);

src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,102 @@ internal class ManagedIdentityClient
1919
{
2020
private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe";
2121
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds";
22-
private static RequestContext _requestContext;
23-
private static AbstractManagedIdentity _identitySource;
22+
private AbstractManagedIdentity s_identitySource;
23+
public static ManagedIdentitySource s_sourceName = ManagedIdentitySource.None;
2424

25-
public ManagedIdentityClient(RequestContext requestContext)
25+
internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityAsync(
26+
RequestContext requestContext,
27+
AcquireTokenForManagedIdentityParameters parameters,
28+
CancellationToken cancellationToken)
2629
{
27-
_requestContext = requestContext;
28-
}
29-
30-
internal async Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityAsync(AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken)
31-
{
32-
if (_identitySource == null)
30+
if (s_identitySource == null)
3331
{
34-
using (_requestContext.Logger.LogMethodDuration())
32+
using (requestContext.Logger.LogMethodDuration())
3533
{
36-
_identitySource = await SelectManagedIdentitySourceAsync().ConfigureAwait(false);
34+
s_identitySource = await SelectManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
3735
}
3836
}
3937

40-
return await _identitySource.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
38+
return await s_identitySource.AuthenticateAsync(parameters, cancellationToken).ConfigureAwait(false);
4139
}
4240

4341
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
44-
private static async Task<AbstractManagedIdentity> SelectManagedIdentitySourceAsync()
42+
private static async Task<AbstractManagedIdentity> SelectManagedIdentitySourceAsync(RequestContext requestContext)
4543
{
46-
return await GetManagedIdentitySourceAsync(_requestContext.Logger).ConfigureAwait(false) switch
47-
{
48-
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(_requestContext),
49-
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(_requestContext),
50-
ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(_requestContext),
51-
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(_requestContext),
52-
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(_requestContext),
53-
ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(_requestContext),
54-
_ => new ImdsManagedIdentitySource(_requestContext)
44+
var source = (s_sourceName != ManagedIdentitySource.None) ? s_sourceName : await GetManagedIdentitySourceAsync(requestContext).ConfigureAwait(false);
45+
return source switch
46+
{
47+
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
48+
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext),
49+
ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext),
50+
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
51+
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
52+
ManagedIdentitySource.ImdsV2 => ImdsV2ManagedIdentitySource.Create(requestContext),
53+
_ => new ImdsManagedIdentitySource(requestContext)
5554
};
5655
}
5756

57+
// Detect managed identity source based on the availability of environment variables and csr metadata probe request.
58+
// This method is perf sensitive any changes should be benchmarked.
59+
internal static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(RequestContext requestContext)
60+
{
61+
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
62+
string identityHeader = EnvironmentVariables.IdentityHeader;
63+
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint;
64+
string msiSecret = EnvironmentVariables.IdentityHeader;
65+
string msiEndpoint = EnvironmentVariables.MsiEndpoint;
66+
string msiSecretMachineLearning = EnvironmentVariables.MsiSecret;
67+
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint;
68+
69+
var logger = requestContext?.ServiceBundle?.ApplicationLogger;
70+
logger?.Info("[Managed Identity] Detecting managed identity source...");
71+
72+
if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader))
73+
{
74+
if (!string.IsNullOrEmpty(identityServerThumbprint))
75+
{
76+
logger?.Info("[Managed Identity] Service Fabric detected.");
77+
s_sourceName = ManagedIdentitySource.ServiceFabric;
78+
}
79+
else
80+
{
81+
logger?.Info("[Managed Identity] App Service detected.");
82+
s_sourceName = ManagedIdentitySource.AppService;
83+
}
84+
}
85+
else if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint))
86+
{
87+
logger?.Info("[Managed Identity] Machine Learning detected.");
88+
s_sourceName = ManagedIdentitySource.MachineLearning;
89+
}
90+
else if (!string.IsNullOrEmpty(msiEndpoint))
91+
{
92+
logger?.Info("[Managed Identity] Cloud Shell detected.");
93+
s_sourceName = ManagedIdentitySource.CloudShell;
94+
}
95+
else if (ValidateAzureArcEnvironment(identityEndpoint, imdsEndpoint, logger))
96+
{
97+
logger?.Info("[Managed Identity] Azure Arc detected.");
98+
s_sourceName = ManagedIdentitySource.AzureArc;
99+
}
100+
else if (await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(requestContext).ConfigureAwait(false))
101+
{
102+
logger?.Info("[Managed Identity] ImdsV2 detected.");
103+
s_sourceName = ManagedIdentitySource.ImdsV2;
104+
}
105+
else
106+
{
107+
s_sourceName = ManagedIdentitySource.DefaultToImds;
108+
}
109+
110+
return s_sourceName;
111+
}
112+
58113
// Detect managed identity source based on the availability of environment variables.
59114
// The result of this method is not cached because reading environment variables is cheap.
60115
// This method is perf sensitive any changes should be benchmarked.
61-
internal static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(ILoggerAdapter logger = null)
116+
[Obsolete("Use GetManagedIdentitySourceAsync(RequestContext) instead.")]
117+
internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter logger = null)
62118
{
63119
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
64120
string identityHeader = EnvironmentVariables.IdentityHeader;
@@ -98,11 +154,6 @@ internal static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
98154
logger?.Info("[Managed Identity] Azure Arc detected.");
99155
return ManagedIdentitySource.AzureArc;
100156
}
101-
else if (await ImdsV2ManagedIdentitySource.GetCsrMetadataAsync(_requestContext).ConfigureAwait(false))
102-
{
103-
logger?.Info("[Managed Identity] ImdsV2 detected.");
104-
return ManagedIdentitySource.ImdsV2;
105-
}
106157
else
107158
{
108159
return ManagedIdentitySource.DefaultToImds;

src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,28 @@ public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIden
5555
resource);
5656
}
5757

58+
/// <inheritdoc/>
59+
public async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync()
60+
{
61+
if (ManagedIdentityClient.s_sourceName != ManagedIdentitySource.None)
62+
{
63+
return ManagedIdentityClient.s_sourceName;
64+
}
65+
66+
// Create a temporary RequestContext for the CSR metadata probe request.
67+
var csrMetadataProbeRequestContext = new RequestContext(this.ServiceBundle, Guid.NewGuid(), null, CancellationToken.None);
68+
69+
return await ManagedIdentityClient.GetManagedIdentitySourceAsync(csrMetadataProbeRequestContext).ConfigureAwait(false);
70+
}
71+
5872
/// <summary>
5973
/// Detects and returns the managed identity source available on the environment.
6074
/// </summary>
6175
/// <returns>Managed identity source detected on the environment if any.</returns>
62-
public static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync()
76+
[Obsolete("Use GetManagedIdentitySourceAsync() instead as an instance method of ManagedIdentityApplication. This method is no longer static.")]
77+
public static ManagedIdentitySource GetManagedIdentitySource()
6378
{
64-
return await ManagedIdentityClient.GetManagedIdentitySourceAsync().ConfigureAwait(false);
79+
return ManagedIdentityClient.GetManagedIdentitySource();
6580
}
6681
}
6782
}

src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Shipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,7 @@ static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetK
970970
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetKrbCred(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> byte[]
971971
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> void
972972
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket, long logonId) -> void
973+
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySource() -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
973974
static Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.Create(Microsoft.Identity.Client.AppConfig.ManagedIdentityId managedIdentityId) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder
974975
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromBroker.get -> long
975976
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromCache.get -> long

src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
22
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
3+
Microsoft.Identity.Client.IManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
4+
Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
35
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
4-
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>

src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Shipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,7 @@ static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetK
970970
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetKrbCred(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> byte[]
971971
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> void
972972
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket, long logonId) -> void
973+
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySource() -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
973974
static Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.Create(Microsoft.Identity.Client.AppConfig.ManagedIdentityId managedIdentityId) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder
974975
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromBroker.get -> long
975976
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromCache.get -> long

src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
22
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
3+
Microsoft.Identity.Client.IManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
4+
Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
35
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
4-
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>

src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Shipped.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,7 @@ static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetK
943943
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.GetKrbCred(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> byte[]
944944
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket) -> void
945945
static Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicketManager.SaveToWindowsTicketCache(Microsoft.Identity.Client.Kerberos.KerberosSupplementalTicket ticket, long logonId) -> void
946+
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySource() -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
946947
static Microsoft.Identity.Client.ManagedIdentityApplicationBuilder.Create(Microsoft.Identity.Client.AppConfig.ManagedIdentityId managedIdentityId) -> Microsoft.Identity.Client.ManagedIdentityApplicationBuilder
947948
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromBroker.get -> long
948949
static Microsoft.Identity.Client.Metrics.TotalAccessTokensFromCache.get -> long

src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const Microsoft.Identity.Client.MsalError.InvalidManagedIdentityIdType = "invalid_managed_identity_id_type" -> string
22
const Microsoft.Identity.Client.MsalError.MissingManagedIdentityEnvVar = "missing_managed_identity_env_var" -> string
3+
Microsoft.Identity.Client.IManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
4+
Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>
35
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.ImdsV2 = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
4-
static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync() -> System.Threading.Tasks.Task<Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource>

0 commit comments

Comments
 (0)