Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the ability to use inject models #258

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;

import java.util.Optional;
import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.bam.runtime.BamRecorder;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
Expand Down Expand Up @@ -49,31 +53,43 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void generateBeans(BamRecorder recorder,
Optional<SelectedChatModelProviderBuildItem> selectedChatItem,
Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
List<SelectedChatModelProviderBuildItem> selectedChatItem,
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
Langchain4jBamConfig config,
BuildProducer<SyntheticBeanBuildItem> beanProducer) {

if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) {
beanProducer.produce(SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config))
.done());
for (var selected : selectedChatItem) {
if (PROVIDER.equals(selected.getProvider())) {
String modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config, modelName));
addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}

for (var selected : selectedEmbedding) {
if (PROVIDER.equals(selected.getProvider())) {
String modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config, modelName));
addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}
}

if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
beanProducer.produce(
SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config))
.unremovable()
.done());
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) {
if (!NamedModelUtil.isDefault(modelName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ interface NewAIService {
NewAIService service;

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Test
void chat() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

var modelId = config.chatModel().modelId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class AllPropertiesTest {
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Inject
ChatLanguageModel model;
Expand All @@ -79,6 +79,7 @@ static void afterAll() {

@Test
void generate() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

assertEquals(WireMockUtil.URL, config.baseUrl().get().toString());
assertEquals(WireMockUtil.API_KEY, config.apiKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class DefaultPropertiesTest {
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Inject
ChatLanguageModel model;
Expand All @@ -59,6 +59,7 @@ static void afterAll() {

@Test
void generate() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

assertEquals(Duration.ofSeconds(10), config.timeout());
assertEquals("2024-01-10", config.version());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.quarkiverse.langchain4j.bam.BamException.Code;
import io.quarkiverse.langchain4j.bam.BamException.Reason;
import io.quarkiverse.langchain4j.bam.BamRestApi;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkus.test.QuarkusUnitTest;

public class HttpErrorTest {
Expand All @@ -36,9 +35,6 @@ public class HttpErrorTest {
static ObjectMapper mapper;
static WireMockUtil mockServers;

@Inject
Langchain4jBamConfig config;

@Inject
ChatLanguageModel model;

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,32 @@
import io.quarkiverse.langchain4j.bam.BamChatModel;
import io.quarkiverse.langchain4j.bam.BamEmbeddingModel;
import io.quarkiverse.langchain4j.bam.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.bam.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;

@Recorder
public class BamRecorder {

public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
private static final String DUMMY_KEY = "dummy";

public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName);
ChatModelConfig chatModelConfig = bamConfig.chatModel();
String apiKey = bamConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}

var builder = BamChatModel.builder()
.accessToken(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.accessToken(bamConfig.apiKey())
.timeout(bamConfig.timeout())
.logRequests(bamConfig.logRequests())
.logResponses(bamConfig.logResponses())
.modelId(chatModelConfig.modelId())
.version(runtimeConfig.version())
.version(bamConfig.version())
.decodingMethod(chatModelConfig.decodingMethod())
.minNewTokens(chatModelConfig.minNewTokens())
.maxNewTokens(chatModelConfig.maxNewTokens())
Expand All @@ -38,8 +48,8 @@ public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig) {
.truncateInputTokens(firstOrDefault(null, chatModelConfig.truncateInputTokens()))
.beamWidth(firstOrDefault(null, chatModelConfig.beamWidth()));

if (runtimeConfig.baseUrl().isPresent()) {
builder.url(runtimeConfig.baseUrl().get());
if (bamConfig.baseUrl().isPresent()) {
builder.url(bamConfig.baseUrl().get());
}

return new Supplier<>() {
Expand All @@ -50,18 +60,22 @@ public Object get() {
};
}

public Supplier<?> embeddingModel(Langchain4jBamConfig runtimeConfig) {

var embeddingModelConfig = runtimeConfig.embeddingModel();
public Supplier<?> embeddingModel(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName);
EmbeddingModelConfig embeddingModelConfig = bamConfig.embeddingModel();
String apiKey = bamConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}

var builder = BamEmbeddingModel.builder()
.accessToken(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.version(runtimeConfig.version())
.accessToken(bamConfig.apiKey())
.timeout(bamConfig.timeout())
.version(bamConfig.version())
.modelId(embeddingModelConfig.modelId());

if (runtimeConfig.baseUrl().isPresent()) {
builder.url(runtimeConfig.baseUrl().get());
if (bamConfig.baseUrl().isPresent()) {
builder.url(bamConfig.baseUrl().get());
}

return new Supplier<>() {
Expand All @@ -71,4 +85,28 @@ public Object get() {
}
};
}

private Langchain4jBamConfig.BamConfig correspondingBamConfig(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig;
if (NamedModelUtil.isDefault(modelName)) {
bamConfig = runtimeConfig.defaultConfig();
} else {
bamConfig = runtimeConfig.namedConfig().get(modelName);
}
return bamConfig;
}

private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) {
return createConfigProblems("api-key", modelName);
}

private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) {
return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) };
}

private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) {
return new ConfigValidationException.Problem(String.format(
"SRCFG00014: The config property quarkus.langchain4j.bam%s%s is required but it could not be found in any config source",
NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key));
}
}
Loading
Loading