Skip to content

Auth: Support Azure Entra (Event Hub with Kafka Protocol) #530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
<version>2.1.0</version>
</dependency>

<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-identity</artifactId>
<version>1.13.0</version>
</dependency>

<dependency>
<groupId>org.apache.avro</groupId>
<artifactId>avro</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package io.kafbat.ui.sasl.azure.entra;

import static org.apache.kafka.clients.CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.DefaultAzureCredentialBuilder;
import java.net.URI;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureEntraLoginCallbackHandler implements AuthenticateCallbackHandler {

private static final Logger LOGGER = LoggerFactory.getLogger(AzureEntraLoginCallbackHandler.class);

private static final Duration ACCESS_TOKEN_REQUEST_BLOCK_TIME = Duration.ofSeconds(10);

private static final int ACCESS_TOKEN_REQUEST_MAX_RETRIES = 6;

private static final String TOKEN_AUDIENCE_FORMAT = "%s://%s/.default";

static TokenCredential tokenCredential = new DefaultAzureCredentialBuilder().build();

private TokenRequestContext tokenRequestContext;

@Override
public void configure(
Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
tokenRequestContext = buildTokenRequestContext(configs);
}

private TokenRequestContext buildTokenRequestContext(Map<String, ?> configs) {
URI uri = buildEventHubsServerUri(configs);
String tokenAudience = buildTokenAudience(uri);

TokenRequestContext request = new TokenRequestContext();
request.addScopes(tokenAudience);
return request;
}

private URI buildEventHubsServerUri(Map<String, ?> configs) {
final List<String> bootstrapServers = (List<String>) configs.get(BOOTSTRAP_SERVERS_CONFIG);

if (null == bootstrapServers) {
final String message = BOOTSTRAP_SERVERS_CONFIG + " is missing from the Kafka configuration.";
LOGGER.error(message);
throw new IllegalArgumentException(message);
}

if (bootstrapServers.size() != 1) {
final String message =
BOOTSTRAP_SERVERS_CONFIG
+ " contains multiple bootstrap servers. Only a single bootstrap server is supported.";
LOGGER.error(message);
throw new IllegalArgumentException(message);
}

return URI.create("https://" + bootstrapServers.get(0));
}

private String buildTokenAudience(URI uri) {
return String.format(TOKEN_AUDIENCE_FORMAT, uri.getScheme(), uri.getHost());
}

@Override
public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback oauthCallback) {
handleOAuthCallback(oauthCallback);
} else {
throw new UnsupportedCallbackException(callback);
}
}
}

private void handleOAuthCallback(OAuthBearerTokenCallback oauthCallback) {
try {
final OAuthBearerToken token = tokenCredential
.getToken(tokenRequestContext)
.map(AzureEntraOAuthBearerTokenImpl::new)
.timeout(ACCESS_TOKEN_REQUEST_BLOCK_TIME)
.doOnError(e -> LOGGER.warn("Failed to acquire Azure token for Event Hub Authentication. Retrying.", e))
.retry(ACCESS_TOKEN_REQUEST_MAX_RETRIES)
.block();

oauthCallback.token(token);
} catch (final RuntimeException e) {
final String message =
"Failed to acquire Azure token for Event Hub Authentication. "
+ "Please ensure valid Azure credentials are configured.";
LOGGER.error(message, e);
oauthCallback.error("invalid_grant", message, null);
}
}

public void close() {
// NOOP
}

void setTokenCredential(final TokenCredential tokenCredential) {
AzureEntraLoginCallbackHandler.tokenCredential = tokenCredential;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.kafbat.ui.sasl.azure.entra;

import com.azure.core.credential.AccessToken;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;

public class AzureEntraOAuthBearerTokenImpl implements OAuthBearerToken {

private final AccessToken accessToken;

private final JWTClaimsSet claims;

public AzureEntraOAuthBearerTokenImpl(AccessToken accessToken) {
this.accessToken = accessToken;

try {
claims = JWTParser.parse(accessToken.getToken()).getJWTClaimsSet();
} catch (ParseException exception) {
throw new SaslAuthenticationException("Unable to parse the access token", exception);
}
}

@Override
public String value() {
return accessToken.getToken();
}

@Override
public Long startTimeMs() {
return claims.getIssueTime().getTime();
}

@Override
public long lifetimeMs() {
return claims.getExpirationTime().getTime();
}

@Override
public Set<String> scope() {
// Referring to
// https://docs.microsoft.com/azure/active-directory/develop/access-tokens#payload-claims, the
// scp
// claim is a String which is presented as a space separated list.
return Optional.ofNullable(claims.getClaim("scp"))
.map(s -> Arrays.stream(((String) s).split(" ")).collect(Collectors.toSet()))
.orElse(null);
}

@Override
public String principalName() {
return (String) claims.getClaim("upn");
}

public boolean isExpired() {
return accessToken.isExpired();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package io.kafbat.ui.sasl.azure.entra;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

@ExtendWith(MockitoExtension.class)
public class AzureEntraLoginCallbackHandlerTest {

// These are not real tokens. It was generated using fake values with an invalid signature,
// so it is safe to store here.
private static final String VALID_SAMPLE_TOKEN =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IjlHbW55RlBraGMzaE91UjIybXZTdmduTG83WSIsImtpZCI6IjlHbW55"
+ "RlBraGMzaE91UjIybXZTdmduTG83WSJ9.eyJhdWQiOiJodHRwczovL3NhbXBsZS5zZXJ2aWNlYnVzLndpbmRvd3MubmV0IiwiaX"
+ "NzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvc2FtcGxlLyIsImlhdCI6MTY5ODQxNTkxMiwibmJmIjoxNjk4NDE1OTEzLCJleH"
+ "AiOjE2OTg0MTU5MTQsImFjciI6IjEiLCJhaW8iOiJzYW1wbGUtYWlvIiwiYW1yIjpbXSwiYXBwaWQiOiJzYW1wbGUtYXBwLWlkIi"
+ "wiYXBwaWRhY3IiOiIwIiwiZmFtaWx5X25hbWUiOiJTYW1wbGUiLCJnaXZlbl9uYW1lIjoiU2FtcGxlIiwiZ3JvdXBzIjpbXSwiaX"
+ "BhZGRyIjoiMTI3LjAuMC4xIiwibmFtZSI6IlNhbXBsZSBOYW1lIiwib2lkIjoic2FtcGxlLW9pZCIsIm9ucHJlbV9zaWQiOiJzYW"
+ "1wbGUtb25wcmVtX3NpZCIsInB1aWQiOiJzYW1wbGUtcHVpZCIsInJoIjoic2FtcGxlLXJoIiwic2NwIjoiZXZlbnRfaHViIHN0b3"
+ "JhZ2VfYWNjb3VudCIsInN1YiI6IlNhbXBsZSBTdWJqZWN0IiwidGlkIjoic2FtcGxlLXRpZCIsInVuaXF1ZV9uYW1lIjoic2FtcG"
+ "xlQG1pY3Jvc29mdC5jb20iLCJ1cG4iOiJzYW1wbGVAbWljcm9zb2Z0LmNvbSIsInV0aSI6InNhbXBsZS11dGkiLCJ2ZXIiOiIxLj"
+ "AiLCJ3aWRzIjpbXX0.DC_guYOsDlRc5GsXE39dn_zlBX54_Y8_mDTLXLgienl9dPMX5RE2X1QXGXA9ukZtptMzP_0wcoqDDjNrys"
+ "GrNhztyeOr0YSeMMFq2NQ5vMBzLapwONwsnv55Hn0jOje9cqnMf43z1LHI6q6-rIIRz-SiTuoYUgOTxzFftpt-7FSqLjQpYEH7bL"
+ "p-0yIU_aJUSb5HQTJbtYYOb54hsZ6VXpaiZ013qGtKODbHTG37kdoIw2MPn66CxanLZKeZM31IVxC-duAqxDgK4O2Ne6xRZRIPW1"
+ "yt61QnZutWTJ4bAyhmplym3OWZ369cyiSJek0uyS5tibXeCYG4Kk8UQSFcsyfwgOsD0xvvcXcLexcUcEekoNBj6ixDhWssFzhC8T"
+ "Npy8-QKNe_Tp6qHzJdI6OV71jpDkGvcmseLHC9GOxBWB0IdYbePTFK-rz2dkN3uMUiFwQJvEbORsq1IaQXj2esT0F7sMfqzWQF9h"
+ "koVy4mJg_auvrZlnQkNPdLHfCacU33ZPwtuSS6b-0XolbxZ5DlJ4p1OJPeHl2xsi61qiHuCBsmnkLNtHmyxNTXGs7xc4dEQokaCK"
+ "-FB_lzC3D4mkJMxKWopQGXnQtizaZjyclGpiUFs3mEauxC7RpsbanitxPFs7FK3mY0MQJk9JNVi1oM-8qfEp8nYT2DwFBhLcIp2z"
+ "Q";

@Mock
private OAuthBearerTokenCallback oauthBearerTokenCallBack;

@Mock
private OAuthBearerToken oauthBearerToken;

@Mock
private TokenCredential tokenCredential;

@Mock
private AccessToken accessToken;

private AzureEntraLoginCallbackHandler azureEntraLoginCallbackHandler;

@BeforeEach
public void beforeEach() {
azureEntraLoginCallbackHandler = new AzureEntraLoginCallbackHandler();
azureEntraLoginCallbackHandler.setTokenCredential(tokenCredential);
}

@Test
public void shouldProvideTokenToCallbackWithSuccessfulTokenRequest()
throws UnsupportedCallbackException {
final Map<String, Object> configs = new HashMap<>();
configs.put(
"bootstrap.servers",
List.of("test-eh.servicebus.windows.net:9093"));

when(tokenCredential.getToken(any(TokenRequestContext.class))).thenReturn(Mono.just(accessToken));
when(accessToken.getToken()).thenReturn(VALID_SAMPLE_TOKEN);

azureEntraLoginCallbackHandler.configure(configs, null, null);
azureEntraLoginCallbackHandler.handle(new Callback[] {oauthBearerTokenCallBack});

final ArgumentCaptor<TokenRequestContext> contextCaptor =
ArgumentCaptor.forClass(TokenRequestContext.class);
final ArgumentCaptor<OAuthBearerToken> tokenCaptor =
ArgumentCaptor.forClass(OAuthBearerToken.class);

verify(tokenCredential, times(1)).getToken(contextCaptor.capture());
verify(oauthBearerTokenCallBack, times(0)).error(anyString(), anyString(), anyString());
verify(oauthBearerTokenCallBack, times(1)).token(tokenCaptor.capture());

final TokenRequestContext tokenRequestContext = contextCaptor.getValue();
assertThat(tokenRequestContext, is(notNullValue()));
assertThat(
tokenRequestContext.getScopes(),
is(List.of("https://test-eh.servicebus.windows.net/.default")));
assertThat(tokenRequestContext.getClaims(), is(nullValue()));
assertThat(tokenRequestContext.getTenantId(), is(nullValue()));
assertFalse(tokenRequestContext.isCaeEnabled());

assertThat(tokenCaptor.getValue(), is(notNullValue()));
assertEquals(VALID_SAMPLE_TOKEN, tokenCaptor.getValue().value());
}

@Test
public void shouldProvideErrorToCallbackWithTokenError() throws UnsupportedCallbackException {
final Map<String, Object> configs = new HashMap<>();
configs.put(
"bootstrap.servers",
List.of("test-eh.servicebus.windows.net:9093"));

when(tokenCredential.getToken(any(TokenRequestContext.class)))
.thenThrow(new RuntimeException("failed to acquire token"));

azureEntraLoginCallbackHandler.configure(configs, null, null);
azureEntraLoginCallbackHandler.handle(new Callback[] {oauthBearerTokenCallBack});

verify(oauthBearerTokenCallBack, times(1))
.error(
"invalid_grant",
"Failed to acquire Azure token for Event Hub Authentication. "
+ "Please ensure valid Azure credentials are configured.",
null);
verify(oauthBearerTokenCallBack, times(0)).token(any());
}

@Test
public void shouldThrowExceptionWithNullBootstrapServers() {
final Map<String, Object> configs = new HashMap<>();

assertThrows(IllegalArgumentException.class, () -> azureEntraLoginCallbackHandler.configure(
configs, null, null));
}

@Test
public void shouldThrowExceptionWithMultipleBootstrapServers() {
final Map<String, Object> configs = new HashMap<>();
configs.put("bootstrap.servers", List.of("server1", "server2"));

assertThrows(IllegalArgumentException.class, () -> azureEntraLoginCallbackHandler.configure(
configs, null, null));
}

@Test
public void shouldThrowExceptionWithUnsupportedCallback() {
assertThrows(UnsupportedCallbackException.class, () -> azureEntraLoginCallbackHandler.handle(
new Callback[] {mock(Callback.class)}));
}

@Test
public void shouldDoNothingOnClose() {
azureEntraLoginCallbackHandler.close();
}

@Test
public void shouldSupportDefaultConstructor() {
new AzureEntraLoginCallbackHandler();
}
}
Loading
Loading