@@ -19,46 +19,102 @@ internal class ManagedIdentityClient
19
19
{
20
20
private const string WindowsHimdsFilePath = "%Programfiles%\\ AzureConnectedMachineAgent\\ himds.exe" ;
21
21
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds" ;
22
- private static RequestContext _requestContext ;
23
- private static AbstractManagedIdentity _identitySource ;
22
+ private AbstractManagedIdentity s_identitySource ;
23
+ public static ManagedIdentitySource s_sourceName = ManagedIdentitySource . None ;
24
24
25
- public ManagedIdentityClient ( RequestContext requestContext )
25
+ internal async Task < ManagedIdentityResponse > SendTokenRequestForManagedIdentityAsync (
26
+ RequestContext requestContext ,
27
+ AcquireTokenForManagedIdentityParameters parameters ,
28
+ CancellationToken cancellationToken )
26
29
{
27
- _requestContext = requestContext ;
28
- }
29
-
30
- internal async Task < ManagedIdentityResponse > SendTokenRequestForManagedIdentityAsync ( AcquireTokenForManagedIdentityParameters parameters , CancellationToken cancellationToken )
31
- {
32
- if ( _identitySource == null )
30
+ if ( s_identitySource == null )
33
31
{
34
- using ( _requestContext . Logger . LogMethodDuration ( ) )
32
+ using ( requestContext . Logger . LogMethodDuration ( ) )
35
33
{
36
- _identitySource = await SelectManagedIdentitySourceAsync ( ) . ConfigureAwait ( false ) ;
34
+ s_identitySource = await SelectManagedIdentitySourceAsync ( requestContext ) . ConfigureAwait ( false ) ;
37
35
}
38
36
}
39
37
40
- return await _identitySource . AuthenticateAsync ( parameters , cancellationToken ) . ConfigureAwait ( false ) ;
38
+ return await s_identitySource . AuthenticateAsync ( parameters , cancellationToken ) . ConfigureAwait ( false ) ;
41
39
}
42
40
43
41
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
44
- private static async Task < AbstractManagedIdentity > SelectManagedIdentitySourceAsync ( )
42
+ private static async Task < AbstractManagedIdentity > SelectManagedIdentitySourceAsync ( RequestContext requestContext )
45
43
{
46
- return await GetManagedIdentitySourceAsync ( _requestContext . Logger ) . ConfigureAwait ( false ) switch
47
- {
48
- ManagedIdentitySource . ServiceFabric => ServiceFabricManagedIdentitySource . Create ( _requestContext ) ,
49
- ManagedIdentitySource . AppService => AppServiceManagedIdentitySource . Create ( _requestContext ) ,
50
- ManagedIdentitySource . MachineLearning => MachineLearningManagedIdentitySource . Create ( _requestContext ) ,
51
- ManagedIdentitySource . CloudShell => CloudShellManagedIdentitySource . Create ( _requestContext ) ,
52
- ManagedIdentitySource . AzureArc => AzureArcManagedIdentitySource . Create ( _requestContext ) ,
53
- ManagedIdentitySource . ImdsV2 => ImdsV2ManagedIdentitySource . Create ( _requestContext ) ,
54
- _ => new ImdsManagedIdentitySource ( _requestContext )
44
+ var source = ( s_sourceName != ManagedIdentitySource . None ) ? s_sourceName : await GetManagedIdentitySourceAsync ( requestContext ) . ConfigureAwait ( false ) ;
45
+ return source switch
46
+ {
47
+ ManagedIdentitySource . ServiceFabric => ServiceFabricManagedIdentitySource . Create ( requestContext ) ,
48
+ ManagedIdentitySource . AppService => AppServiceManagedIdentitySource . Create ( requestContext ) ,
49
+ ManagedIdentitySource . MachineLearning => MachineLearningManagedIdentitySource . Create ( requestContext ) ,
50
+ ManagedIdentitySource . CloudShell => CloudShellManagedIdentitySource . Create ( requestContext ) ,
51
+ ManagedIdentitySource . AzureArc => AzureArcManagedIdentitySource . Create ( requestContext ) ,
52
+ ManagedIdentitySource . ImdsV2 => ImdsV2ManagedIdentitySource . Create ( requestContext ) ,
53
+ _ => new ImdsManagedIdentitySource ( requestContext )
55
54
} ;
56
55
}
57
56
57
+ // Detect managed identity source based on the availability of environment variables and csr metadata probe request.
58
+ // This method is perf sensitive any changes should be benchmarked.
59
+ internal static async Task < ManagedIdentitySource > GetManagedIdentitySourceAsync ( RequestContext requestContext )
60
+ {
61
+ string identityEndpoint = EnvironmentVariables . IdentityEndpoint ;
62
+ string identityHeader = EnvironmentVariables . IdentityHeader ;
63
+ string identityServerThumbprint = EnvironmentVariables . IdentityServerThumbprint ;
64
+ string msiSecret = EnvironmentVariables . IdentityHeader ;
65
+ string msiEndpoint = EnvironmentVariables . MsiEndpoint ;
66
+ string msiSecretMachineLearning = EnvironmentVariables . MsiSecret ;
67
+ string imdsEndpoint = EnvironmentVariables . ImdsEndpoint ;
68
+
69
+ var logger = requestContext ? . ServiceBundle ? . ApplicationLogger ;
70
+ logger ? . Info ( "[Managed Identity] Detecting managed identity source..." ) ;
71
+
72
+ if ( ! string . IsNullOrEmpty ( identityEndpoint ) && ! string . IsNullOrEmpty ( identityHeader ) )
73
+ {
74
+ if ( ! string . IsNullOrEmpty ( identityServerThumbprint ) )
75
+ {
76
+ logger ? . Info ( "[Managed Identity] Service Fabric detected." ) ;
77
+ s_sourceName = ManagedIdentitySource . ServiceFabric ;
78
+ }
79
+ else
80
+ {
81
+ logger ? . Info ( "[Managed Identity] App Service detected." ) ;
82
+ s_sourceName = ManagedIdentitySource . AppService ;
83
+ }
84
+ }
85
+ else if ( ! string . IsNullOrEmpty ( msiSecretMachineLearning ) && ! string . IsNullOrEmpty ( msiEndpoint ) )
86
+ {
87
+ logger ? . Info ( "[Managed Identity] Machine Learning detected." ) ;
88
+ s_sourceName = ManagedIdentitySource . MachineLearning ;
89
+ }
90
+ else if ( ! string . IsNullOrEmpty ( msiEndpoint ) )
91
+ {
92
+ logger ? . Info ( "[Managed Identity] Cloud Shell detected." ) ;
93
+ s_sourceName = ManagedIdentitySource . CloudShell ;
94
+ }
95
+ else if ( ValidateAzureArcEnvironment ( identityEndpoint , imdsEndpoint , logger ) )
96
+ {
97
+ logger ? . Info ( "[Managed Identity] Azure Arc detected." ) ;
98
+ s_sourceName = ManagedIdentitySource . AzureArc ;
99
+ }
100
+ else if ( await ImdsV2ManagedIdentitySource . GetCsrMetadataAsync ( requestContext ) . ConfigureAwait ( false ) )
101
+ {
102
+ logger ? . Info ( "[Managed Identity] ImdsV2 detected." ) ;
103
+ s_sourceName = ManagedIdentitySource . ImdsV2 ;
104
+ }
105
+ else
106
+ {
107
+ s_sourceName = ManagedIdentitySource . DefaultToImds ;
108
+ }
109
+
110
+ return s_sourceName ;
111
+ }
112
+
58
113
// Detect managed identity source based on the availability of environment variables.
59
114
// The result of this method is not cached because reading environment variables is cheap.
60
115
// This method is perf sensitive any changes should be benchmarked.
61
- internal static async Task < ManagedIdentitySource > GetManagedIdentitySourceAsync ( ILoggerAdapter logger = null )
116
+ [ Obsolete ( "Use GetManagedIdentitySourceAsync(RequestContext) instead." ) ]
117
+ internal static ManagedIdentitySource GetManagedIdentitySource ( ILoggerAdapter logger = null )
62
118
{
63
119
string identityEndpoint = EnvironmentVariables . IdentityEndpoint ;
64
120
string identityHeader = EnvironmentVariables . IdentityHeader ;
@@ -98,11 +154,6 @@ internal static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
98
154
logger ? . Info ( "[Managed Identity] Azure Arc detected." ) ;
99
155
return ManagedIdentitySource . AzureArc ;
100
156
}
101
- else if ( await ImdsV2ManagedIdentitySource . GetCsrMetadataAsync ( _requestContext ) . ConfigureAwait ( false ) )
102
- {
103
- logger ? . Info ( "[Managed Identity] ImdsV2 detected." ) ;
104
- return ManagedIdentitySource . ImdsV2 ;
105
- }
106
157
else
107
158
{
108
159
return ManagedIdentitySource . DefaultToImds ;
0 commit comments