Skip to content

Commit

Permalink
OIDC: support CredentialsProvider recorded at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
michalvavrik committed Jun 30, 2024
1 parent c20fa84 commit 69564e6
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.deployment.ApplicationArchive;
import io.quarkus.deployment.Feature;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.annotations.Consume;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.ApplicationArchivesBuildItem;
Expand Down Expand Up @@ -89,6 +91,13 @@ private Set<String> oidcClientNamesOf(ApplicationArchivesBuildItem beanArchiveIn
.collect(Collectors.toSet());
}

@Consume(SyntheticBeansRuntimeInitBuildItem.class)
@Record(ExecutionTime.RUNTIME_INIT)
@BuildStep
void initOidcClients(OidcClientRecorder recorder) {
recorder.initOidcClients();
}

@Record(ExecutionTime.RUNTIME_INIT)
@BuildStep
public void setup(
Expand All @@ -99,40 +108,36 @@ public void setup(
TlsRegistryBuildItem tlsRegistry,
BuildProducer<SyntheticBeanBuildItem> syntheticBean) {

OidcClients clients = recorder.setup(oidcConfig, vertxBuildItem.getVertx(), tlsRegistry.registry());

syntheticBean.produce(SyntheticBeanBuildItem.configure(OidcClient.class).unremovable()
.types(OidcClient.class)
.supplier(recorder.createOidcClientBean(clients))
syntheticBean.produce(SyntheticBeanBuildItem.configure(OidcClients.class).unremovable()
.types(OidcClients.class)
.supplier(recorder.createOidcClientsBean(oidcConfig, vertxBuildItem.getVertx(), tlsRegistry.registry()))
.scope(Singleton.class)
.setRuntimeInit()
.destroyer(BeanDestroyer.CloseableDestroyer.class)
.done());

syntheticBean.produce(SyntheticBeanBuildItem.configure(OidcClients.class).unremovable()
.types(OidcClients.class)
.supplier(recorder.createOidcClientsBean(clients))
syntheticBean.produce(SyntheticBeanBuildItem.configure(OidcClient.class).unremovable()
.types(OidcClient.class)
.supplier(recorder.createOidcClientBean())
.scope(Singleton.class)
.setRuntimeInit()
.destroyer(BeanDestroyer.CloseableDestroyer.class)
.done());

produceNamedOidcClientBeans(syntheticBean, oidcClientNames.oidcClientNames(), recorder, clients);
produceNamedOidcClientBeans(syntheticBean, oidcClientNames.oidcClientNames(), recorder);
}

private void produceNamedOidcClientBeans(BuildProducer<SyntheticBeanBuildItem> syntheticBean,
Set<String> injectedOidcClientNames,
OidcClientRecorder recorder, OidcClients clients) {
Set<String> injectedOidcClientNames, OidcClientRecorder recorder) {
injectedOidcClientNames.stream()
.map(clientName -> syntheticNamedOidcClientBeanFor(clientName, recorder, clients))
.map(clientName -> syntheticNamedOidcClientBeanFor(clientName, recorder))
.forEach(syntheticBean::produce);
}

private SyntheticBeanBuildItem syntheticNamedOidcClientBeanFor(String clientName, OidcClientRecorder recorder,
OidcClients clients) {
private SyntheticBeanBuildItem syntheticNamedOidcClientBeanFor(String clientName, OidcClientRecorder recorder) {
return SyntheticBeanBuildItem.configure(OidcClient.class).unremovable()
.types(OidcClient.class)
.supplier(recorder.createOidcClientBean(clients, clientName))
.supplier(recorder.createOidcClientBean(clientName))
.scope(Singleton.class)
.addQualifier().annotation(NamedOidcClient.class).addValue("value", clientName).done()
.setRuntimeInit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,22 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.lang.reflect.Method;
import java.util.function.Consumer;

import jakarta.enterprise.context.ApplicationScoped;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.builder.BuildChainBuilder;
import io.quarkus.builder.BuildContext;
import io.quarkus.builder.BuildStep;
import io.quarkus.credentials.CredentialsProvider;
import io.quarkus.deployment.builditem.MainBytecodeRecorderBuildItem;
import io.quarkus.deployment.recording.BytecodeRecorderImpl;
import io.quarkus.runtime.annotations.Recorder;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.WithTestResource;
import io.restassured.RestAssured;
Expand All @@ -17,14 +30,17 @@ public class OidcClientCredentialsJwtSecretTestCase {
private static Class<?>[] testClasses = {
OidcClientsResource.class,
ProtectedResource.class,
SecretProvider.class
RuntimeSecretProvider.class,
TestRecorder.class,
OidcClientCredentialsJwtSecretTestCase.class
};

@RegisterExtension
static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar
.addClasses(testClasses)
.addAsResource("application-oidc-client-credentials-jwt-secret.properties", "application.properties"));
.addAsResource("application-oidc-client-credentials-jwt-secret.properties", "application.properties"))
.addBuildChainCustomizer(buildCustomizer());

@Test
public void testGetTokenJwtClient() {
Expand Down Expand Up @@ -53,4 +69,51 @@ private static void assertTokensNotNull(String[] tokens) {
assertNotNull(tokens[0]);
assertEquals("null", tokens[1]);
}

@Recorder
public static class TestRecorder {

public RuntimeSecretProvider createRuntimeSecretProvider() {
return new RuntimeSecretProvider();
}

}

private static Consumer<BuildChainBuilder> buildCustomizer() {
// whole purpose of this step is to have a bean recorded during runtime init
return new Consumer<BuildChainBuilder>() {

@Override
public void accept(BuildChainBuilder builder) {
builder.addBuildStep(new BuildStep() {
@Override
public void execute(BuildContext context) {
BytecodeRecorderImpl bytecodeRecorder = new BytecodeRecorderImpl(false,
TestRecorder.class.getSimpleName(), "createRuntimeSecretProvider",
"" + TestRecorder.class.hashCode(), true, s -> null);
context.produce(new MainBytecodeRecorderBuildItem(bytecodeRecorder));

// We need to use reflection due to some class loading problems
Object recorderProxy = bytecodeRecorder.getRecordingProxy(TestRecorder.class);
try {
Method creator = recorderProxy.getClass().getDeclaredMethod("createRuntimeSecretProvider");
Object proxy1 = creator.invoke(recorderProxy, new Object[] {});

context.produce(SyntheticBeanBuildItem
.configure(RuntimeSecretProvider.class)
.types(CredentialsProvider.class)
.scope(ApplicationScoped.class)
.setRuntimeInit()
.unremovable()
.runtimeProxy(proxy1)
.done());

} catch (Exception e) {
throw new RuntimeException(e);
}
}
}).produces(MainBytecodeRecorderBuildItem.class).produces(SyntheticBeanBuildItem.class).build();
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package io.quarkus.oidc.client;

import java.util.HashMap;
import java.util.Map;

import io.quarkus.credentials.CredentialsProvider;

public class RuntimeSecretProvider implements CredentialsProvider {

@Override
public Map<String, String> getCredentials(String credentialsProviderName) {
Map<String, String> creds = new HashMap<>();
creds.put("secret-from-vault-for-jwt",
"AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow");
return creds;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ public class SecretProvider implements CredentialsProvider {
public Map<String, String> getCredentials(String credentialsProviderName) {
Map<String, String> creds = new HashMap<>();
creds.put("secret-from-vault", "secret");
creds.put("secret-from-vault-for-jwt",
"AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow");
return creds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

import jakarta.enterprise.inject.CreationException;

import org.jboss.logging.Logger;

import io.quarkus.arc.Arc;
import io.quarkus.oidc.client.OidcClient;
import io.quarkus.oidc.client.OidcClientConfig;
import io.quarkus.oidc.client.OidcClientConfig.Grant;
Expand Down Expand Up @@ -40,7 +43,7 @@ public class OidcClientRecorder {
private static final String CLIENT_ID_ATTRIBUTE = "client-id";
private static final String DEFAULT_OIDC_CLIENT_ID = "Default";

public OidcClients setup(OidcClientsConfig oidcClientsConfig, Supplier<Vertx> vertx,
private static OidcClients setup(OidcClientsConfig oidcClientsConfig, Supplier<Vertx> vertx,
Supplier<TlsConfigurationRegistry> registrySupplier) {

String defaultClientId = oidcClientsConfig.defaultClient.getId().orElse(DEFAULT_OIDC_CLIENT_ID);
Expand All @@ -66,32 +69,33 @@ public Uni<OidcClient> apply(OidcClientConfig config) {
});
}

public Supplier<OidcClient> createOidcClientBean(OidcClients clients) {
public Supplier<OidcClient> createOidcClientBean() {
return new Supplier<OidcClient>() {

@Override
public OidcClient get() {
return clients.getClient();
return Arc.container().instance(OidcClients.class).get().getClient();
}
};
}

public Supplier<OidcClient> createOidcClientBean(OidcClients clients, String clientName) {
public Supplier<OidcClient> createOidcClientBean(String clientName) {
return new Supplier<OidcClient>() {

@Override
public OidcClient get() {
return clients.getClient(clientName);
return Arc.container().instance(OidcClients.class).get().getClient(clientName);
}
};
}

public Supplier<OidcClients> createOidcClientsBean(OidcClients clients) {
public Supplier<OidcClients> createOidcClientsBean(OidcClientsConfig oidcClientsConfig, Supplier<Vertx> vertx,
Supplier<TlsConfigurationRegistry> registrySupplier) {
return new Supplier<OidcClients>() {

@Override
public OidcClients get() {
return clients;
return setup(oidcClientsConfig, vertx, registrySupplier);
}
};
}
Expand Down Expand Up @@ -245,6 +249,19 @@ protected static OidcClientException toOidcClientException(String authServerUrlS
return new OidcClientException(OidcCommonUtils.formatConnectionErrorMessage(authServerUrlString), cause);
}

public void initOidcClients() {
try {
// makes sure that OIDC Clients are created at the latest when runtime synthetic beans are ready
Arc.container().instance(OidcClients.class).get();
} catch (CreationException wrapper) {
if (wrapper.getCause() instanceof RuntimeException runtimeException) {
// so that users see ConfigurationException etc. without noise
throw runtimeException;
}
throw wrapper;
}
}

private static class DisabledOidcClient implements OidcClient {
String message;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.quarkus.arc.deployment.InjectionPointTransformerBuildItem;
import io.quarkus.arc.deployment.QualifierRegistrarBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem;
import io.quarkus.arc.processor.Annotations;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.arc.processor.InjectionPointInfo;
Expand All @@ -42,6 +43,7 @@
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.annotations.Consume;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
Expand Down Expand Up @@ -73,7 +75,6 @@
import io.quarkus.oidc.runtime.OidcUtils;
import io.quarkus.oidc.runtime.TenantConfigBean;
import io.quarkus.oidc.runtime.providers.AzureAccessTokenCustomizer;
import io.quarkus.smallrye.context.deployment.ContextPropagationInitializedBuildItem;
import io.quarkus.tls.TlsRegistryBuildItem;
import io.quarkus.vertx.core.deployment.CoreVertxBuildItem;
import io.quarkus.vertx.http.deployment.EagerSecurityInterceptorBindingBuildItem;
Expand Down Expand Up @@ -274,18 +275,23 @@ public SyntheticBeanBuildItem setup(
OidcConfig config,
OidcRecorder recorder,
CoreVertxBuildItem vertxBuildItem,
TlsRegistryBuildItem tlsRegistryBuildItem,
// this is required for setup ordering: we need CP set up
ContextPropagationInitializedBuildItem cpInitializedBuildItem) {
TlsRegistryBuildItem tlsRegistryBuildItem) {
return SyntheticBeanBuildItem.configure(TenantConfigBean.class).unremovable().types(TenantConfigBean.class)
.supplier(recorder.setup(config, vertxBuildItem.getVertx(), tlsRegistryBuildItem.registry(),
detectUserInfoRequired(beanRegistration)))
.supplier(recorder.createTenantConfigBean(config, vertxBuildItem.getVertx(),
tlsRegistryBuildItem.registry(), detectUserInfoRequired(beanRegistration)))
.destroyer(TenantConfigBean.Destroyer.class)
.scope(Singleton.class) // this should have been @ApplicationScoped but fails for some reason
.setRuntimeInit()
.done();
}

@Consume(SyntheticBeansRuntimeInitBuildItem.class)
@Record(ExecutionTime.RUNTIME_INIT)
@BuildStep
void initTenantConfigBean(OidcRecorder recorder) {
recorder.initTenantConfigBean();
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void registerTenantResolverInterceptor(Capabilities capabilities, OidcRecorder recorder,
Expand Down
Loading

0 comments on commit 69564e6

Please sign in to comment.