diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java index c6ffb45a..46b3bf0a 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/AzureArcManagedIdentitySource.java @@ -26,6 +26,7 @@ class AzureArcManagedIdentitySource extends AbstractManagedIdentitySource{ private static final String FILE_EXTENSION = ".key"; private static final int MAX_FILE_SIZE_BYTES = 4096; private static final String WWW_AUTHENTICATE_HEADER = "WWW-Authenticate"; + private static final String FALLBACK_IDENTITY_ENDPOINT = "http://127.0.0.1:40342/metadata/identity/oauth2/token"; private final URI MSI_ENDPOINT; @@ -33,6 +34,12 @@ static AbstractManagedIdentitySource create(MsalRequest msalRequest, ServiceBund { IEnvironmentVariables environmentVariables = getEnvironmentVariables(); String identityEndpoint = environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT); + + if (StringHelper.isNullOrBlank(identityEndpoint)) { + LOG.info("[Managed Identity] Azure Arc was detected through file based detection but the environment variables were not found. Defaulting to known azure arc endpoint."); + identityEndpoint = FALLBACK_IDENTITY_ENDPOINT; + } + String imdsEndpoint = environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT); URI validatedUri = validateAndGetUri(identityEndpoint, imdsEndpoint); diff --git a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java index 9b118c1e..4974930b 100644 --- a/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java +++ b/msal4j-sdk/src/main/java/com/microsoft/aad/msal4j/ManagedIdentityClient.java @@ -6,11 +6,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.File; + /** * Class to initialize a managed identity and identify the service. */ class ManagedIdentityClient { private static final Logger LOG = LoggerFactory.getLogger(ManagedIdentityClient.class); + private static String WINDOWS_HIMDS_FILEPATH = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe"; + private static String LINUX_HIMDS_FILEPATH = "/opt/azcmagent/bin/himds"; static ManagedIdentitySourceType getManagedIdentitySource() { IEnvironmentVariables environmentVariables = AbstractManagedIdentitySource.getEnvironmentVariables(); @@ -24,8 +28,7 @@ static ManagedIdentitySourceType getManagedIdentitySource() { } } else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.MSI_ENDPOINT))) { return ManagedIdentitySourceType.CLOUD_SHELL; - } else if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) && - !StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) { + } else if (validateAzureArcEnvironment(environmentVariables)) { return ManagedIdentitySourceType.AZURE_ARC; } else { return ManagedIdentitySourceType.DEFAULT_TO_IMDS; @@ -44,6 +47,10 @@ static ManagedIdentitySourceType getManagedIdentitySource() { } } + ManagedIdentityClient(){ + // Default constructor for testing + } + ManagedIdentityResponse getManagedIdentityResponse(ManagedIdentityParameters parameters) { return managedIdentitySource.getManagedIdentityResponse(parameters); } @@ -65,4 +72,44 @@ private static AbstractManagedIdentitySource createManagedIdentitySource(MsalReq return new IMDSManagedIdentitySource(msalRequest, serviceBundle); } } + + static boolean validateAzureArcEnvironment(IEnvironmentVariables environmentVariables) { + if (!StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IDENTITY_ENDPOINT)) && + !StringHelper.isNullOrBlank(environmentVariables.getEnvironmentVariable(Constants.IMDS_ENDPOINT))) { + LOG.info("[Managed Identity] Azure Arc managed identity is available through environment variables."); + return true; + } + + String osName = System.getProperty("os.name").toLowerCase(); + + if (osName.contains("windows")) { + File windowsFile = new File(WINDOWS_HIMDS_FILEPATH); + if (windowsFile.exists()) { + LOG.info("[Managed Identity] Azure Arc managed identity is available through file detection."); + return true; + } + } else if (osName.contains("linux")) { + File linuxFile = new File(LINUX_HIMDS_FILEPATH); + if (linuxFile.exists()) { + LOG.info("[Managed Identity] Azure Arc managed identity is available through file detection."); + return true; + } + } else { + LOG.warn("[Managed Identity] Azure Arc managed identity cannot be configured on a platform other than Windows and Linux."); + } + + LOG.info("[Managed Identity] Azure Arc managed identity is not available."); + return false; + } + + //These set methods should be used solely for automated testing. + // -The file paths are not normally customizable in this flow, as they should exist in an Azure Arc environment. + // -However, unit tests need some way to adjust them as part of mocking the environment and creating temporary files. + void setWindowsFilePath(String filePath) { + WINDOWS_HIMDS_FILEPATH = filePath; + } + + void setLinuxFilePath(String filePath) { + LINUX_HIMDS_FILEPATH = filePath; + } } \ No newline at end of file diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/EnvironmentVariablesHelper.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/EnvironmentVariablesHelper.java index d337db0f..4c212187 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/EnvironmentVariablesHelper.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/EnvironmentVariablesHelper.java @@ -44,4 +44,8 @@ public class EnvironmentVariablesHelper implements IEnvironmentVariables { public String getEnvironmentVariable(String envVariable) { return mockedEnvironmentVariables.get(envVariable); } + + void setEnvironmentVariable(String key, String value) { + mockedEnvironmentVariables.put(key, value); + } } diff --git a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java index 8ea74e5a..4f25af93 100644 --- a/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java +++ b/msal4j-sdk/src/test/java/com/microsoft/aad/msal4j/ManagedIdentityTests.java @@ -4,15 +4,19 @@ package com.microsoft.aad.msal4j; import com.nimbusds.oauth2.sdk.util.URLUtils; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.junit.jupiter.MockitoExtension; +import java.io.File; +import java.io.IOException; import java.net.SocketException; import java.nio.file.Path; import java.nio.file.Paths; @@ -32,6 +36,7 @@ import static java.util.Collections.*; import static org.apache.http.HttpStatus.*; import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -961,4 +966,127 @@ private void assertMsalServiceException(String errorCode, String message) throws assertTrue(ex.getMessage().contains(message)); } } + + + @Nested + class OSTests { + + @TempDir + Path tempDir; + + private ManagedIdentityClient client = new ManagedIdentityClient(); + private EnvironmentVariablesHelper envVars; + + @BeforeEach + void setUp() { + envVars = new EnvironmentVariablesHelper(AZURE_ARC, azureArcEndpoint); + ManagedIdentityApplication.setEnvironmentVariables(envVars); + } + + @Test + void validateAzureArc_WithCorrectEnvironmentVariables() { + // Set environment variables for Azure Arc + envVars.setEnvironmentVariable(Constants.IDENTITY_ENDPOINT, "https://example.com"); + envVars.setEnvironmentVariable(Constants.IMDS_ENDPOINT, "https://example2.com"); + + // Test validation + boolean result = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + assertTrue(result, "Azure Arc should be validated with correct environment variables"); + } + + @Test + void validateAzureArc_WithMissingEnvironmentVariables() { + // Only set one environment variable + envVars.setEnvironmentVariable(Constants.IDENTITY_ENDPOINT, "https://example.com"); + envVars.setEnvironmentVariable(Constants.IMDS_ENDPOINT, null); + + boolean result = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + assertFalse(result, "Azure Arc validation should fail with missing environment variables"); + } + + @Test + void validateAzureArc_WindowsFileExists() throws IOException { + // Determine OS and skip if not Windows + String osName = System.getProperty("os.name").toLowerCase(); + assumeTrue(osName.contains("windows"), "Test only runs on Windows"); + + // Create temp file to simulate Azure Arc file on Windows + File windowsFile = tempDir.resolve("himds.key").toFile(); + assertTrue(windowsFile.createNewFile(), "Failed to create test file"); + + // Set custom file path for testing + client.setWindowsFilePath(windowsFile.getAbsolutePath()); + + boolean result = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + assertTrue(result, "Azure Arc should be validated when Windows file exists"); + } + + @Test + void validateAzureArc_LinuxFileExists() throws IOException { + // Determine OS and skip if not Linux + String osName = System.getProperty("os.name").toLowerCase(); + assumeTrue(osName.contains("linux"), "Test only runs on Linux"); + + // Create temp file to simulate Azure Arc file on Linux + File linuxFile = tempDir.resolve("himds.sock").toFile(); + assertTrue(linuxFile.createNewFile(), "Failed to create test file"); + + // Set custom file path for testing + client.setLinuxFilePath(linuxFile.getAbsolutePath()); + + boolean result = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + assertTrue(result, "Azure Arc should be validated when Linux file exists"); + } + + @Test + void validateAzureArc_FilesNotExist() { + envVars.setEnvironmentVariable(Constants.IMDS_ENDPOINT, null); + + // Set non-existent file paths + client.setWindowsFilePath(tempDir.resolve("nonexistent-himds.key").toString()); + client.setLinuxFilePath(tempDir.resolve("nonexistent-himds.sock").toString()); + + boolean result = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + assertFalse(result, "Azure Arc validation should fail when files don't exist"); + } + + @Test + void validateAzureArc_CrossPlatformTest() throws IOException { + // This test creates both Windows and Linux files to test the method + // independent of platform in unit tests + envVars.setEnvironmentVariable(Constants.IMDS_ENDPOINT, null); + // Create both temp files + File windowsFile = tempDir.resolve("himds.key").toFile(); + File linuxFile = tempDir.resolve("himds.sock").toFile(); + + // Set custom file paths for testing + client.setWindowsFilePath(windowsFile.getAbsolutePath()); + client.setLinuxFilePath(linuxFile.getAbsolutePath()); + + // Test with no files + boolean noFilesResult = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + assertFalse(noFilesResult, "Validation should fail when no files exist"); + + // Create Windows file + assertTrue(windowsFile.createNewFile(), "Failed to create Windows test file"); + + // The result depends on OS - but at least one path should be checked + boolean windowsFileResult = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + // Create Linux file + assertTrue(linuxFile.createNewFile(), "Failed to create Linux test file"); + + // Now with both files existing, the result should depend on OS + boolean bothFilesResult = ManagedIdentityClient.validateAzureArcEnvironment(envVars); + + // At least one of the tests with files should pass + assertTrue(windowsFileResult || bothFilesResult, + "At least one validation should succeed when test files exist"); + } + } }