Skip to content

Commit

Permalink
moving logic from Azure OpenAI to the Langchain Core
Browse files Browse the repository at this point in the history
putting the resolve methods in the interface, not in the embedded Resolver class

updating documentation

formatting

added forgotten unremovable beans configuration
  • Loading branch information
csotiriou committed Jun 25, 2024
1 parent de7c673 commit 3c9bbb2
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.LangChain4jRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BeanStream;
Expand Down Expand Up @@ -353,6 +354,7 @@ public void cleanUp(LangChain4jRecorder recorder, ShutdownContextBuildItem shutd
@BuildStep
public void unremoveableBeans(BuildProducer<UnremovableBeanBuildItem> unremoveableProducer) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(ObjectMapper.class));
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(ModelAuthProvider.class));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.quarkiverse.langchain4j.runtime.auth;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.spi.CDI;

import io.quarkiverse.langchain4j.ModelName;

public interface ModelAuthProvider {
String getAuthorization(Input input);

interface Input {
String method();

URI uri();

Map<String, List<Object>> headers();
}

static Optional<ModelAuthProvider> resolve(String modelName) {
Instance<ModelAuthProvider> beanInstance = modelName == null
? CDI.current().select(ModelAuthProvider.class)
: CDI.current().select(ModelAuthProvider.class, ModelName.Literal.of(modelName));

//get the first one without causing a bean1 resolution exception
ModelAuthProvider authorizer = null;
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}
return Optional.ofNullable(authorizer);
}

static Optional<ModelAuthProvider> resolve() {
return resolve(null);
}

}
6 changes: 3 additions & 3 deletions docs/modules/ROOT/pages/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -183,20 +183,20 @@ public class RequestFilter implements ResteasyReactiveClientRequestFilter {
----

==== Using `AuthProvider`
One can implement the `AuthProvider` interface and provide the implementation of the `getAuthorization` method.
One can implement the `ModelAuthProvider` interface and provide the implementation of the `getAuthorization` method.

This is useful when you need to provide different authorization headers for different OpenAI models. The `@Named` annotation can be used to specify the model name in this scenario.

[source,java]
----
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import jakarta.enterprise.context.ApplicationScoped;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
@ApplicationScoped
@ModelName("my-model-name") //you can omit this if you have only one model or if you want to use the default model
public class TestClass implements OpenAiRestApi.AuthProvider {
public class TestClass implements ModelAuthProvider {
@Inject MyTokenProviderService tokenProviderService;
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import com.knuddels.jtokkit.Encodings;

import io.quarkiverse.langchain4j.openai.OpenAiRestApi;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.builditem.IndexDependencyBuildItem;
Expand All @@ -29,11 +27,6 @@ void indexDependencies(BuildProducer<IndexDependencyBuildItem> producer) {
producer.produce(new IndexDependencyBuildItem("dev.ai4j", "openai4j"));
}

@BuildStep
UnremovableBeanBuildItem unremovableBeans() {
return UnremovableBeanBuildItem.beanTypes(OpenAiRestApi.AuthProvider.class);
}

@BuildStep
void nativeImageSupport(BuildProducer<NativeImageResourceBuildItem> resourcesProducer) {
registerJtokkitResources(resourcesProducer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import dev.ai4j.openai4j.moderation.ModerationRequest;
import dev.ai4j.openai4j.moderation.ModerationResponse;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.rest.client.reactive.ClientExceptionMapper;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -182,35 +183,25 @@ public boolean test(SseEvent<String> event) {
}
}

interface AuthProvider {
String getAuthorization(Input input);

interface Input {
String method();

URI uri();

MultivaluedMap<String, Object> headers();
}
}

class OpenAIRestAPIFilter implements ResteasyReactiveClientRequestFilter {
AuthProvider authorizer;
ModelAuthProvider authorizer;

public OpenAIRestAPIFilter(AuthProvider authorizer) {
public OpenAIRestAPIFilter(ModelAuthProvider authorizer) {
this.authorizer = authorizer;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().putSingle("Authorization", authorizer.getAuthorization(
new AuthInputImpl(requestContext.getMethod(), requestContext.getUri(), requestContext.getHeaders())));
requestContext
.getHeaders()
.putSingle("Authorization", authorizer.getAuthorization(new AuthInputImpl(requestContext.getMethod(),
requestContext.getUri(), requestContext.getHeaders())));
}

private record AuthInputImpl(
String method,
URI uri,
MultivaluedMap<String, Object> headers) implements AuthProvider.Input {
MultivaluedMap<String, Object> headers) implements ModelAuthProvider.Input {
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
import java.util.function.Consumer;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.spi.CDI;
import jakarta.ws.rs.client.ClientRequestContext;
import jakarta.ws.rs.client.ClientRequestFilter;

Expand Down Expand Up @@ -43,7 +41,7 @@
import dev.ai4j.openai4j.moderation.ModerationResponse;
import dev.ai4j.openai4j.moderation.ModerationResult;
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.runtime.auth.ModelAuthProvider;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -116,21 +114,11 @@ public void filter(ClientRequestContext requestContext) {
});
}

OpenAiRestApi.AuthProvider authorizer = null;
Instance<OpenAiRestApi.AuthProvider> beanInstance = builder.configName == null
? CDI.current().select(OpenAiRestApi.AuthProvider.class)
: CDI.current().select(OpenAiRestApi.AuthProvider.class, ModelName.Literal.of(builder.configName));
ModelAuthProvider
.resolve(builder.configName)
.ifPresent(modelAuthProvider -> restApiBuilder
.register(new OpenAiRestApi.OpenAIRestAPIFilter(modelAuthProvider)));

//get the first one without causing a bean resolution exception
for (var handle : beanInstance.handles()) {
authorizer = handle.get();
break;
}

if (authorizer != null) {
var filterProvider = new OpenAiRestApi.OpenAIRestAPIFilter(authorizer);
restApiBuilder.register(filterProvider);
}
return restApiBuilder.build(OpenAiRestApi.class);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
Expand Down

0 comments on commit 3c9bbb2

Please sign in to comment.