Skip to content

Auth refactoring #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void testResolveTokenSuccessfully(VaultEnvironment vaultEnvironment) throws Exce

final var jwtToken =
new JsonwebtokenResolver(
new JwksKeyLocator.Builder()
JwksKeyLocator.builder()
.jwksUri(vaultEnvironment.jwksUri())
.connectTimeout(Duration.ofSeconds(3))
.requestTimeout(Duration.ofSeconds(3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class VaultServiceTokenTests {
void testGetServiceTokenUsingWrongCredentials(VaultEnvironment vaultEnvironment)
throws Exception {
final var serviceTokenSupplier =
new VaultServiceTokenSupplier.Builder()
VaultServiceTokenSupplier.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(randomAlphabetic(16)))
.serviceRole(randomAlphabetic(16))
Expand All @@ -52,7 +52,7 @@ void testGetNonExistingServiceToken(VaultEnvironment vaultEnvironment) throws Ex
final var nonExistingServiceRole = "non-existing-role-" + System.currentTimeMillis();

final var serviceTokenSupplier =
new VaultServiceTokenSupplier.Builder()
VaultServiceTokenSupplier.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(vaultEnvironment.login()))
.serviceRole(nonExistingServiceRole)
Expand All @@ -76,7 +76,7 @@ void testGetServiceTokenByWrongServiceRole(VaultEnvironment vaultEnvironment) th
final var serviceRole2 = "role2-" + now;
final var serviceRole3 = "role3-" + now;

new VaultServiceRolesInstaller.Builder()
VaultServiceRolesInstaller.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(vaultEnvironment.login()))
.keyNameSupplier(() -> "key-" + now)
Expand All @@ -94,7 +94,7 @@ void testGetServiceTokenByWrongServiceRole(VaultEnvironment vaultEnvironment) th
.install();

final var serviceTokenSupplier =
new VaultServiceTokenSupplier.Builder()
VaultServiceTokenSupplier.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(vaultEnvironment.login()))
.serviceRole(serviceRole1)
Expand All @@ -117,7 +117,7 @@ void testGetServiceTokenSuccessfully(VaultEnvironment vaultEnvironment) throws E
final var serviceRole = "role-" + now;
final var tags = Map.of("type", "ops", "ns", "develop");

new VaultServiceRolesInstaller.Builder()
VaultServiceRolesInstaller.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(vaultEnvironment.login()))
.keyNameSupplier(() -> "key-" + now)
Expand All @@ -128,7 +128,7 @@ void testGetServiceTokenSuccessfully(VaultEnvironment vaultEnvironment) throws E
.install();

final var serviceTokenSupplier =
new VaultServiceTokenSupplier.Builder()
VaultServiceTokenSupplier.builder()
.vaultAddress(vaultEnvironment.vaultAddr())
.vaultTokenSupplier(() -> completedFuture(vaultEnvironment.login()))
.serviceRole(serviceRole)
Expand All @@ -142,7 +142,7 @@ void testGetServiceTokenSuccessfully(VaultEnvironment vaultEnvironment) throws E

final var jwtToken =
new JsonwebtokenResolver(
new JwksKeyLocator.Builder().jwksUri(vaultEnvironment.jwksUri()).build())
JwksKeyLocator.builder().jwksUri(vaultEnvironment.jwksUri()).build())
.resolve(serviceToken)
.get(3, TimeUnit.SECONDS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.Locator;
import java.security.Key;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -14,7 +15,7 @@ public class JsonwebtokenResolver implements JwtTokenResolver {
private final Locator<Key> keyLocator;

public JsonwebtokenResolver(Locator<Key> keyLocator) {
this.keyLocator = keyLocator;
this.keyLocator = Objects.requireNonNull(keyLocator, "keyLocator");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

Expand All @@ -41,12 +42,16 @@ public class JwksKeyLocator extends LocatorAdapter<Key> {
private final ReentrantLock cleanupLock = new ReentrantLock();

private JwksKeyLocator(Builder builder) {
this.jwksUri = builder.jwksUri;
this.connectTimeout = builder.connectTimeout;
this.requestTimeout = builder.requestTimeout;
this.jwksUri = Objects.requireNonNull(builder.jwksUri, "jwksUri");
this.connectTimeout = Objects.requireNonNull(builder.connectTimeout, "connectTimeout");
this.requestTimeout = Objects.requireNonNull(builder.requestTimeout, "requestTimeout");
this.keyTtl = builder.keyTtl;
}

public static Builder builder() {
return new Builder();
}

@Override
protected Key locate(JwsHeader header) {
try {
Expand Down Expand Up @@ -160,6 +165,8 @@ public static class Builder {
private Duration requestTimeout = Duration.ofSeconds(10);
private int keyTtl = 60 * 1000;

private Builder() {}

/**
* Setter for JWKS URI. The JWKS URI typically follows a well-known pattern, such as
* https://server_domain/.well-known/jwks.json. This endpoint is a read-only URL that responds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -49,32 +50,39 @@ public class VaultServiceRolesInstaller {
private final TimeUnit timeUnit;

private VaultServiceRolesInstaller(Builder builder) {
this.vaultAddress = builder.vaultAddress;
this.vaultTokenSupplier = builder.vaultTokenSupplier;
this.keyNameSupplier = builder.keyNameSupplier;
this.roleNameBuilder = builder.roleNameBuilder;
this.serviceRolesSources = builder.serviceRolesSources;
this.keyAlgorithm = builder.keyAlgorithm;
this.keyRotationPeriod = builder.keyRotationPeriod;
this.keyVerificationTtl = builder.keyVerificationTtl;
this.roleTtl = builder.roleTtl;
this.vaultAddress = Objects.requireNonNull(builder.vaultAddress, "vaultAddress");
this.vaultTokenSupplier =
Objects.requireNonNull(builder.vaultTokenSupplier, "vaultTokenSupplier");
this.keyNameSupplier = Objects.requireNonNull(builder.keyNameSupplier, "keyNameSupplier");
this.roleNameBuilder = Objects.requireNonNull(builder.roleNameBuilder, "roleNameBuilder");
this.serviceRolesSources =
Objects.requireNonNull(builder.serviceRolesSources, "serviceRolesSources");
this.keyAlgorithm = Objects.requireNonNull(builder.keyAlgorithm, "keyAlgorithm");
this.keyRotationPeriod = Objects.requireNonNull(builder.keyRotationPeriod, "keyRotationPeriod");
this.keyVerificationTtl =
Objects.requireNonNull(builder.keyVerificationTtl, "keyVerificationTtl");
this.roleTtl = Objects.requireNonNull(builder.roleTtl, "roleTtl");
this.timeout = builder.timeout;
this.timeUnit = builder.timeUnit;
}

public static Builder builder() {
return new Builder();
}

/**
* Builds vault oidc micro-infrastructure (identity roles and keys) to use it for
* machine-to-machine authentication.
*/
public void install() {
if (isNullOrNoneOrEmpty(vaultAddress)) {
LOGGER.debug("Skipping serviceRoles installation, vaultAddress not set");
LOGGER.debug("Skipping service roles installation, vault address not set");
return;
}

final ServiceRoles serviceRoles = loadServiceRoles();
if (serviceRoles == null || serviceRoles.roles.isEmpty()) {
LOGGER.debug("Skipping serviceRoles installation, serviceRoles not set");
LOGGER.debug("Skipping service roles installation, service roles not set");
return;
}

Expand All @@ -87,17 +95,19 @@ public void install() {
final var keyName = keyNameSupplier.get();

createVaultIdentityKey(rest.url(vaultIdentityKeyUri(keyName)), keyName);
LOGGER.debug("Vault identity key: {}", keyName);

for (var role : serviceRoles.roles) {
String roleName = roleNameBuilder.apply(role.role);
final var roleName = roleNameBuilder.apply(role.role);
createVaultIdentityRole(
rest.url(vaultIdentityRoleUri(roleName)),
keyName,
roleName,
role.role,
role.permissions);
LOGGER.debug("Vault identity role: {}", roleName);
}

LOGGER.debug("Installed serviceRoles ({})", serviceRoles);
LOGGER.debug("Installed service roles: {}", serviceRoles);
})
.get(timeout, timeUnit);
} catch (Exception e) {
Expand All @@ -106,10 +116,6 @@ public void install() {
}

private ServiceRoles loadServiceRoles() {
if (serviceRolesSources == null) {
return null;
}

for (Supplier<ServiceRoles> serviceRolesSource : serviceRolesSources) {
final ServiceRoles serviceRoles = serviceRolesSource.get();
if (serviceRoles != null) {
Expand All @@ -134,11 +140,10 @@ private void createVaultIdentityKey(Rest rest, String keyName) {
.add("allowed_client_ids", "*")
.add("algorithm", keyAlgorithm)
.toString()
.getBytes();
.getBytes(StandardCharsets.UTF_8);

try {
awaitSuccess(rest.body(body).post().getStatus());
LOGGER.debug("Created vault identity key: {}", keyName);
} catch (RestException e) {
throw new RuntimeException("Failed to create vault identity key: " + keyName, e);
}
Expand All @@ -149,23 +154,26 @@ private void createVaultIdentityRole(
final byte[] body =
Json.object()
.add("key", keyName)
.add("template", createTemplate(permissions))
.add("template", createTemplate(roleName, permissions))
.add("ttl", roleTtl)
.toString()
.getBytes();
.getBytes(StandardCharsets.UTF_8);

try {
awaitSuccess(rest.body(body).post().getStatus());
LOGGER.debug("Created vault identity role: {}", roleName);
} catch (RestException e) {
throw new RuntimeException("Failed to create vault identity role: " + roleName, e);
}
}

private static String createTemplate(List<String> permissions) {
private static String createTemplate(String roleName, List<String> permissions) {
return Base64.getUrlEncoder()
.encodeToString(
Json.object().add("permissions", String.join(",", permissions)).toString().getBytes());
Json.object()
.add("role", roleName)
.add("permissions", String.join(",", permissions))
.toString()
.getBytes(StandardCharsets.UTF_8));
}

private String vaultIdentityKeyUri(String keyName) {
Expand Down Expand Up @@ -363,7 +371,7 @@ public static class Builder {
private long timeout = 10;
private TimeUnit timeUnit = TimeUnit.SECONDS;

public Builder() {}
private Builder() {}

public Builder vaultAddress(String vaultAddress) {
this.vaultAddress = vaultAddress;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ private VaultServiceTokenSupplier(Builder builder) {
Objects.requireNonNull(builder.serviceTokenNameBuilder, "serviceTokenNameBuilder");
}

public static Builder builder() {
return new Builder();
}

/**
* Obtains vault service token (aka identity token or oidc token).
*
Expand Down Expand Up @@ -98,7 +102,7 @@ public static class Builder {
private Supplier<CompletableFuture<String>> vaultTokenSupplier;
private BiFunction<String, Map<String, String>, String> serviceTokenNameBuilder;

public Builder() {}
private Builder() {}

/**
* Setter for {@code vaultAddress}.
Expand Down