Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Putting Generic Authentication flow to the Core #694

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.LangChain4jRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BeanStream;
Expand Down Expand Up @@ -353,6 +354,7 @@ public void cleanUp(LangChain4jRecorder recorder, ShutdownContextBuildItem shutd
@BuildStep
public void unremoveableBeans(BuildProducer<UnremovableBeanBuildItem> unremoveableProducer) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(ObjectMapper.class));
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(ModelAuthProvider.class));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.quarkiverse.langchain4j.runtime.auth;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.spi.CDI;

import io.quarkiverse.langchain4j.ModelName;

public interface ModelAuthProvider {
String getAuthorization(Input input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @csotiriou @geoand, only a couple of questions/suggestions for you to consider.

I was thinking, can we expect situations going forward where local models will require a secure token access, and where a token is passed, for example, as a property ?

I guess one can say that if Input is null then do not return a Bearer prefix in the response... This is related to the next question...

The other question, so ModelAuthProvider is expected to return a complete HTTP Authorization header value, including the scheme, Bearer <token>, instead of only <token>. In #708, moving to ModelAuthProvider has an impact of having to update the test filter to include Bearer - I don't think it is a breaking change though, since the current AuthProvider is not public.

I wonder, should we expect providers returns the token only, and then filters, depending on a context (access to the remote or local model) will either add Bearer or not.

IMHO it might be a bit better, and a little bit simpler for custom auth providers too...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to worry about that now, it's very early days to be able to know exactly how systems using auth will evolve.
For now, I think the proposal in this PR is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@geoand Sure, Ok then, this PR works well for #708 so it is good to go from my perspective, thanks


interface Input {
String method();

URI uri();

Map<String, List<Object>> headers();
}

static Optional<ModelAuthProvider> resolve(String modelName) {
Instance<ModelAuthProvider> beanInstance = modelName == null
? CDI.current().select(ModelAuthProvider.class)
: CDI.current().select(ModelAuthProvider.class, ModelName.Literal.of(modelName));

//get the first one without causing a bean1 resolution exception
ModelAuthProvider authorizer = null;
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
return Optional.ofNullable(authorizer);
}

static Optional<ModelAuthProvider> resolve() {
return resolve(null);
}

}
5 changes: 3 additions & 2 deletions docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,21 @@ public class RequestFilter implements ResteasyReactiveClientRequestFilter {
----

==== Using `AuthProvider`
One can implement the `AuthProvider` interface and provide the implementation of the `getAuthorization` method.
One can implement the `ModelAuthProvider` interface and provide the implementation of the `getAuthorization` method.

This is useful when you need to provide different authorization headers for different OpenAI models. The `@Named` annotation can be used to specify the model name in this scenario.

[source,java]
----
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

@ApplicationScoped
@ModelName("my-model-name") //you can omit this if you have only one model or if you want to use the default model
public class TestClass implements OpenAiRestApi.AuthProvider {
public class TestClass implements ModelAuthProvider {
@Inject MyTokenProviderService tokenProviderService;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import com.knuddels.jtokkit.Encodings;

import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.builditem.IndexDependencyBuildItem;
Expand All @@ -29,11 +27,6 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
producer.produce(new IndexDependencyBuildItem("dev.ai4j", "openai4j"));
}

@BuildStep
UnremovableBeanBuildItem unremovableBeans() {
return UnremovableBeanBuildItem.beanTypes(OpenAiRestApi.AuthProvider.class);
}

csotiriou marked this conversation as resolved.
Show resolved Hide resolved
@BuildStep
void nativeImageSupport(BuildProducer<NativeImageResourceBuildItem> resourcesProducer) {
registerJtokkitResources(resourcesProducer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import dev.ai4j.openai4j.moderation.ModerationRequest;
import dev.ai4j.openai4j.moderation.ModerationResponse;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -182,35 +183,25 @@ public boolean test(SseEvent<String> event) {
}
}

interface AuthProvider {
String getAuthorization(Input input);

interface Input {
String method();

URI uri();

MultivaluedMap<String, Object> headers();
}
}

class OpenAIRestAPIFilter implements ResteasyReactiveClientRequestFilter {
AuthProvider authorizer;
ModelAuthProvider authorizer;

public OpenAIRestAPIFilter(AuthProvider authorizer) {
public OpenAIRestAPIFilter(ModelAuthProvider authorizer) {
Copy link
Contributor

@sberyozkin sberyozkin Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO it should be aligned with how it is done for Vertex AI Gemini TokenFilter, to have ExecutorService injected as well and then call a custom provider from this ExecutorService as providers may want to use some blocking calls etc. Using ManagedExecutor is best as it retains the original RequestScope which lets users inject Quarkus API beans which require it, as proposed in #708.

We confirmed with Georgios that using ManagedExecutot helps in the secure-fraud-detection demo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

this.authorizer = authorizer;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().putSingle("Authorization", authorizer.getAuthorization(
new AuthInputImpl(requestContext.getMethod(), requestContext.getUri(), requestContext.getHeaders())));
requestContext
.getHeaders()
.putSingle("Authorization", authorizer.getAuthorization(new AuthInputImpl(requestContext.getMethod(),
requestContext.getUri(), requestContext.getHeaders())));
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers) implements AuthProvider.Input {
MultivaluedMap<String, Object> headers) implements ModelAuthProvider.Input {
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.spi.CDI;
import jakarta.ws.rs.client.ClientRequestContext;
import jakarta.ws.rs.client.ClientRequestFilter;

Expand Down Expand Up @@ -43,7 +41,7 @@
import dev.ai4j.openai4j.moderation.ModerationResponse;
import dev.ai4j.openai4j.moderation.ModerationResult;
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -116,21 +114,11 @@ public void filter(ClientRequestContext requestContext) {
});
}

OpenAiRestApi.AuthProvider authorizer = null;
Instance<OpenAiRestApi.AuthProvider> beanInstance = builder.configName == null
? CDI.current().select(OpenAiRestApi.AuthProvider.class)
: CDI.current().select(OpenAiRestApi.AuthProvider.class, ModelName.Literal.of(builder.configName));
ModelAuthProvider
.resolve(builder.configName)
.ifPresent(modelAuthProvider -> restApiBuilder
.register(new OpenAiRestApi.OpenAIRestAPIFilter(modelAuthProvider)));

//get the first one without causing a bean resolution exception
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}

if (authorizer != null) {
var filterProvider = new OpenAiRestApi.OpenAIRestAPIFilter(authorizer);
restApiBuilder.register(filterProvider);
}
return restApiBuilder.build(OpenAiRestApi.class);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
Expand Down
Loading