Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Vertex AI Gemini provider to use ModelAuthProvider #708

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-deployment</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.quarkiverse.langchain4j.oidc.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.oidc-model-auth-provider")
public interface OidcModelAuthProviderBuildConfig {
/**
* Whether the OIDC ModelAuthProvider should be enabled
*/
@ConfigDocDefault("true")
Optional<Boolean> enabled();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.quarkiverse.langchain4j.oidc.deployment;

import java.util.function.BooleanSupplier;

import io.quarkiverse.langchain4j.oidc.runtime.OidcModelAuthProvider;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.builditem.FeatureBuildItem;

@BuildSteps(onlyIf = OidcModelAuthProviderProcessor.IsEnabled.class)
public class OidcModelAuthProviderProcessor {
private static final String FEATURE = "langchain4j-oidc-model-auth-provider";

@BuildStep
FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
public void additionalBeans(BuildProducer<AdditionalBeanBuildItem> additionalBeans) {
AdditionalBeanBuildItem.Builder builder = AdditionalBeanBuildItem.builder().setUnremovable();
builder.addBeanClass(OidcModelAuthProvider.class);
additionalBeans.produce(builder.build());
}

public static class IsEnabled implements BooleanSupplier {
OidcModelAuthProviderBuildConfig config;

public boolean getAsBoolean() {
return config.enabled().orElse(true);
}
}
}
20 changes: 20 additions & 0 deletions model-auth-providers/oidc-model-auth-provider/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-parent</artifactId>
<version>999-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Parent</name>
<packaging>pom</packaging>

<modules>
<module>deployment</module>
<module>runtime</module>
</modules>


</project>
65 changes: 65 additions & 0 deletions model-auth-providers/oidc-model-auth-provider/runtime/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-oidc-model-auth-provider</artifactId>
<name>Quarkus LangChain4j - OpenId Connect (OIDC) ModelAuthProvider - Runtime</name>
<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-arc</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus.security</groupId>
<artifactId>quarkus-security</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-maven-plugin</artifactId>
<version>${quarkus.version}</version>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>extension-descriptor</goal>
</goals>
<configuration>
<deployment>${project.groupId}:${project.artifactId}-deployment:${project.version}</deployment>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.oidc.runtime;

import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.security.credential.TokenCredential;

public class OidcModelAuthProvider implements ModelAuthProvider {
@Inject
Instance<TokenCredential> tokenCredential;

@Override
public String getAuthorization(Input input) {
return tokenCredential.isResolvable() ? "Bearer " + tokenCredential.get().getToken() : null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: LangChain4j OpenId Connect (OIDC) ModelAuthProvider
artifact: ${project.groupId}:${project.artifactId}:${project.version}
description: Provides ModelAuthProvider which uses OIDC bearer or authorization code flow access tokens
metadata:
keywords:
- ai
- langchain4j
- oidc
- security
guide: "https://docs.quarkiverse.io/quarkus-langchain4j/dev/index.html"
categories:
- "security"
status: "experimental"

Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import com.github.tomakehurst.wiremock.verification.LoggedRequest;

import dev.langchain4j.model.chat.ChatLanguageModel;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.testing.internal.WiremockAware;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertexAiGeminiChatLanguageModel;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.VertxAiGeminiRestApi;
import io.quarkus.arc.ClientProxy;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -111,11 +111,13 @@ void test() {
}

@Singleton
public static class DummyAuthProvider implements VertxAiGeminiRestApi.AuthProvider {
public static class DummyAuthProvider implements ModelAuthProvider {

@Override
public String getBearerToken() {
return API_KEY;
public String getAuthorization(Input input) {
return "Bearer " + API_KEY;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.util.concurrent.ExecutorService;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;
import jakarta.ws.rs.BeanParam;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.MultivaluedMap;

import org.eclipse.microprofile.context.ManagedExecutor;
import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.RestPath;
Expand All @@ -25,7 +29,9 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.auth.oauth2.GoogleCredentials;

import io.quarkus.arc.DefaultBean;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider.Input;
import io.quarkiverse.langchain4j.vertexai.runtime.gemini.config.ChatModelConfig;
import io.quarkus.rest.client.reactive.jackson.ClientObjectMapper;
import io.vertx.core.Handler;
import io.vertx.core.MultiMap;
Expand Down Expand Up @@ -102,17 +108,10 @@ public ApiMetadata build() {
}
}

interface AuthProvider {

String getBearerToken();
}

@ApplicationScoped
@DefaultBean
class ApplicationDefaultAuthProvider implements AuthProvider {
class ApplicationDefaultAuthProvider implements ModelAuthProvider {

@Override
public String getBearerToken() {
public String getAuthorization(Input input) {
try {
var credentials = GoogleCredentials.getApplicationDefault();
credentials.refreshIfExpired();
Expand All @@ -126,11 +125,17 @@ public String getBearerToken() {
class TokenFilter implements ResteasyReactiveClientRequestFilter {

private final ExecutorService executorService;
private final AuthProvider authProvider;
private final ModelAuthProvider defaultAuthorizer;
private final ModelAuthProvider authorizer;

@Inject
Instance<ChatModelConfig> model;

public TokenFilter(ExecutorService executorService, AuthProvider authProvider) {
public TokenFilter(ManagedExecutor executorService) {
this.executorService = executorService;
this.authProvider = authProvider;
this.defaultAuthorizer = new ApplicationDefaultAuthProvider();
this.authorizer = ModelAuthProvider.resolve(
model != null && model.isResolvable() ? model.get().modelId() : null).orElse(null);
}

@Override
Expand All @@ -140,14 +145,25 @@ public void filter(ResteasyReactiveClientRequestContext context) {
@Override
public void run() {
try {
context.getHeaders().add("Authorization", "Bearer " + authProvider.getBearerToken());
final Input authInput = new AuthInputImpl(context.getMethod(), context.getUri(), context.getHeaders());
String authorization = authorizer != null ? authorizer.getAuthorization(authInput) : null;
if (authorization == null) {
authorization = defaultAuthorizer.getAuthorization(authInput);
}
context.getHeaders().add("Authorization", authorization);
context.resume();
} catch (Exception e) {
context.resume(e);
}
}
});
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers) implements ModelAuthProvider.Input {
}
}

class VertxAiClientLogger implements ClientLogger {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ interface VertexAiGeminiConfig {
Optional<String> baseUrl();

/**
* Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Anthropic
* Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Vertex AI Gemini
* provider.
* Set to {@code false} to disable all requests.
*/
Expand Down
3 changes: 3 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
<module>model-providers/vertex-ai-gemini</module>
<module>model-providers/watsonx</module>

<module>model-auth-providers/oidc-model-auth-provider</module>

<module>quarkus-integrations/websockets-next</module>

<module>rag/easy-rag</module>
Expand Down Expand Up @@ -199,6 +201,7 @@
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/secure-fraud-detection</module>
<module>samples/secure-vertex-ai-gemini-poem</module>
<module>samples/chatbot</module>
<module>samples/chatbot-easy-rag</module>
<module>samples/sql-chatbot</module>
Expand Down
Loading
Loading