diff --git a/bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java b/bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java index 180f67162..988ef7f7a 100644 --- a/bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java +++ b/bam/deployment/src/main/java/io/quarkiverse/langchain4j/bam/deployment/BamProcessor.java @@ -3,16 +3,20 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.bam.runtime.BamRecorder; import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -49,31 +53,43 @@ public void providerCandidates(BuildProducer selectedChatItem, - Optional selectedEmbedding, + List selectedChatItem, + List selectedEmbedding, Langchain4jBamConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } + } - if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - beanProducer.produce( - SyntheticBeanBuildItem - .configure(EMBEDDING_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config)) - .unremovable() - .done()); + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); } } } diff --git a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AiServiceTest.java b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AiServiceTest.java index 4dc43f6b7..171aec57a 100644 --- a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AiServiceTest.java +++ b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AiServiceTest.java @@ -66,10 +66,11 @@ interface NewAIService { NewAIService service; @Inject - Langchain4jBamConfig config; + Langchain4jBamConfig langchain4jBamConfig; @Test void chat() throws Exception { + var config = langchain4jBamConfig.defaultConfig(); var modelId = config.chatModel().modelId(); diff --git a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AllPropertiesTest.java b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AllPropertiesTest.java index 4bb53a6f4..24a49f26e 100644 --- a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AllPropertiesTest.java +++ b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/AllPropertiesTest.java @@ -59,7 +59,7 @@ public class AllPropertiesTest { .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Inject - Langchain4jBamConfig config; + Langchain4jBamConfig langchain4jBamConfig; @Inject ChatLanguageModel model; @@ -79,6 +79,7 @@ static void afterAll() { @Test void generate() throws Exception { + var config = langchain4jBamConfig.defaultConfig(); assertEquals(WireMockUtil.URL, config.baseUrl().get().toString()); assertEquals(WireMockUtil.API_KEY, config.apiKey()); diff --git a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/DefaultPropertiesTest.java b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/DefaultPropertiesTest.java index 87832b1b3..c28762306 100644 --- a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/DefaultPropertiesTest.java +++ b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/DefaultPropertiesTest.java @@ -39,7 +39,7 @@ public class DefaultPropertiesTest { .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class)); @Inject - Langchain4jBamConfig config; + Langchain4jBamConfig langchain4jBamConfig; @Inject ChatLanguageModel model; @@ -59,6 +59,7 @@ static void afterAll() { @Test void generate() throws Exception { + var config = langchain4jBamConfig.defaultConfig(); assertEquals(Duration.ofSeconds(10), config.timeout()); assertEquals("2024-01-10", config.version()); diff --git a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/HttpErrorTest.java b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/HttpErrorTest.java index 0e6e15484..1257bae8b 100644 --- a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/HttpErrorTest.java +++ b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/HttpErrorTest.java @@ -27,7 +27,6 @@ import io.quarkiverse.langchain4j.bam.BamException.Code; import io.quarkiverse.langchain4j.bam.BamException.Reason; import io.quarkiverse.langchain4j.bam.BamRestApi; -import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig; import io.quarkus.test.QuarkusUnitTest; public class HttpErrorTest { @@ -36,9 +35,6 @@ public class HttpErrorTest { static ObjectMapper mapper; static WireMockUtil mockServers; - @Inject - Langchain4jBamConfig config; - @Inject ChatLanguageModel model; diff --git a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/MissingPropertiesTest.java b/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/MissingPropertiesTest.java deleted file mode 100644 index 339be348c..000000000 --- a/bam/deployment/src/test/java/io/quarkiverse/langchain4j/bam/deployment/MissingPropertiesTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package io.quarkiverse.langchain4j.bam.deployment; - -import static org.junit.jupiter.api.Assertions.fail; - -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import io.quarkus.test.QuarkusUnitTest; -import io.smallrye.config.ConfigValidationException; - -public class MissingPropertiesTest { - - @RegisterExtension - static QuarkusUnitTest unitTest = new QuarkusUnitTest() - .setExpectedException(ConfigValidationException.class) - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); - - @Test - void test() { - fail("Should not be called"); - } -} diff --git a/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/BamRecorder.java b/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/BamRecorder.java index 286a253f9..a53d844d9 100644 --- a/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/BamRecorder.java +++ b/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/BamRecorder.java @@ -7,22 +7,32 @@ import io.quarkiverse.langchain4j.bam.BamChatModel; import io.quarkiverse.langchain4j.bam.BamEmbeddingModel; import io.quarkiverse.langchain4j.bam.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.bam.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.annotations.Recorder; +import io.smallrye.config.ConfigValidationException; @Recorder public class BamRecorder { - public Supplier chatModel(Langchain4jBamConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + private static final String DUMMY_KEY = "dummy"; + + public Supplier chatModel(Langchain4jBamConfig runtimeConfig, String modelName) { + Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName); + ChatModelConfig chatModelConfig = bamConfig.chatModel(); + String apiKey = bamConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } var builder = BamChatModel.builder() - .accessToken(runtimeConfig.apiKey()) - .timeout(runtimeConfig.timeout()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .accessToken(bamConfig.apiKey()) + .timeout(bamConfig.timeout()) + .logRequests(bamConfig.logRequests()) + .logResponses(bamConfig.logResponses()) .modelId(chatModelConfig.modelId()) - .version(runtimeConfig.version()) + .version(bamConfig.version()) .decodingMethod(chatModelConfig.decodingMethod()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) @@ -38,8 +48,8 @@ public Supplier chatModel(Langchain4jBamConfig runtimeConfig) { .truncateInputTokens(firstOrDefault(null, chatModelConfig.truncateInputTokens())) .beamWidth(firstOrDefault(null, chatModelConfig.beamWidth())); - if (runtimeConfig.baseUrl().isPresent()) { - builder.url(runtimeConfig.baseUrl().get()); + if (bamConfig.baseUrl().isPresent()) { + builder.url(bamConfig.baseUrl().get()); } return new Supplier<>() { @@ -50,18 +60,22 @@ public Object get() { }; } - public Supplier embeddingModel(Langchain4jBamConfig runtimeConfig) { - - var embeddingModelConfig = runtimeConfig.embeddingModel(); + public Supplier embeddingModel(Langchain4jBamConfig runtimeConfig, String modelName) { + Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName); + EmbeddingModelConfig embeddingModelConfig = bamConfig.embeddingModel(); + String apiKey = bamConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } var builder = BamEmbeddingModel.builder() - .accessToken(runtimeConfig.apiKey()) - .timeout(runtimeConfig.timeout()) - .version(runtimeConfig.version()) + .accessToken(bamConfig.apiKey()) + .timeout(bamConfig.timeout()) + .version(bamConfig.version()) .modelId(embeddingModelConfig.modelId()); - if (runtimeConfig.baseUrl().isPresent()) { - builder.url(runtimeConfig.baseUrl().get()); + if (bamConfig.baseUrl().isPresent()) { + builder.url(bamConfig.baseUrl().get()); } return new Supplier<>() { @@ -71,4 +85,28 @@ public Object get() { } }; } + + private Langchain4jBamConfig.BamConfig correspondingBamConfig(Langchain4jBamConfig runtimeConfig, String modelName) { + Langchain4jBamConfig.BamConfig bamConfig; + if (NamedModelUtil.isDefault(modelName)) { + bamConfig = runtimeConfig.defaultConfig(); + } else { + bamConfig = runtimeConfig.namedConfig().get(modelName); + } + return bamConfig; + } + + private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { + return createConfigProblems("api-key", modelName); + } + + private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { + return new ConfigValidationException.Problem(String.format( + "SRCFG00014: The config property quarkus.langchain4j.bam%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); + } } diff --git a/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/config/Langchain4jBamConfig.java b/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/config/Langchain4jBamConfig.java index c13bf59c5..3d1c698ce 100644 --- a/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/config/Langchain4jBamConfig.java +++ b/bam/runtime/src/main/java/io/quarkiverse/langchain4j/bam/runtime/config/Langchain4jBamConfig.java @@ -4,59 +4,84 @@ import java.net.URL; import java.time.Duration; +import java.util.Map; import java.util.Optional; import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.bam") public interface Langchain4jBamConfig { /** - * Base URL where the Ollama serving is running + * Default model config. */ - @ConfigDocDefault("https://bam-api.res.ibm.com") - Optional baseUrl(); + @WithParentName + BamConfig defaultConfig(); /** - * BAM API key + * Named model config. */ - String apiKey(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * Timeout for BAM calls - */ - @WithDefault("10s") - Duration timeout(); + @ConfigGroup + interface BamConfig { + /** + * Base URL where the Ollama serving is running + */ + @ConfigDocDefault("https://bam-api.res.ibm.com") + Optional baseUrl(); - /** - * Version to use - */ - @WithDefault("2024-01-10") - String version(); + /** + * BAM API key + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + String apiKey(); - /** - * Whether the BAM client should log requests - */ - @WithDefault("false") - Boolean logRequests(); + /** + * Timeout for BAM calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Whether the BAM client should log responses - */ - @WithDefault("false") - Boolean logResponses(); + /** + * Version to use + */ + @WithDefault("2024-01-10") + String version(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * Whether the BAM client should log requests + */ + @WithDefault("false") + Boolean logRequests(); - /** - * Embedding model related settings - */ - EmbeddingModelConfig embeddingModel(); + /** + * Whether the BAM client should log responses + */ + @WithDefault("false") + Boolean logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 1972471d9..ba2049726 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -48,8 +48,10 @@ import dev.langchain4j.exception.IllegalConfigurationException; import dev.langchain4j.service.V; +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport; @@ -120,6 +122,7 @@ public class AiServicesProcessor { public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class); private static final String[] EMPTY_STRING_ARRAY = new String[0]; private static final String METRICS_DEFAULT_NAME = "langchain4j.aiservices"; + public static final ClassType CHAT_MODEL_CLASS_TYPE = ClassType.create(Langchain4jDotNames.CHAT_MODEL); @BuildStep public void nativeSupport(CombinedIndexBuildItem indexBuildItem, @@ -167,7 +170,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, BuildProducer reflectiveClassProducer) { IndexView index = indexBuildItem.getIndex(); - boolean needChatModelBean = false; + Set chatModelNames = new HashSet<>(); boolean needModerationModelBean = false; for (AnnotationInstance instance : index.getAnnotations(Langchain4jDotNames.REGISTER_AI_SERVICES)) { if (instance.target().kind() != AnnotationTarget.Kind.CLASS) { @@ -187,8 +190,16 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } } + String modeName = NamedModelUtil.DEFAULT_NAME; if (chatLanguageModelSupplierClassDotName == null) { - needChatModelBean = true; + AnnotationValue modelNameValue = instance.value("modelName"); + if (modelNameValue != null) { + String modelNameValueStr = modelNameValue.asString(); + if ((modelNameValueStr != null) && !modelNameValueStr.isEmpty()) { + modeName = modelNameValueStr; + } + } + chatModelNames.add(modeName); } List toolDotNames = Collections.emptyList(); @@ -251,11 +262,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, retrieverClassDotName, auditServiceSupplierClassName, moderationModelSupplierClassName, - cdiScope)); + cdiScope, + modeName)); } - if (needChatModelBean) { - requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem()); + for (String chatModelName : chatModelNames) { + requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem(chatModelName)); } if (needModerationModelBean) { requestModerationModelBeanProducer.produce(new RequestModerationModelBeanBuildItem()); @@ -282,7 +294,7 @@ private void validateSupplierAndRegisterForReflection(DotName supplierDotName, I @Record(ExecutionTime.STATIC_INIT) public void handleDeclarativeServices(AiServicesRecorder recorder, List declarativeAiServiceItems, - Optional selectedChatModelProvider, + List selectedChatModelProvider, BuildProducer syntheticBeanProducer, BuildProducer unremoveableProducer) { @@ -319,6 +331,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getModerationModelSupplierDotName().toString() : null); + String chatModelName = bi.getChatModelName(); SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem .configure(QuarkusAiServiceContext.class) .createWith(recorder.createDeclarativeAiService( @@ -326,14 +339,20 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName, auditServiceClassSupplierName, - moderationModelSupplierClassName))) + moderationModelSupplierClassName, chatModelName))) .setRuntimeInit() .addQualifier() .annotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName) .done() .scope(Dependent.class); - if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed? - configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL)); + if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) { + if (NamedModelUtil.isDefault(chatModelName)) { + configurator.addInjectionPoint(CHAT_MODEL_CLASS_TYPE); + } else { + configurator.addInjectionPoint(CHAT_MODEL_CLASS_TYPE, + AnnotationInstance.builder(ModelName.class).add("value", chatModelName).build()); + + } needsChatModelBean = true; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java index 52b053887..e82e9fb88 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java @@ -3,21 +3,34 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.IMAGE_MODEL; +import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.MODEL_NAME; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.MODERATION_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.STREAMING_CHAT_MODEL; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; -import org.apache.poi.ss.formula.functions.T; +import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.DotName; import com.fasterxml.jackson.databind.ObjectMapper; import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig; -import io.quarkiverse.langchain4j.deployment.items.*; +import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ImageModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ProviderHolder; +import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedImageModelProviderBuildItem; +import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem; import io.quarkiverse.langchain4j.runtime.Langchain4jRecorder; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.arc.processor.InjectionPointInfo; @@ -59,91 +72,172 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished BuildProducer selectedImageProducer, List inProcessEmbeddingBuildItems) { - boolean chatModelBeanRequested = false; - boolean streamingChatModelBeanRequested = false; - boolean embeddingModelBeanRequested = false; - boolean moderationModelBeanRequested = false; - boolean imageModelBeanRequested = false; + Set requestedChatModels = new HashSet<>(); + Set requestedStreamingChatModels = new HashSet<>(); + Set requestEmbeddingModels = new HashSet<>(); + Set requestedModerationModels = new HashSet<>(); + Set requestedImageModels = new HashSet<>(); + for (InjectionPointInfo ip : beanDiscoveryFinished.getInjectionPoints()) { DotName requiredName = ip.getRequiredType().name(); + String modelName = determineModelName(ip); if (CHAT_MODEL.equals(requiredName)) { - chatModelBeanRequested = true; + requestedChatModels.add(modelName); } else if (STREAMING_CHAT_MODEL.equals(requiredName)) { - streamingChatModelBeanRequested = true; + requestedStreamingChatModels.add(modelName); } else if (EMBEDDING_MODEL.equals(requiredName)) { - embeddingModelBeanRequested = true; + requestEmbeddingModels.add(modelName); } else if (MODERATION_MODEL.equals(requiredName)) { - moderationModelBeanRequested = true; + requestedModerationModels.add(modelName); } else if (IMAGE_MODEL.equals(requiredName)) { - imageModelBeanRequested = true; + requestedImageModels.add(modelName); } } - if (!requestChatModelBeanItems.isEmpty()) { - chatModelBeanRequested = true; + for (var bi : requestChatModelBeanItems) { + requestedChatModels.add(bi.getModelName()); } - if (!requestModerationModelBeanBuildItems.isEmpty()) { - moderationModelBeanRequested = true; + for (var bi : requestModerationModelBeanBuildItems) { + requestedModerationModels.add(bi.getModelName()); } - if (chatModelBeanRequested || streamingChatModelBeanRequested) { - selectedChatProducer.produce( - new SelectedChatModelProviderBuildItem( - selectProvider( - chatCandidateItems, - buildConfig.chatModel().provider(), - "ChatLanguageModel or StreamingChatLanguageModel", - "chat-model"))); + if (!requestedChatModels.isEmpty() || !requestedStreamingChatModels.isEmpty()) { + Set allChatModelNames = new HashSet<>(requestedChatModels); + allChatModelNames.addAll(requestedStreamingChatModels); + for (String modelName : allChatModelNames) { + Optional userSelectedProvider; + String configNamespace; + if (NamedModelUtil.isDefault(modelName)) { + userSelectedProvider = buildConfig.defaultConfig().chatModel().provider(); + configNamespace = "chat-model"; + } else { + if (buildConfig.namedConfig().containsKey(modelName)) { + userSelectedProvider = buildConfig.namedConfig().get(modelName).chatModel().provider(); + } else { + userSelectedProvider = Optional.empty(); + } + configNamespace = modelName + ".chat-model"; + } + + selectedChatProducer.produce( + new SelectedChatModelProviderBuildItem( + selectProvider( + chatCandidateItems, + userSelectedProvider, + "ChatLanguageModel or StreamingChatLanguageModel", + configNamespace), + modelName)); + } + } - if (embeddingModelBeanRequested) { + + for (String modelName : requestEmbeddingModels) { + Optional userSelectedProvider; + String configNamespace; + if (NamedModelUtil.isDefault(modelName)) { + userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider(); + configNamespace = "embedding-model"; + } else { + if (buildConfig.namedConfig().containsKey(modelName)) { + userSelectedProvider = buildConfig.namedConfig().get(modelName).embeddingModel().provider(); + } else { + userSelectedProvider = Optional.empty(); + } + configNamespace = modelName + ".embedding-model"; + } + selectedEmbeddingProducer.produce( new SelectedEmbeddingModelCandidateBuildItem( selectEmbeddingModelProvider( inProcessEmbeddingBuildItems, embeddingCandidateItems, - buildConfig.embeddingModel().provider(), + userSelectedProvider, "EmbeddingModel", - "embedding-model"))); + configNamespace), + modelName)); } - if (moderationModelBeanRequested) { + + for (String modelName : requestedModerationModels) { + Optional userSelectedProvider; + String configNamespace; + if (NamedModelUtil.isDefault(modelName)) { + userSelectedProvider = buildConfig.defaultConfig().moderationModel().provider(); + configNamespace = "moderation-model"; + } else { + if (buildConfig.namedConfig().containsKey(modelName)) { + userSelectedProvider = buildConfig.namedConfig().get(modelName).moderationModel().provider(); + } else { + userSelectedProvider = Optional.empty(); + } + configNamespace = modelName + ".moderation-model"; + } + selectedModerationProducer.produce( new SelectedModerationModelProviderBuildItem( selectProvider( moderationCandidateItems, - buildConfig.moderationModel().provider(), + userSelectedProvider, "ModerationModel", - "moderation-model"))); + configNamespace), + modelName)); } - if (imageModelBeanRequested) { + + for (String modelName : requestedImageModels) { + Optional userSelectedProvider; + String configNamespace; + if (NamedModelUtil.isDefault(modelName)) { + userSelectedProvider = buildConfig.defaultConfig().imageModel().provider(); + configNamespace = "image-model"; + } else { + if (buildConfig.namedConfig().containsKey(modelName)) { + userSelectedProvider = buildConfig.namedConfig().get(modelName).imageModel().provider(); + } else { + userSelectedProvider = Optional.empty(); + } + configNamespace = modelName + ".image-model"; + } + selectedImageProducer.produce( new SelectedImageModelProviderBuildItem( selectProvider( imageCandidateItems, - buildConfig.moderationModel().provider(), + userSelectedProvider, "ImageModel", - "image-model"))); + configNamespace), + modelName)); } } + private String determineModelName(InjectionPointInfo ip) { + AnnotationInstance modelNameInstance = ip.getRequiredQualifier(MODEL_NAME); + if (modelNameInstance != null) { + String value = modelNameInstance.value().asString(); + if ((value != null) && !value.isEmpty()) { + return value; + } + } + return NamedModelUtil.DEFAULT_NAME; + } + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") private String selectProvider( - List chatCandidateItems, + List candidateItems, Optional userSelectedProvider, - String requestedBeanName, + String beanType, String configNamespace) { - List availableProviders = chatCandidateItems.stream().map(ProviderHolder::getProvider) + List availableProviders = candidateItems.stream().map(ProviderHolder::getProvider) .collect(Collectors.toList()); if (availableProviders.isEmpty()) { throw new ConfigurationException(String.format( "A %s bean was requested, but no langchain4j providers were configured. Consider adding an extension like 'quarkus-langchain4j-openai'", - requestedBeanName)); + beanType)); } if (availableProviders.size() == 1) { // user has selected a provider, but it's not the one that is available if (userSelectedProvider.isPresent() && !availableProviders.get(0).equals(userSelectedProvider.get())) { throw new ConfigurationException(String.format( "A %s bean with provider=%s was requested was requested via configuration, but the only provider found on the classpath is %s.", - requestedBeanName, userSelectedProvider.get(), availableProviders.get(0))); + beanType, userSelectedProvider.get(), availableProviders.get(0))); } return availableProviders.get(0); } @@ -151,7 +245,7 @@ private String selectProvider( if (userSelectedProvider.isEmpty()) { throw new ConfigurationException(String.format( "A %s bean was requested, but since there are multiple available providers, the 'quarkus.langchain4j.%s.provider' needs to be set to one of the available options (%s).", - requestedBeanName, configNamespace, String.join(",", availableProviders))); + beanType, configNamespace, String.join(",", availableProviders))); } boolean matches = availableProviders.stream().anyMatch(ap -> ap.equals(userSelectedProvider.get())); if (matches) { @@ -159,7 +253,7 @@ private String selectProvider( } throw new ConfigurationException(String.format( "A %s bean was requested, but the value of 'quarkus.langchain4j.%s.provider' does not match any of the available options (%s).", - requestedBeanName, configNamespace, String.join(",", availableProviders))); + beanType, configNamespace, String.join(",", availableProviders))); } private String selectEmbeddingModelProvider( diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 1cb709b02..847ec7371 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -22,6 +22,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final DotName auditServiceClassSupplierDotName; private final DotName moderationModelSupplierDotName; private final ScopeInfo cdiScope; + private final String chatModelName; public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName, List toolDotNames, @@ -29,7 +30,8 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag DotName retrieverClassDotName, DotName auditServiceClassSupplierDotName, DotName moderationModelSupplierDotName, - ScopeInfo cdiScope) { + ScopeInfo cdiScope, + String chatModelName) { this.serviceClassInfo = serviceClassInfo; this.languageModelSupplierClassDotName = languageModelSupplierClassDotName; this.toolDotNames = toolDotNames; @@ -38,6 +40,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName; this.moderationModelSupplierDotName = moderationModelSupplierDotName; this.cdiScope = cdiScope; + this.chatModelName = chatModelName; } public ClassInfo getServiceClassInfo() { @@ -71,4 +74,8 @@ public DotName getModerationModelSupplierDotName() { public ScopeInfo getCdiScope() { return cdiScope; } + + public String getChatModelName() { + return chatModelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java index 77e07ab8a..44652016d 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java @@ -19,6 +19,7 @@ import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.UserName; import io.quarkiverse.langchain4j.CreatedAware; +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier; @@ -39,6 +40,8 @@ public class Langchain4jDotNames { static final DotName DESCRIPTION = DotName.createSimple(Description.class); static final DotName STRUCTURED_PROMPT = DotName.createSimple(StructuredPrompt.class); static final DotName STRUCTURED_PROMPT_PROCESSOR = DotName.createSimple(StructuredPromptProcessor.class); + + static final DotName MODEL_NAME = DotName.createSimple(ModelName.class); static final DotName REGISTER_AI_SERVICES = DotName.createSimple(RegisterAiService.class); static final DotName BEAN_CHAT_MODEL_SUPPLIER = DotName.createSimple( diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java index 3a3ddbad0..1196afedf 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java @@ -7,4 +7,14 @@ * even if no injection point exists. */ public final class RequestChatModelBeanBuildItem extends MultiBuildItem { + + private final String modelName; + + public RequestChatModelBeanBuildItem(String modelName) { + this.modelName = modelName; + } + + public String getModelName() { + return modelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java index da94cb365..446cef955 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java @@ -1,5 +1,6 @@ package io.quarkiverse.langchain4j.deployment; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.builder.item.MultiBuildItem; /** @@ -7,4 +8,15 @@ * even if no injection point exists. */ public final class RequestModerationModelBeanBuildItem extends MultiBuildItem { + + private final String modelName; + + // TODO: this is in anticipation of actually needing a configurable moderation model + public RequestModerationModelBeanBuildItem() { + this.modelName = NamedModelUtil.DEFAULT_NAME; + } + + public String getModelName() { + return modelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/ImageModelConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/ImageModelConfig.java new file mode 100644 index 000000000..a73f8ce71 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/ImageModelConfig.java @@ -0,0 +1,14 @@ +package io.quarkiverse.langchain4j.deployment.config; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface ImageModelConfig { + + /** + * The model provider to use + */ + Optional provider(); +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java index f12776669..f462b39c9 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/config/LangChain4jBuildConfig.java @@ -2,25 +2,52 @@ import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; +import java.util.Map; + +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = BUILD_TIME) @ConfigMapping(prefix = "quarkus.langchain4j") public interface LangChain4jBuildConfig { /** - * Chat model + * Default model config. */ - ChatModelConfig chatModel(); + @WithParentName + @ConfigDocSection + BaseConfig defaultConfig(); /** - * Embedding model + * Named model config. */ - EmbeddingModelConfig embeddingModel(); + @WithParentName + @ConfigDocMapKey("model-name") + @ConfigDocSection + Map namedConfig(); - /** - * Moderation model - */ - ModerationModelConfig moderationModel(); + interface BaseConfig { + /** + * Chat model + */ + ChatModelConfig chatModel(); + + /** + * Embedding model + */ + EmbeddingModelConfig embeddingModel(); + + /** + * Moderation model + */ + ModerationModelConfig moderationModel(); + + /** + * Image model + */ + ImageModelConfig imageModel(); + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedChatModelProviderBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedChatModelProviderBuildItem.java index 010187f7f..826ac2e3c 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedChatModelProviderBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedChatModelProviderBuildItem.java @@ -1,16 +1,22 @@ package io.quarkiverse.langchain4j.deployment.items; -import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.builder.item.MultiBuildItem; -public final class SelectedChatModelProviderBuildItem extends SimpleBuildItem { +public final class SelectedChatModelProviderBuildItem extends MultiBuildItem { private final String provider; + private final String modelName; - public SelectedChatModelProviderBuildItem(String provider) { + public SelectedChatModelProviderBuildItem(String provider, String modelName) { this.provider = provider; + this.modelName = modelName; } public String getProvider() { return provider; } + + public String getModelName() { + return modelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedEmbeddingModelCandidateBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedEmbeddingModelCandidateBuildItem.java index 4ab22059c..156352800 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedEmbeddingModelCandidateBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedEmbeddingModelCandidateBuildItem.java @@ -1,16 +1,22 @@ package io.quarkiverse.langchain4j.deployment.items; -import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.builder.item.MultiBuildItem; -public final class SelectedEmbeddingModelCandidateBuildItem extends SimpleBuildItem { +public final class SelectedEmbeddingModelCandidateBuildItem extends MultiBuildItem { private final String provider; + private final String modelName; - public SelectedEmbeddingModelCandidateBuildItem(String provider) { + public SelectedEmbeddingModelCandidateBuildItem(String provider, String modelName) { this.provider = provider; + this.modelName = modelName; } public String getProvider() { return provider; } + + public String getModelName() { + return modelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedImageModelProviderBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedImageModelProviderBuildItem.java index ad9130029..398060107 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedImageModelProviderBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedImageModelProviderBuildItem.java @@ -1,16 +1,22 @@ package io.quarkiverse.langchain4j.deployment.items; -import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.builder.item.MultiBuildItem; -public final class SelectedImageModelProviderBuildItem extends SimpleBuildItem { +public final class SelectedImageModelProviderBuildItem extends MultiBuildItem { private final String provider; + private final String modelName; - public SelectedImageModelProviderBuildItem(String provider) { + public SelectedImageModelProviderBuildItem(String provider, String modelName) { this.provider = provider; + this.modelName = modelName; } public String getProvider() { return provider; } + + public String getModelName() { + return modelName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedModerationModelProviderBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedModerationModelProviderBuildItem.java index b95cb5475..0e142356f 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedModerationModelProviderBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/SelectedModerationModelProviderBuildItem.java @@ -1,16 +1,22 @@ package io.quarkiverse.langchain4j.deployment.items; -import io.quarkus.builder.item.SimpleBuildItem; +import io.quarkus.builder.item.MultiBuildItem; -public final class SelectedModerationModelProviderBuildItem extends SimpleBuildItem { +public final class SelectedModerationModelProviderBuildItem extends MultiBuildItem { private final String provider; + private final String modelName; - public SelectedModerationModelProviderBuildItem(String provider) { + public SelectedModerationModelProviderBuildItem(String provider, String modelName) { this.provider = provider; + this.modelName = modelName; } public String getProvider() { return provider; } + + public String getModelName() { + return modelName; + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/ModelName.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/ModelName.java new file mode 100644 index 000000000..e831410b5 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/ModelName.java @@ -0,0 +1,63 @@ +package io.quarkiverse.langchain4j; + +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import jakarta.enterprise.util.AnnotationLiteral; +import jakarta.inject.Qualifier; + +/** + * Marker annotation to select a named model + * Configure the {@code name} parameter to select the model instance. + *

+ * For example, when configuring OpenAI like so: + * + *

+ * quarkus.langchain4j.openai.somename.api-key=somekey
+ * 
+ * + * Then to inject the proper {@code ChatLanguageModel}, you would need to use {@code Model} like so: + * + *
+ *     @Inject
+ *     &#ModelName("somename")
+ *     ChatLanguageModel model;
+ * 
+ * + * For the case of {@link RegisterAiService}, instead of using this annotation, users should set the {@code modelName} property + * instead. + */ +@Target({ ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER }) +@Retention(RUNTIME) +@Documented +@Qualifier +public @interface ModelName { + /** + * Specify the cluster name of the connection. + * + * @return the value + */ + String value() default ""; + + class Literal extends AnnotationLiteral implements ModelName { + + public static Literal of(String value) { + return new Literal(value); + } + + private final String value; + + public Literal(String value) { + this.value = value; + } + + @Override + public String value() { + return value; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 344bf2e6f..1e0069c27 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -47,6 +47,20 @@ */ Class> chatLanguageModelSupplier() default BeanChatLanguageModelSupplier.class; + /** + * When {@code chatLanguageModelSupplier} is set to {@code BeanChatLanguageModelSupplier.class} (which is the default) + * this allows the selection of the {@link ChatLanguageModel} CDI bean to use. + *

+ * If not set, the default model (i.e. the one configured without setting the model name) is used. + * An example of the default model configuration is the following: + * {@code quarkus.langchain4j.openai.chat-model.model-name=gpt-4-turbo-preview} + * + * If set, it uses the model configured by name. For example if this is set to {@code somename} + * an example configuration value for that named model could be: + * {@code quarkus.langchain4j.somename.openai.chat-model.model-name=gpt-4-turbo-preview} + */ + String modelName() default ""; + /** * Tool classes to use. All tools are expected to be CDI beans. */ diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index b9cdbc5e5..e1145147d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -18,6 +18,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.retriever.Retriever; +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; @@ -85,7 +86,13 @@ public T apply(SyntheticCreationalContext creationalContext) { .getConstructor().newInstance(); quarkusAiServices.chatLanguageModel(supplier.get()); } else { - quarkusAiServices.chatLanguageModel(creationalContext.getInjectedReference(ChatLanguageModel.class)); + if (NamedModelUtil.isDefault(info.getChatModelName())) { + quarkusAiServices + .chatLanguageModel(creationalContext.getInjectedReference(ChatLanguageModel.class)); + } else { + quarkusAiServices.chatLanguageModel(creationalContext.getInjectedReference(ChatLanguageModel.class, + ModelName.Literal.of(info.getChatModelName()))); + } } List toolsClasses = info.getToolsClassNames(); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/NamedModelUtil.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/NamedModelUtil.java new file mode 100644 index 000000000..70190aeb8 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/NamedModelUtil.java @@ -0,0 +1,10 @@ +package io.quarkiverse.langchain4j.runtime; + +public class NamedModelUtil { + + public static final String DEFAULT_NAME = ""; + + public static boolean isDefault(String modelName) { + return DEFAULT_NAME.equals(modelName); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index 24eb20928..b6a229cb0 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -14,13 +14,15 @@ public class DeclarativeAiServiceCreateInfo { private final String auditServiceClassSupplierName; private final String moderationModelSupplierClassName; + private final String chatModelName; @RecordableConstructor public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName, List toolsClassNames, String chatMemoryProviderSupplierClassName, String retrieverClassName, String auditServiceClassSupplierName, - String moderationModelSupplierClassName) { + String moderationModelSupplierClassName, + String chatModelName) { this.serviceClassName = serviceClassName; this.languageModelSupplierClassName = languageModelSupplierClassName; this.toolsClassNames = toolsClassNames; @@ -28,6 +30,7 @@ public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageMo this.retrieverClassName = retrieverClassName; this.auditServiceClassSupplierName = auditServiceClassSupplierName; this.moderationModelSupplierClassName = moderationModelSupplierClassName; + this.chatModelName = chatModelName; } public String getServiceClassName() { @@ -57,4 +60,8 @@ public String getAuditServiceClassSupplierName() { public String getModerationModelSupplierClassName() { return moderationModelSupplierClassName; } + + public String getChatModelName() { + return chatModelName; + } } diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc index db6bd7dba..1e40228f6 100644 --- a/docs/modules/ROOT/pages/ai-services.adoc +++ b/docs/modules/ROOT/pages/ai-services.adoc @@ -141,34 +141,26 @@ We'll explore an alternative approach to avoid manual memory handling in the <` interface: +The configuration of the various models could look like so: -[source,java] +[source,properties,subs=attributes+] ---- -package io.quarkiverse.langchain4j.sample; - -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.openai.OpenAiChatModel; +# ensure that the model with the name 'm1', is provided by OpenAI +quarkus.langchain4j.m1.chat-model.provider=openai +# ensure that the model with the name 'm1', is provided by OpenAI +quarkus.langchain4j.m2.chat-model.provider=huggingface -import java.util.function.Supplier; - -public class MyChatModelSupplier implements Supplier { - @Override - public ChatLanguageModel get() { - return OpenAiChatModel.builder() - .apiKey("...") - .build(); - } -} +# configure the various aspects of each model +quarkus.langchain4j.openai.m1.api-key=sk-... +quarkus.langchain4j.huggingface.m2.api-key=sk-... ---- [#memory] diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-huggingface.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-huggingface.adoc index 3ae2b8422..c44800add 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-huggingface.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-huggingface.adoc @@ -75,7 +75,7 @@ ifndef::add-copy-button-to-env-var[] Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE_API_KEY+++` endif::add-copy-button-to-env-var[] --|string -| +|`dummy` a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.timeout]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.timeout[quarkus.langchain4j.huggingface.timeout]` @@ -360,6 +360,311 @@ endif::add-copy-button-to-env-var[] --|boolean |`false` + +h|[[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.named-config-named-model-config]]link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.named-config-named-model-config[Named model config] + +h|Type +h|Default + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.api-key]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.api-key[quarkus.langchain4j.huggingface."model-name".api-key]` + + +[.description] +-- +HuggingFace API key + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__API_KEY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__API_KEY+++` +endif::add-copy-button-to-env-var[] +--|string +|`dummy` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.timeout]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.timeout[quarkus.langchain4j.huggingface."model-name".timeout]` + + +[.description] +-- +Timeout for HuggingFace calls + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__TIMEOUT+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__TIMEOUT+++` +endif::add-copy-button-to-env-var[] +--|link:https://docs.oracle.com/javase/8/docs/api/java/time/Duration.html[Duration] + link:#duration-note-anchor-{summaryTableId}[icon:question-circle[], title=More information about the Duration format] +|`10S` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.inference-endpoint-url]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.inference-endpoint-url[quarkus.langchain4j.huggingface."model-name".chat-model.inference-endpoint-url]` + + +[.description] +-- +The URL of the inference endpoint for the chat model. + +When using Hugging Face with the inference API, the URL is `https://api-inference.huggingface.co/models/`, for example `https://api-inference.huggingface.co/models/google/flan-t5-small`. + +When using a deployed inference endpoint, the URL is the URL of the endpoint. When using a local hugging face model, the URL is the URL of the local model. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_INFERENCE_ENDPOINT_URL+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_INFERENCE_ENDPOINT_URL+++` +endif::add-copy-button-to-env-var[] +--|link:https://docs.oracle.com/javase/8/docs/api/java/net/URL.html[URL] + +|`https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.temperature]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.temperature[quarkus.langchain4j.huggingface."model-name".chat-model.temperature]` + + +[.description] +-- +Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TEMPERATURE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TEMPERATURE+++` +endif::add-copy-button-to-env-var[] +--|double +|`1.0` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.max-new-tokens]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.max-new-tokens[quarkus.langchain4j.huggingface."model-name".chat-model.max-new-tokens]` + + +[.description] +-- +Int (0-250). The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_MAX_NEW_TOKENS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_MAX_NEW_TOKENS+++` +endif::add-copy-button-to-env-var[] +--|int +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.return-full-text]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.return-full-text[quarkus.langchain4j.huggingface."model-name".chat-model.return-full-text]` + + +[.description] +-- +If set to `false`, the return results will not contain the original query making it easier for prompting + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_RETURN_FULL_TEXT+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_RETURN_FULL_TEXT+++` +endif::add-copy-button-to-env-var[] +--|boolean +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.wait-for-model]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.wait-for-model[quarkus.langchain4j.huggingface."model-name".chat-model.wait-for-model]` + + +[.description] +-- +If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done. It is advised to only set this flag to true after receiving a 503 error as it will limit hanging in your application to known places + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_WAIT_FOR_MODEL+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_WAIT_FOR_MODEL+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`true` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.do-sample]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.do-sample[quarkus.langchain4j.huggingface."model-name".chat-model.do-sample]` + + +[.description] +-- +Whether or not to use sampling ; use greedy decoding otherwise. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_DO_SAMPLE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_DO_SAMPLE+++` +endif::add-copy-button-to-env-var[] +--|boolean +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.top-k]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.top-k[quarkus.langchain4j.huggingface."model-name".chat-model.top-k]` + + +[.description] +-- +The number of highest probability vocabulary tokens to keep for top-k-filtering. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TOP_K+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TOP_K+++` +endif::add-copy-button-to-env-var[] +--|int +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.top-p]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.top-p[quarkus.langchain4j.huggingface."model-name".chat-model.top-p]` + + +[.description] +-- +If set to less than `1`, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TOP_P+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_TOP_P+++` +endif::add-copy-button-to-env-var[] +--|double +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.repetition-penalty]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.repetition-penalty[quarkus.langchain4j.huggingface."model-name".chat-model.repetition-penalty]` + + +[.description] +-- +The parameter for repetition penalty. 1.0 means no penalty. See link:https://arxiv.org/pdf/1909.05858.pdf[this paper] for more details. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_REPETITION_PENALTY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_REPETITION_PENALTY+++` +endif::add-copy-button-to-env-var[] +--|double +| + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.log-requests]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.log-requests[quarkus.langchain4j.huggingface."model-name".chat-model.log-requests]` + + +[.description] +-- +Whether chat model requests should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.log-responses]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.chat-model.log-responses[quarkus.langchain4j.huggingface."model-name".chat-model.log-responses]` + + +[.description] +-- +Whether chat model responses should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__CHAT_MODEL_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.embedding-model.inference-endpoint-url]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.embedding-model.inference-endpoint-url[quarkus.langchain4j.huggingface."model-name".embedding-model.inference-endpoint-url]` + + +[.description] +-- +The URL of the inference endpoint for the embedding. + +When using Hugging Face with the inference API, the URL is `https://api-inference.huggingface.co/pipeline/feature-extraction/`, for example `https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/all-mpnet-base-v2`. + +When using a deployed inference endpoint, the URL is the URL of the endpoint. When using a local hugging face model, the URL is the URL of the local model. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__EMBEDDING_MODEL_INFERENCE_ENDPOINT_URL+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__EMBEDDING_MODEL_INFERENCE_ENDPOINT_URL+++` +endif::add-copy-button-to-env-var[] +--|link:https://docs.oracle.com/javase/8/docs/api/java/net/URL.html[URL] + +|`https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/all-MiniLM-L6-v2` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.embedding-model.wait-for-model]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.embedding-model.wait-for-model[quarkus.langchain4j.huggingface."model-name".embedding-model.wait-for-model]` + + +[.description] +-- +If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your inference done. It is advised to only set this flag to true after receiving a 503 error as it will limit hanging in your application to known places + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__EMBEDDING_MODEL_WAIT_FOR_MODEL+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__EMBEDDING_MODEL_WAIT_FOR_MODEL+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`true` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.log-requests]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.log-requests[quarkus.langchain4j.huggingface."model-name".log-requests]` + + +[.description] +-- +Whether the HuggingFace client should log requests + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.log-responses]]`link:#quarkus-langchain4j-huggingface_quarkus.langchain4j.huggingface.-model-name-.log-responses[quarkus.langchain4j.huggingface."model-name".log-responses]` + + +[.description] +-- +Whether the HuggingFace client should log responses + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_HUGGINGFACE__MODEL_NAME__LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + |=== ifndef::no-duration-note[] [NOTE] diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc index b352d064d..2cc80b8d7 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc @@ -109,7 +109,7 @@ ifndef::add-copy-button-to-env-var[] Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI_API_KEY+++` endif::add-copy-button-to-env-var[] --|string -| +|`dummy` a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id[quarkus.langchain4j.openai.organization-id]` @@ -657,6 +657,591 @@ endif::add-copy-button-to-env-var[] --|boolean |`false` + +h|[[quarkus-langchain4j-openai_quarkus.langchain4j.openai.named-config-named-model-config]]link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.named-config-named-model-config[Named model config] + +h|Type +h|Default + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.base-url]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.base-url[quarkus.langchain4j.openai."model-name".base-url]` + + +[.description] +-- +Base URL of OpenAI API + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__BASE_URL+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__BASE_URL+++` +endif::add-copy-button-to-env-var[] +--|string +|`https://api.openai.com/v1/` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.api-key]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.api-key[quarkus.langchain4j.openai."model-name".api-key]` + + +[.description] +-- +OpenAI API key + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__API_KEY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__API_KEY+++` +endif::add-copy-button-to-env-var[] +--|string +|`dummy` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.organization-id]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.organization-id[quarkus.langchain4j.openai."model-name".organization-id]` + + +[.description] +-- +OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional) + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__ORGANIZATION_ID+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__ORGANIZATION_ID+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.timeout]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.timeout[quarkus.langchain4j.openai."model-name".timeout]` + + +[.description] +-- +Timeout for OpenAI calls + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__TIMEOUT+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__TIMEOUT+++` +endif::add-copy-button-to-env-var[] +--|link:https://docs.oracle.com/javase/8/docs/api/java/time/Duration.html[Duration] + link:#duration-note-anchor-{summaryTableId}[icon:question-circle[], title=More information about the Duration format] +|`10S` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.max-retries]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.max-retries[quarkus.langchain4j.openai."model-name".max-retries]` + + +[.description] +-- +The maximum number of times to retry + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MAX_RETRIES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MAX_RETRIES+++` +endif::add-copy-button-to-env-var[] +--|int +|`3` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.log-requests]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.log-requests[quarkus.langchain4j.openai."model-name".log-requests]` + + +[.description] +-- +Whether the OpenAI client should log requests + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.log-responses]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.log-responses[quarkus.langchain4j.openai."model-name".log-responses]` + + +[.description] +-- +Whether the OpenAI client should log responses + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.model-name]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.model-name[quarkus.langchain4j.openai."model-name".chat-model.model-name]` + + +[.description] +-- +Model name to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_MODEL_NAME+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_MODEL_NAME+++` +endif::add-copy-button-to-env-var[] +--|string +|`gpt-3.5-turbo` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.temperature]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.temperature[quarkus.langchain4j.openai."model-name".chat-model.temperature]` + + +[.description] +-- +What sampling temperature to use, with values between 0 and 2. Higher values means the model will take more risks. A value of 0.9 is good for more creative applications, while 0 (argmax sampling) is good for ones with a well-defined answer. It is recommended to alter this or topP, but not both. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_TEMPERATURE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_TEMPERATURE+++` +endif::add-copy-button-to-env-var[] +--|double +|`1.0` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.top-p]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.top-p[quarkus.langchain4j.openai."model-name".chat-model.top-p]` + + +[.description] +-- +An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with topP probability mass. 0.1 means only the tokens comprising the top 10% probability mass are considered. It is recommended to alter this or topP, but not both. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_TOP_P+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_TOP_P+++` +endif::add-copy-button-to-env-var[] +--|double +|`1.0` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.max-tokens]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.max-tokens[quarkus.langchain4j.openai."model-name".chat-model.max-tokens]` + + +[.description] +-- +The maximum number of tokens to generate in the completion. The token count of your prompt plus max_tokens can't exceed the model's context length. Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_MAX_TOKENS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_MAX_TOKENS+++` +endif::add-copy-button-to-env-var[] +--|int +| + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.presence-penalty]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.presence-penalty[quarkus.langchain4j.openai."model-name".chat-model.presence-penalty]` + + +[.description] +-- +Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_PRESENCE_PENALTY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_PRESENCE_PENALTY+++` +endif::add-copy-button-to-env-var[] +--|double +|`0` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.frequency-penalty]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.frequency-penalty[quarkus.langchain4j.openai."model-name".chat-model.frequency-penalty]` + + +[.description] +-- +Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_FREQUENCY_PENALTY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_FREQUENCY_PENALTY+++` +endif::add-copy-button-to-env-var[] +--|double +|`0` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.log-requests]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.log-requests[quarkus.langchain4j.openai."model-name".chat-model.log-requests]` + + +[.description] +-- +Whether chat model requests should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.log-responses]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.chat-model.log-responses[quarkus.langchain4j.openai."model-name".chat-model.log-responses]` + + +[.description] +-- +Whether chat model responses should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__CHAT_MODEL_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.model-name]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.model-name[quarkus.langchain4j.openai."model-name".embedding-model.model-name]` + + +[.description] +-- +Model name to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_MODEL_NAME+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_MODEL_NAME+++` +endif::add-copy-button-to-env-var[] +--|string +|`text-embedding-ada-002` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.log-requests]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.log-requests[quarkus.langchain4j.openai."model-name".embedding-model.log-requests]` + + +[.description] +-- +Whether embedding model requests should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.log-responses]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.log-responses[quarkus.langchain4j.openai."model-name".embedding-model.log-responses]` + + +[.description] +-- +Whether embedding model responses should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.user]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.embedding-model.user[quarkus.langchain4j.openai."model-name".embedding-model.user]` + + +[.description] +-- +A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_USER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__EMBEDDING_MODEL_USER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.model-name]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.model-name[quarkus.langchain4j.openai."model-name".moderation-model.model-name]` + + +[.description] +-- +Model name to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_MODEL_NAME+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_MODEL_NAME+++` +endif::add-copy-button-to-env-var[] +--|string +|`text-moderation-latest` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.log-requests]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.log-requests[quarkus.langchain4j.openai."model-name".moderation-model.log-requests]` + + +[.description] +-- +Whether moderation model requests should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.log-responses]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.moderation-model.log-responses[quarkus.langchain4j.openai."model-name".moderation-model.log-responses]` + + +[.description] +-- +Whether moderation model responses should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__MODERATION_MODEL_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.model-name]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.model-name[quarkus.langchain4j.openai."model-name".image-model.model-name]` + + +[.description] +-- +Model name to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_MODEL_NAME+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_MODEL_NAME+++` +endif::add-copy-button-to-env-var[] +--|string +|`dall-e-3` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.persist]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.persist[quarkus.langchain4j.openai."model-name".image-model.persist]` + + +[.description] +-- +Configure whether the generated images will be saved to disk. By default, persisting is disabled, but it is implicitly enabled when `quarkus.langchain4j.openai.image-mode.directory` is set and this property is not to `false` + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_PERSIST+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_PERSIST+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.persist-directory]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.persist-directory[quarkus.langchain4j.openai."model-name".image-model.persist-directory]` + + +[.description] +-- +The path where the generated images will be persisted to disk. This only applies of `quarkus.langchain4j.openai.image-mode.persist` is not set to `false`. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_PERSIST_DIRECTORY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_PERSIST_DIRECTORY+++` +endif::add-copy-button-to-env-var[] +--|path +|`${java.io.tmpdir}/dall-e-images` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.response-format]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.response-format[quarkus.langchain4j.openai."model-name".image-model.response-format]` + + +[.description] +-- +The format in which the generated images are returned. + +Must be one of `url` or `b64_json` + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_RESPONSE_FORMAT+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_RESPONSE_FORMAT+++` +endif::add-copy-button-to-env-var[] +--|string +|`url` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.size]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.size[quarkus.langchain4j.openai."model-name".image-model.size]` + + +[.description] +-- +The size of the generated images. + +Must be one of `1024x1024`, `1792x1024`, or `1024x1792` when the model is `dall-e-3`. + +Must be one of `256x256`, `512x512`, or `1024x1024` when the model is `dall-e-2`. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_SIZE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_SIZE+++` +endif::add-copy-button-to-env-var[] +--|string +|`1024x1024` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.quality]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.quality[quarkus.langchain4j.openai."model-name".image-model.quality]` + + +[.description] +-- +The quality of the image that will be generated. + +`hd` creates images with finer details and greater consistency across the image. + +This param is only supported for when the model is `dall-e-3`. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_QUALITY+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_QUALITY+++` +endif::add-copy-button-to-env-var[] +--|string +|`standard` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.number]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.number[quarkus.langchain4j.openai."model-name".image-model.number]` + + +[.description] +-- +The number of images to generate. + +Must be between 1 and 10. + +When the model is dall-e-3, only n=1 is supported. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_NUMBER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_NUMBER+++` +endif::add-copy-button-to-env-var[] +--|int +|`1` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.style]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.style[quarkus.langchain4j.openai."model-name".image-model.style]` + + +[.description] +-- +The style of the generated images. + +Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. + +This param is only supported for when the model is `dall-e-3`. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_STYLE+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_STYLE+++` +endif::add-copy-button-to-env-var[] +--|string +|`vivid` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.user]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.user[quarkus.langchain4j.openai."model-name".image-model.user]` + + +[.description] +-- +A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_USER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_USER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.log-requests]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.log-requests[quarkus.langchain4j.openai."model-name".image-model.log-requests]` + + +[.description] +-- +Whether image model requests should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_LOG_REQUESTS+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_LOG_REQUESTS+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + + +a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.log-responses]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.-model-name-.image-model.log-responses[quarkus.langchain4j.openai."model-name".image-model.log-responses]` + + +[.description] +-- +Whether image model responses should be logged + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_LOG_RESPONSES+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI__MODEL_NAME__IMAGE_MODEL_LOG_RESPONSES+++` +endif::add-copy-button-to-env-var[] +--|boolean +|`false` + |=== ifndef::no-duration-note[] [NOTE] diff --git a/docs/modules/ROOT/pages/includes/quarkus-langchain4j.adoc b/docs/modules/ROOT/pages/includes/quarkus-langchain4j.adoc index eee803d16..4f042e988 100644 --- a/docs/modules/ROOT/pages/includes/quarkus-langchain4j.adoc +++ b/docs/modules/ROOT/pages/includes/quarkus-langchain4j.adoc @@ -5,7 +5,7 @@ icon:lock[title=Fixed at build time] Configuration property fixed at build time [.configuration-reference.searchable, cols="80,.^10,.^10"] |=== -h|[[quarkus-langchain4j_configuration]]link:#quarkus-langchain4j_configuration[Configuration property] +h|[[quarkus-langchain4j_quarkus.langchain4j.default-config-default-model-config]]link:#quarkus-langchain4j_quarkus.langchain4j.default-config-default-model-config[Default model config] h|Type h|Default @@ -60,4 +60,94 @@ endif::add-copy-button-to-env-var[] --|string | + +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j_quarkus.langchain4j.image-model.provider]]`link:#quarkus-langchain4j_quarkus.langchain4j.image-model.provider[quarkus.langchain4j.image-model.provider]` + + +[.description] +-- +The model provider to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_IMAGE_MODEL_PROVIDER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J_IMAGE_MODEL_PROVIDER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +h|[[quarkus-langchain4j_quarkus.langchain4j.named-config-named-model-config]]link:#quarkus-langchain4j_quarkus.langchain4j.named-config-named-model-config[Named model config] + +h|Type +h|Default + +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j_quarkus.langchain4j.-model-name-.chat-model.provider]]`link:#quarkus-langchain4j_quarkus.langchain4j.-model-name-.chat-model.provider[quarkus.langchain4j."model-name".chat-model.provider]` + + +[.description] +-- +The model provider to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J__MODEL_NAME__CHAT_MODEL_PROVIDER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J__MODEL_NAME__CHAT_MODEL_PROVIDER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j_quarkus.langchain4j.-model-name-.embedding-model.provider]]`link:#quarkus-langchain4j_quarkus.langchain4j.-model-name-.embedding-model.provider[quarkus.langchain4j."model-name".embedding-model.provider]` + + +[.description] +-- +The model provider to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J__MODEL_NAME__EMBEDDING_MODEL_PROVIDER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J__MODEL_NAME__EMBEDDING_MODEL_PROVIDER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j_quarkus.langchain4j.-model-name-.moderation-model.provider]]`link:#quarkus-langchain4j_quarkus.langchain4j.-model-name-.moderation-model.provider[quarkus.langchain4j."model-name".moderation-model.provider]` + + +[.description] +-- +The model provider to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J__MODEL_NAME__MODERATION_MODEL_PROVIDER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J__MODEL_NAME__MODERATION_MODEL_PROVIDER+++` +endif::add-copy-button-to-env-var[] +--|string +| + + +a|icon:lock[title=Fixed at build time] [[quarkus-langchain4j_quarkus.langchain4j.-model-name-.image-model.provider]]`link:#quarkus-langchain4j_quarkus.langchain4j.-model-name-.image-model.provider[quarkus.langchain4j."model-name".image-model.provider]` + + +[.description] +-- +The model provider to use + +ifdef::add-copy-button-to-env-var[] +Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J__MODEL_NAME__IMAGE_MODEL_PROVIDER+++[] +endif::add-copy-button-to-env-var[] +ifndef::add-copy-button-to-env-var[] +Environment variable: `+++QUARKUS_LANGCHAIN4J__MODEL_NAME__IMAGE_MODEL_PROVIDER+++` +endif::add-copy-button-to-env-var[] +--|string +| + |=== \ No newline at end of file diff --git a/hugging-face/deployment/src/main/java/io/quarkiverse/langchain4j/huggingface/deployment/HuggingFaceProcessor.java b/hugging-face/deployment/src/main/java/io/quarkiverse/langchain4j/huggingface/deployment/HuggingFaceProcessor.java index 0978f8ce1..3e83e3114 100644 --- a/hugging-face/deployment/src/main/java/io/quarkiverse/langchain4j/huggingface/deployment/HuggingFaceProcessor.java +++ b/hugging-face/deployment/src/main/java/io/quarkiverse/langchain4j/huggingface/deployment/HuggingFaceProcessor.java @@ -3,18 +3,20 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; -import java.util.Optional; -import java.util.function.Supplier; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; -import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; import io.quarkiverse.langchain4j.huggingface.runtime.HuggingFaceRecorder; import io.quarkiverse.langchain4j.huggingface.runtime.config.Langchain4jHuggingFaceConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -35,7 +37,6 @@ FeatureBuildItem feature() { @BuildStep public void providerCandidates(BuildProducer chatProducer, BuildProducer embeddingProducer, - BuildProducer moderationProducer, Langchain4jHuggingFaceBuildConfig config) { if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { chatProducer.produce(new ChatModelProviderCandidateBuildItem(PROVIDER)); @@ -43,40 +44,49 @@ public void providerCandidates(BuildProducer selectedChatItem, - Optional selectedEmbedding, + List selectedChatItem, + List selectedEmbedding, Langchain4jHuggingFaceConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } - if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - Supplier supplier = recorder.embeddingModel(config); - if (supplier != null) { - beanProducer.produce(SyntheticBeanBuildItem + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem .configure(EMBEDDING_MODEL) .setRuntimeInit() .defaultBean() .scope(ApplicationScoped.class) - .supplier(supplier) - .done()); + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); } } } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); + } + } } diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java index 644a1e713..bdf37b453 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java @@ -3,7 +3,6 @@ import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; import java.net.URL; -import java.util.Optional; import java.util.function.Supplier; import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel; @@ -11,34 +10,41 @@ import io.quarkiverse.langchain4j.huggingface.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.huggingface.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.huggingface.runtime.config.Langchain4jHuggingFaceConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.annotations.Recorder; import io.smallrye.config.ConfigValidationException; @Recorder public class HuggingFaceRecorder { - public Supplier chatModel(Langchain4jHuggingFaceConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - URL url = runtimeConfig.chatModel().inferenceEndpointUrl(); - if (apiKeyOpt.isEmpty() && url.toExternalForm().contains("api-inference.huggingface.co")) { // when using the default base URL an API key is required - throw new ConfigValidationException(createApiKeyConfigProblems()); + private static final String DUMMY_KEY = "dummy"; + private static final String HUGGING_FACE_URL_MARKER = "api-inference.huggingface.co"; + + public Supplier chatModel(Langchain4jHuggingFaceConfig runtimeConfig, String modelName) { + Langchain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig = correspondingHuggingFaceConfig(runtimeConfig, + modelName); + String apiKey = huggingFaceConfig.apiKey(); + ChatModelConfig chatModelConfig = huggingFaceConfig.chatModel(); + URL url = chatModelConfig.inferenceEndpointUrl(); + + if (DUMMY_KEY.equals(apiKey) && url.toExternalForm().contains(HUGGING_FACE_URL_MARKER)) { // when using the default base URL an API key is required + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); } - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); var builder = QuarkusHuggingFaceChatModel.builder() .url(url) - .timeout(runtimeConfig.timeout()) + .timeout(huggingFaceConfig.timeout()) .temperature(chatModelConfig.temperature()) .waitForModel(chatModelConfig.waitForModel()) .doSample(chatModelConfig.doSample()) .topP(chatModelConfig.topP()) .topK(chatModelConfig.topK()) .repetitionPenalty(chatModelConfig.repetitionPenalty()) - .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())); + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), huggingFaceConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), huggingFaceConfig.logResponses())); - if (apiKeyOpt.isPresent()) { - builder.accessToken(apiKeyOpt.get()); + if (!DUMMY_KEY.equals(apiKey)) { + builder.accessToken(apiKey); } if (chatModelConfig.returnFullText().isPresent()) { builder.returnFullText(chatModelConfig.returnFullText().get()); @@ -56,24 +62,24 @@ public Object get() { }; } - public Supplier embeddingModel(Langchain4jHuggingFaceConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel(); - Optional urlOpt = embeddingModelConfig.inferenceEndpointUrl(); - if (urlOpt.isEmpty()) { - return null; - } - if (apiKeyOpt.isEmpty() && urlOpt.isPresent() - && urlOpt.get().toExternalForm().contains("api-inference.huggingface.co")) { // when using the default base URL an API key is required - throw new ConfigValidationException(createApiKeyConfigProblems()); + public Supplier embeddingModel(Langchain4jHuggingFaceConfig runtimeConfig, String modelName) { + Langchain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig = correspondingHuggingFaceConfig(runtimeConfig, + modelName); + String apiKey = huggingFaceConfig.apiKey(); + EmbeddingModelConfig embeddingModelConfig = huggingFaceConfig.embeddingModel(); + URL url = embeddingModelConfig.inferenceEndpointUrl(); + + if (DUMMY_KEY.equals(apiKey) && url.toExternalForm().contains(HUGGING_FACE_URL_MARKER)) { // when using the default base URL an API key is required + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); } + var builder = QuarkusHuggingFaceEmbeddingModel.builder() - .url(urlOpt.get()) - .timeout(runtimeConfig.timeout()) + .url(url) + .timeout(huggingFaceConfig.timeout()) .waitForModel(embeddingModelConfig.waitForModel()); - if (apiKeyOpt.isPresent()) { - builder.accessToken(apiKeyOpt.get()); + if (!DUMMY_KEY.equals(apiKey)) { + builder.accessToken(apiKey); } return new Supplier<>() { @@ -84,17 +90,28 @@ public Object get() { }; } - private ConfigValidationException.Problem[] createApiKeyConfigProblems() { - return createConfigProblems("api-key"); + private Langchain4jHuggingFaceConfig.HuggingFaceConfig correspondingHuggingFaceConfig( + Langchain4jHuggingFaceConfig runtimeConfig, String modelName) { + Langchain4jHuggingFaceConfig.HuggingFaceConfig huggingFaceConfig; + if (NamedModelUtil.isDefault(modelName)) { + huggingFaceConfig = runtimeConfig.defaultConfig(); + } else { + huggingFaceConfig = runtimeConfig.namedConfig().get(modelName); + } + return huggingFaceConfig; + } + + private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { + return createConfigProblems("api-key", modelName); } - private ConfigValidationException.Problem[] createConfigProblems(String key) { - return new ConfigValidationException.Problem[] { createConfigProblem(key) }; + private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; } - private ConfigValidationException.Problem createConfigProblem(String key) { + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { return new ConfigValidationException.Problem(String.format( - "SRCFG00014: The config property quarkus.langchain4j.huggingface.%s is required but it could not be found in any config source", - key)); + "SRCFG00014: The config property quarkus.langchain4j.huggingface%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); } } diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/EmbeddingModelConfig.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/EmbeddingModelConfig.java index 504f85763..341a4a7dc 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/EmbeddingModelConfig.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/EmbeddingModelConfig.java @@ -1,7 +1,6 @@ package io.quarkiverse.langchain4j.huggingface.runtime.config; import java.net.URL; -import java.util.Optional; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -23,7 +22,7 @@ public interface EmbeddingModelConfig { * When using a local hugging face model, the URL is the URL of the local model. */ @WithDefault(DEFAULT_INFERENCE_ENDPOINT_EMBEDDING) - Optional inferenceEndpointUrl(); + URL inferenceEndpointUrl(); /** * If the model is not ready, wait for it instead of receiving 503. It limits the number of requests required to get your diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java index 7813b7d54..b097d4435 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java @@ -3,47 +3,72 @@ import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; import java.time.Duration; +import java.util.Map; import java.util.Optional; import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.huggingface") public interface Langchain4jHuggingFaceConfig { /** - * HuggingFace API key + * Default model config. */ - Optional apiKey(); + @WithParentName + HuggingFaceConfig defaultConfig(); /** - * Timeout for HuggingFace calls + * Named model config. */ - @WithDefault("10s") - Duration timeout(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + @ConfigGroup + interface HuggingFaceConfig { + /** + * HuggingFace API key + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + String apiKey(); - /** - * Embedding model related settings - */ - EmbeddingModelConfig embeddingModel(); + /** + * Timeout for HuggingFace calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Whether the HuggingFace client should log requests - */ - @ConfigDocDefault("false") - Optional logRequests(); + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); - /** - * Whether the HuggingFace client should log responses - */ - @ConfigDocDefault("false") - Optional logResponses(); + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + + /** + * Whether the HuggingFace client should log requests + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether the HuggingFace client should log responses + */ + @ConfigDocDefault("false") + Optional logResponses(); + } } diff --git a/integration-tests/multiple-providers/pom.xml b/integration-tests/multiple-providers/pom.xml new file mode 100644 index 000000000..7a2e5bbcf --- /dev/null +++ b/integration-tests/multiple-providers/pom.xml @@ -0,0 +1,148 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-tests-multiple-providers + Quarkus LangChain4j - Integration Tests - Multiple providers + + true + + + + io.quarkus + quarkus-resteasy-reactive-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-azure-openai + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-hugging-face + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-bam + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-ollama + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openshift-ai + ${project.version} + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-azure-openai-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/multiple-providers/src/main/resources/application.properties b/integration-tests/multiple-providers/src/main/resources/application.properties new file mode 100644 index 000000000..2ce520e6a --- /dev/null +++ b/integration-tests/multiple-providers/src/main/resources/application.properties @@ -0,0 +1,40 @@ +quarkus.langchain4j.chat-model.provider=openai +quarkus.langchain4j.openai.api-key=test1 + +quarkus.langchain4j.c1.chat-model.provider=openai +quarkus.langchain4j.c1.embedding-model.provider=azure-openai + +quarkus.langchain4j.openai.c1.api-key=test2 + +quarkus.langchain4j.azure-openai.c1.resource-name=res +quarkus.langchain4j.azure-openai.c1.deployment-name=dep +quarkus.langchain4j.azure-openai.c1.api-key=test + +quarkus.langchain4j.c2.chat-model.provider=azure-openai +quarkus.langchain4j.c2.embedding-model.provider=azure-openai + +quarkus.langchain4j.azure-openai.c2.resource-name=res +quarkus.langchain4j.azure-openai.c2.deployment-name=dep +quarkus.langchain4j.azure-openai.c2.api-key=test3 + +quarkus.langchain4j.c3.chat-model.provider=huggingface +quarkus.langchain4j.huggingface.c3.api-key=test4 + +quarkus.langchain4j.c4.chat-model.provider=bam +quarkus.langchain4j.bam.c4.api-key=test5 + +quarkus.langchain4j.c5.chat-model.provider=ollama + +quarkus.langchain4j.c6.chat-model.provider=openshift-ai +quarkus.langchain4j.openshift-ai.c6.base-url=https://somecluster.somedomain.ai:443/api +quarkus.langchain4j.openshift-ai.c6.chat-model.model-id=somemodel + +quarkus.langchain4j.c7.chat-model.provider=watsonx +quarkus.langchain4j.watsonx.c7.base-url=https://somecluster.somedomain.ai:443/api +quarkus.langchain4j.watsonx.c7.api-key=test8 +quarkus.langchain4j.watsonx.c7.project-id=proj + +quarkus.langchain4j.e1.embedding-model.provider=openai +quarkus.langchain4j.openai.e1.api-key=test5 +quarkus.langchain4j.e2.embedding-model.provider=ollama + diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java new file mode 100644 index 000000000..bca18de46 --- /dev/null +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleChatProvidersTest.java @@ -0,0 +1,94 @@ +package org.acme.example.multiple; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel; +import io.quarkiverse.langchain4j.bam.BamChatModel; +import io.quarkiverse.langchain4j.huggingface.QuarkusHuggingFaceChatModel; +import io.quarkiverse.langchain4j.ollama.OllamaChatLanguageModel; +import io.quarkiverse.langchain4j.openshiftai.OpenshiftAiChatModel; +import io.quarkiverse.langchain4j.watsonx.WatsonChatModel; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class MultipleChatProvidersTest { + + @Inject + ChatLanguageModel defaultModel; + + @Inject + @ModelName("c1") + ChatLanguageModel firstNamedModel; + + @Inject + @ModelName("c2") + ChatLanguageModel secondNamedModel; + + @Inject + @ModelName("c3") + ChatLanguageModel thirdNamedModel; + + @Inject + @ModelName("c4") + ChatLanguageModel fourthNamedModel; + + @Inject + @ModelName("c5") + ChatLanguageModel fifthNamedModel; + + @Inject + @ModelName("c6") + ChatLanguageModel sixthNamedModel; + + @Inject + @ModelName("c7") + ChatLanguageModel seventhNamedModel; + + @Test + void defaultModel() { + assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiChatModel.class); + } + + @Test + void firstNamedModel() { + assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiChatModel.class); + } + + @Test + void secondNamedModel() { + assertThat(ClientProxy.unwrap(secondNamedModel)).isInstanceOf(AzureOpenAiChatModel.class); + } + + @Test + void thirdNamedModel() { + assertThat(ClientProxy.unwrap(thirdNamedModel)).isInstanceOf(QuarkusHuggingFaceChatModel.class); + } + + @Test + void fourthNamedModel() { + assertThat(ClientProxy.unwrap(fourthNamedModel)).isInstanceOf(BamChatModel.class); + } + + @Test + void fifthNamedModel() { + assertThat(ClientProxy.unwrap(fifthNamedModel)).isInstanceOf(OllamaChatLanguageModel.class); + } + + @Test + void sixthNamedModel() { + assertThat(ClientProxy.unwrap(sixthNamedModel)).isInstanceOf(OpenshiftAiChatModel.class); + } + + @Test + void seventhNamedModel() { + assertThat(ClientProxy.unwrap(seventhNamedModel)).isInstanceOf(WatsonChatModel.class); + } +} diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java new file mode 100644 index 000000000..a189589b9 --- /dev/null +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleEmbeddingModelsTest.java @@ -0,0 +1,55 @@ +package org.acme.example.multiple; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.openai.OpenAiEmbeddingModel; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiEmbeddingModel; +import io.quarkiverse.langchain4j.ollama.OllamaEmbeddingModel; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class MultipleEmbeddingModelsTest { + + @Inject + @ModelName("e1") + EmbeddingModel firstNamedModel; + + @Inject + @ModelName("e2") + EmbeddingModel secondNamedModel; + + @Inject + @ModelName("c1") + EmbeddingModel thirdNamedModel; + + @Inject + @ModelName("c2") + EmbeddingModel fourthNamedModel; + + @Test + void firstNamedModel() { + assertThat(ClientProxy.unwrap(firstNamedModel)).isInstanceOf(OpenAiEmbeddingModel.class); + } + + @Test + void secondNamedModel() { + assertThat(ClientProxy.unwrap(secondNamedModel)).isInstanceOf(OllamaEmbeddingModel.class); + } + + @Test + void thirdNamedModel() { + assertThat(ClientProxy.unwrap(thirdNamedModel)).isInstanceOf(AzureOpenAiEmbeddingModel.class); + } + + @Test + void fourthNamedModel() { + assertThat(ClientProxy.unwrap(fourthNamedModel)).isInstanceOf(AzureOpenAiEmbeddingModel.class); + } +} diff --git a/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java new file mode 100644 index 000000000..67e82e76f --- /dev/null +++ b/integration-tests/multiple-providers/src/test/java/org/acme/example/multiple/MultipleModerationProvidersTest.java @@ -0,0 +1,24 @@ +package org.acme.example.multiple; + +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; + +import dev.langchain4j.model.moderation.ModerationModel; +import dev.langchain4j.model.openai.OpenAiModerationModel; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +public class MultipleModerationProvidersTest { + + @Inject + ModerationModel defaultModel; + + @Test + void defaultModel() { + assertThat(ClientProxy.unwrap(defaultModel)).isInstanceOf(OpenAiModerationModel.class); + } +} diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java index eb9aec48a..051be47f6 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java @@ -52,12 +52,12 @@ public class QuarkusRestApiResource { public QuarkusRestApiResource(Langchain4jOpenAiConfig runtimeConfig) throws URISyntaxException { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = runtimeConfig.defaultConfig(); this.restApi = QuarkusRestClientBuilder.newBuilder() - .baseUri(new URI(runtimeConfig.baseUrl())) + .baseUri(new URI(openAiConfig.baseUrl())) .build(OpenAiRestApi.class); - this.token = runtimeConfig.apiKey() - .orElseThrow(() -> new IllegalArgumentException("quarkus.langchain4j.openai.api-key must be provided")); - this.organizationId = runtimeConfig.organizationId().orElse(null); + this.token = openAiConfig.apiKey(); + this.organizationId = openAiConfig.organizationId().orElse(null); } @GET diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/QuarkusOpenAiClientChatResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/QuarkusOpenAiClientChatResource.java index 764c09040..0fe987b90 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/QuarkusOpenAiClientChatResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/QuarkusOpenAiClientChatResource.java @@ -21,9 +21,8 @@ public class QuarkusOpenAiClientChatResource { private final QuarkusOpenAiClient quarkusOpenAiClient; public QuarkusOpenAiClientChatResource(Langchain4jOpenAiConfig runtimeConfig) { - String token = runtimeConfig.apiKey() - .orElseThrow(() -> new IllegalArgumentException("quarkus.langchain4j.openai.api-key must be provided")); - String baseUrl = runtimeConfig.baseUrl(); + String token = runtimeConfig.defaultConfig().apiKey(); + String baseUrl = runtimeConfig.defaultConfig().baseUrl(); quarkusOpenAiClient = QuarkusOpenAiClient.builder().openAiApiKey(token).baseUrl(baseUrl).build(); } diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index e2754abff..f4bfe311d 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -17,6 +17,7 @@ ollama simple-ollama azure-openai + multiple-providers devui embed-all-minilm-l6-v2-q embed-all-minilm-l6-v2 diff --git a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java index b9b7eecb6..ff465eb0e 100644 --- a/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java +++ b/ollama/deployment/src/main/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaProcessor.java @@ -3,16 +3,20 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; import io.quarkiverse.langchain4j.ollama.runtime.OllamaRecorder; import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -46,28 +50,43 @@ public void providerCandidates(BuildProducer selectedChatItem, - Optional selectedEmbedding, + List selectedChatItem, + List selectedEmbedding, Langchain4jOllamaConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } - if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(EMBEDDING_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config)) - .done()); + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); } } } diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java index 54e76f950..6dc4900ce 100644 --- a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java @@ -6,14 +6,17 @@ import io.quarkiverse.langchain4j.ollama.OllamaEmbeddingModel; import io.quarkiverse.langchain4j.ollama.Options; import io.quarkiverse.langchain4j.ollama.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.ollama.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.ollama.runtime.config.Langchain4jOllamaConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.annotations.Recorder; @Recorder public class OllamaRecorder { - public Supplier chatModel(Langchain4jOllamaConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + public Supplier chatModel(Langchain4jOllamaConfig runtimeConfig, String modelName) { + Langchain4jOllamaConfig.OllamaConfig ollamaConfig = correspondingOllamaConfig(runtimeConfig, modelName); + ChatModelConfig chatModelConfig = ollamaConfig.chatModel(); Options.Builder optionsBuilder = Options.builder() .temperature(chatModelConfig.temperature()) .topK(chatModelConfig.topK()) @@ -23,11 +26,11 @@ public Supplier chatModel(Langchain4jOllamaConfig runtimeConfig) { optionsBuilder.stop(chatModelConfig.stop().get()); } var builder = OllamaChatLanguageModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .timeout(runtimeConfig.timeout()) + .baseUrl(ollamaConfig.baseUrl()) + .timeout(ollamaConfig.timeout()) + .logRequests(ollamaConfig.logRequests()) + .logResponses(ollamaConfig.logResponses()) .model(chatModelConfig.modelId()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) .options(optionsBuilder.build()); return new Supplier<>() { @@ -38,20 +41,21 @@ public Object get() { }; } - public Supplier embeddingModel(Langchain4jOllamaConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + public Supplier embeddingModel(Langchain4jOllamaConfig runtimeConfig, String modelName) { + Langchain4jOllamaConfig.OllamaConfig ollamaConfig = correspondingOllamaConfig(runtimeConfig, modelName); + EmbeddingModelConfig embeddingModelConfig = ollamaConfig.embeddingModel(); Options.Builder optionsBuilder = Options.builder() - .temperature(chatModelConfig.temperature()) - .topK(chatModelConfig.topK()) - .topP(chatModelConfig.topP()) - .numPredict(chatModelConfig.numPredict()); - if (chatModelConfig.stop().isPresent()) { - optionsBuilder.stop(chatModelConfig.stop().get()); + .temperature(embeddingModelConfig.temperature()) + .topK(embeddingModelConfig.topK()) + .topP(embeddingModelConfig.topP()) + .numPredict(embeddingModelConfig.numPredict()); + if (embeddingModelConfig.stop().isPresent()) { + optionsBuilder.stop(embeddingModelConfig.stop().get()); } var builder = OllamaEmbeddingModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .timeout(runtimeConfig.timeout()) - .model(chatModelConfig.modelId()); + .baseUrl(ollamaConfig.baseUrl()) + .timeout(ollamaConfig.timeout()) + .model(embeddingModelConfig.modelId()); return new Supplier<>() { @Override @@ -60,4 +64,15 @@ public Object get() { } }; } + + private Langchain4jOllamaConfig.OllamaConfig correspondingOllamaConfig(Langchain4jOllamaConfig runtimeConfig, + String modelName) { + Langchain4jOllamaConfig.OllamaConfig ollamaConfig; + if (NamedModelUtil.isDefault(modelName)) { + ollamaConfig = runtimeConfig.defaultConfig(); + } else { + ollamaConfig = runtimeConfig.namedConfig().get(modelName); + } + return ollamaConfig; + } } diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/EmbeddingModelConfig.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/EmbeddingModelConfig.java new file mode 100644 index 000000000..cd3bff075 --- /dev/null +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/EmbeddingModelConfig.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.ollama.runtime.config; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; + +@ConfigGroup +public interface EmbeddingModelConfig { + + // TODO: we need to check if these are the correct settings... + + /** + * Model to use. According to Ollama + * docs, + * the default value is {@code latest} + */ + @WithDefault("latest") + String modelId(); + + /** + * The temperature of the model. Increasing the temperature will make the model answer with + * more variability. A lower temperature will make the model answer more conservatively. + */ + @WithDefault("0.8") + Double temperature(); + + /** + * Maximum number of tokens to predict when generating text + */ + @WithDefault("128") + Integer numPredict(); + + /** + * Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return + */ + Optional stop(); + + /** + * Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) + * will generate more focused and conservative text + */ + @WithDefault("0.9") + Double topP(); + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower + * value (e.g. 10) will be more conservative + */ + @WithDefault("40") + Integer topK(); +} diff --git a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/Langchain4jOllamaConfig.java b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/Langchain4jOllamaConfig.java index b61dcf6e8..652674320 100644 --- a/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/Langchain4jOllamaConfig.java +++ b/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/Langchain4jOllamaConfig.java @@ -3,41 +3,68 @@ import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; import java.time.Duration; +import java.util.Map; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.ollama") public interface Langchain4jOllamaConfig { /** - * Base URL where the Ollama serving is running + * Default model config. */ - @WithDefault("http://localhost:11434") - String baseUrl(); + @WithParentName + OllamaConfig defaultConfig(); /** - * Timeout for Ollama calls + * Named model config. */ - @WithDefault("10s") - Duration timeout(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * Whether the Ollama client should log requests - */ - @WithDefault("false") - Boolean logRequests(); + interface OllamaConfig { + /** + * Base URL where the Ollama serving is running + */ + @WithDefault("http://localhost:11434") + String baseUrl(); - /** - * Whether the Ollama client should log responses - */ - @WithDefault("false") - Boolean logResponses(); + /** + * Timeout for Ollama calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * Whether the Ollama client should log requests + */ + @WithDefault("false") + Boolean logRequests(); + + /** + * Whether the Ollama client should log responses + */ + @WithDefault("false") + Boolean logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + } } diff --git a/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java b/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java index b426a1bcd..3e022e9ad 100644 --- a/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java +++ b/openai/azure-openai/deployment/src/main/java/io/quarkiverse/langchain4j/azure/openai/deployment/AzureOpenAiProcessor.java @@ -4,10 +4,13 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.STREAMING_CHAT_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.azure.openai.runtime.AzureOpenAiRecorder; import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; @@ -15,7 +18,7 @@ import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem; -import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -51,37 +54,51 @@ public void providerCandidates(BuildProducer selectedChatItem, - Optional selectedEmbedding, - Optional selectedModeration, + List selectedChatItem, + List selectedEmbedding, Langchain4jAzureOpenAiConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + + var streamingBuilder = SyntheticBeanBuildItem + .configure(STREAMING_CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.streamingChatModel(config, modelName)); + addQualifierIfNecessary(streamingBuilder, modelName); + beanProducer.produce(streamingBuilder.done()); + } + } - beanProducer.produce(SyntheticBeanBuildItem - .configure(STREAMING_CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.streamingChatModel(config)) - .done()); + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } + } - if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(EMBEDDING_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config)) - .done()); + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); } } diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index 58ebd2aed..f11606d35 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -3,6 +3,7 @@ import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; import java.util.ArrayList; +import java.util.List; import java.util.function.Supplier; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -15,6 +16,7 @@ import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig; import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.ShutdownContext; import io.quarkus.runtime.annotations.Recorder; import io.smallrye.config.ConfigValidationException; @@ -22,18 +24,26 @@ @Recorder public class AzureOpenAiRecorder { - static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s"; - public Supplier chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + private static final String DUMMY_KEY = "dummy"; + static final String AZURE_ENDPOINT_URL_PATTERN = "https://%s.openai.azure.com/openai/deployments/%s"; + public static final Problem[] EMPTY_PROBLEMS = new Problem[0]; + + public Supplier chatModel(Langchain4jAzureOpenAiConfig runtimeConfig, String modelName) { + Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName); + ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); + String apiKey = azureAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } var builder = AzureOpenAiChatModel.builder() - .endpoint(getEndpoint(runtimeConfig)) - .apiKey(runtimeConfig.apiKey()) - .apiVersion(runtimeConfig.apiVersion()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) + .endpoint(getEndpoint(azureAiConfig, modelName)) + .apiKey(apiKey) + .apiVersion(azureAiConfig.apiVersion()) + .timeout(azureAiConfig.timeout()) + .maxRetries(azureAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), azureAiConfig.logResponses())) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) @@ -52,15 +62,21 @@ public ChatLanguageModel get() { }; } - public Supplier streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + public Supplier streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig, + String modelName) { + Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName); + ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); + String apiKey = azureAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } var builder = AzureOpenAiStreamingChatModel.builder() - .endpoint(getEndpoint(runtimeConfig)) - .apiKey(runtimeConfig.apiKey()) - .apiVersion(runtimeConfig.apiVersion()) - .timeout(runtimeConfig.timeout()) - .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) + .endpoint(getEndpoint(azureAiConfig, modelName)) + .apiKey(apiKey) + .apiVersion(azureAiConfig.apiVersion()) + .timeout(azureAiConfig.timeout()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), azureAiConfig.logResponses())) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) @@ -79,16 +95,21 @@ public StreamingChatLanguageModel get() { }; } - public Supplier embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) { - EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel(); + public Supplier embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig, String modelName) { + Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, modelName); + EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel(); + String apiKey = azureAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblem(modelName)); + } var builder = AzureOpenAiEmbeddingModel.builder() - .endpoint(getEndpoint(runtimeConfig)) - .apiKey(runtimeConfig.apiKey()) - .apiVersion(runtimeConfig.apiVersion()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses())); + .endpoint(getEndpoint(azureAiConfig, modelName)) + .apiKey(apiKey) + .apiVersion(azureAiConfig.apiVersion()) + .timeout(azureAiConfig.timeout()) + .maxRetries(azureAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests())) + .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), azureAiConfig.logResponses())); return new Supplier<>() { @Override @@ -98,38 +119,59 @@ public EmbeddingModel get() { }; } - static String getEndpoint(Langchain4jAzureOpenAiConfig runtimeConfig) { - var endpoint = runtimeConfig.endpoint(); + static String getEndpoint(Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig, String modelName) { + var endpoint = azureAiConfig.endpoint(); return (endpoint.isPresent() && !endpoint.get().trim().isBlank()) ? endpoint.get() - : constructEndpointFromConfig(runtimeConfig); + : constructEndpointFromConfig(azureAiConfig, modelName); } - private static String constructEndpointFromConfig(Langchain4jAzureOpenAiConfig runtimeConfig) { - var resourceName = runtimeConfig.resourceName(); - var deploymentName = runtimeConfig.deploymentName(); + private static String constructEndpointFromConfig(Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig, + String modelName) { + var resourceName = azureAiConfig.resourceName(); + var deploymentName = azureAiConfig.deploymentName(); if (resourceName.isEmpty() || deploymentName.isEmpty()) { - var configProblems = new ArrayList<>(); + List configProblems = new ArrayList<>(); if (resourceName.isEmpty()) { - configProblems.add(createConfigProblem("resource-name")); + configProblems.add(createConfigProblem("resource-name", modelName)); } if (deploymentName.isEmpty()) { - configProblems.add(createConfigProblem("deployment-name")); + configProblems.add(createConfigProblem("deployment-name", modelName)); } - throw new ConfigValidationException(configProblems.toArray(new Problem[configProblems.size()])); + throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); } return String.format(AZURE_ENDPOINT_URL_PATTERN, resourceName.get(), deploymentName.get()); } - private static ConfigValidationException.Problem createConfigProblem(String key) { + private Langchain4jAzureOpenAiConfig.AzureAiConfig correspondingAzureOpenAiConfig( + Langchain4jAzureOpenAiConfig runtimeConfig, + String modelName) { + Langchain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig; + if (NamedModelUtil.isDefault(modelName)) { + azureAiConfig = runtimeConfig.defaultConfig(); + } else { + azureAiConfig = runtimeConfig.namedConfig().get(modelName); + } + return azureAiConfig; + } + + private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) { + return createConfigProblems("api-key", modelName); + } + + private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { return new ConfigValidationException.Problem(String.format( - "SRCFG00014: The config property quarkus.langchain4j.azure-openai.%s is required but it could not be found in any config source", - key)); + "SRCFG00014: The config property quarkus.langchain4j.azure-openai%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); } public void cleanUp(ShutdownContext shutdown) { diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java index b041fbecb..97ed086a3 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java @@ -3,90 +3,115 @@ import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; import java.time.Duration; +import java.util.Map; import java.util.Optional; import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.azure-openai") public interface Langchain4jAzureOpenAiConfig { /** - * The name of your Azure OpenAI Resource. You're required to first deploy a model before you can make calls. - *

- * This and {@code quarkus.langchain4j.azure-openai.deployment-name} are required if - * {@code quarkus.langchain4j.azure-openai.endpoint} is not set. - * If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read. - *

+ * Default model config. */ - Optional resourceName(); + @WithParentName + AzureAiConfig defaultConfig(); /** - * The name of your model deployment. You're required to first deploy a model before you can make calls. - *

- * This and {@code quarkus.langchain4j.azure-openai.resource-name} are required if - * {@code quarkus.langchain4j.azure-openai.endpoint} is not set. - * If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read. - *

+ * Named model config. */ - Optional deploymentName(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * The endpoint for the Azure OpenAI resource. - *

- * If not specified, then {@code quarkus.langchain4j.azure-openai.resource-name} and - * {@code quarkus.langchain4j.azure-openai.deployment-name} are required. - * In this case the endpoint will be set to - * {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-name}} - *

- */ - Optional endpoint(); + @ConfigGroup + interface AzureAiConfig { + /** + * The name of your Azure OpenAI Resource. You're required to first deploy a model before you can make calls. + *

+ * This and {@code quarkus.langchain4j.azure-openai.deployment-name} are required if + * {@code quarkus.langchain4j.azure-openai.endpoint} is not set. + * If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read. + *

+ */ + Optional resourceName(); - /** - * The API version to use for this operation. This follows the YYYY-MM-DD format - */ - @WithDefault("2023-05-15") - String apiVersion(); + /** + * The name of your model deployment. You're required to first deploy a model before you can make calls. + *

+ * This and {@code quarkus.langchain4j.azure-openai.resource-name} are required if + * {@code quarkus.langchain4j.azure-openai.endpoint} is not set. + * If {@code quarkus.langchain4j.azure-openai.endpoint} is not set then this is never read. + *

+ */ + Optional deploymentName(); - /** - * Azure OpenAI API key - */ - String apiKey(); + /** + * The endpoint for the Azure OpenAI resource. + *

+ * If not specified, then {@code quarkus.langchain4j.azure-openai.resource-name} and + * {@code quarkus.langchain4j.azure-openai.deployment-name} are required. + * In this case the endpoint will be set to + * {@code https://${quarkus.langchain4j.azure-openai.resource-name}.openai.azure.com/openai/deployments/${quarkus.langchain4j.azure-openai.deployment-name}} + *

+ */ + Optional endpoint(); - /** - * Timeout for OpenAI calls - */ - @WithDefault("10s") - Duration timeout(); + /** + * The API version to use for this operation. This follows the YYYY-MM-DD format + */ + @WithDefault("2023-05-15") + String apiVersion(); - /** - * The maximum number of times to retry - */ - @WithDefault("3") - Integer maxRetries(); + /** + * Azure OpenAI API key + */ + @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it.. + String apiKey(); - /** - * Whether the OpenAI client should log requests - */ - @ConfigDocDefault("false") - Optional logRequests(); + /** + * Timeout for OpenAI calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Whether the OpenAI client should log responses - */ - @ConfigDocDefault("false") - Optional logResponses(); + /** + * The maximum number of times to retry + */ + @WithDefault("3") + Integer maxRetries(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * Whether the OpenAI client should log requests + */ + @ConfigDocDefault("false") + Optional logRequests(); - /** - * Embedding model related settings - */ - EmbeddingModelConfig embeddingModel(); + /** + * Whether the OpenAI client should log responses + */ + @ConfigDocDefault("false") + Optional logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + } } diff --git a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java index 68f422510..eb9dc4a26 100644 --- a/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java +++ b/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java @@ -4,6 +4,7 @@ import static org.mockito.Mockito.*; import java.time.Duration; +import java.util.Map; import java.util.Optional; import org.junit.jupiter.api.Test; @@ -11,17 +12,19 @@ import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.smallrye.config.ConfigValidationException; import io.smallrye.config.ConfigValidationException.Problem; class AzureOpenAiRecorderEndpointTests { private static final String CONFIG_ERROR_MESSAGE_TEMPLATE = "SRCFG00014: The config property quarkus.langchain4j.azure-openai.%s is required but it could not be found in any config source"; - Langchain4jAzureOpenAiConfig config = spy(Config.class); + Langchain4jAzureOpenAiConfig.AzureAiConfig config = spy(CustomAzureAiConfig.class); @Test void noEndpointConfigSet() { - var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config), + var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config, + NamedModelUtil.DEFAULT_NAME), ConfigValidationException.class); assertThat(configValidationException.getProblemCount()) @@ -44,7 +47,8 @@ void onlyResourceNameSet() { .when(this.config) .resourceName(); - var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config), + var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config, + NamedModelUtil.DEFAULT_NAME), ConfigValidationException.class); assertThat(configValidationException.getProblemCount()) @@ -62,7 +66,8 @@ void onlyDeploymentNameSet() { .when(this.config) .deploymentName(); - var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config), + var configValidationException = catchThrowableOfType(() -> AzureOpenAiRecorder.getEndpoint(this.config, + NamedModelUtil.DEFAULT_NAME), ConfigValidationException.class); assertThat(configValidationException.getProblemCount()) @@ -80,7 +85,7 @@ void endpointSet() { .when(this.config) .endpoint(); - assertThat(AzureOpenAiRecorder.getEndpoint(this.config)) + assertThat(AzureOpenAiRecorder.getEndpoint(this.config, NamedModelUtil.DEFAULT_NAME)) .isNotNull() .isEqualTo("https://somewhere.com"); } @@ -95,12 +100,31 @@ void resourceNameAndDeploymentNameSet() { .when(this.config) .deploymentName(); - assertThat(AzureOpenAiRecorder.getEndpoint(this.config)) + assertThat(AzureOpenAiRecorder.getEndpoint(this.config, NamedModelUtil.DEFAULT_NAME)) .isNotNull() .isEqualTo(String.format(AzureOpenAiRecorder.AZURE_ENDPOINT_URL_PATTERN, "resourceName", "deploymentName")); } - static class Config implements Langchain4jAzureOpenAiConfig { + static class CustomLangchain4JAzureOpenAiConfig implements Langchain4jAzureOpenAiConfig { + + private final AzureAiConfig azureAiConfig; + + CustomLangchain4JAzureOpenAiConfig(AzureAiConfig azureAiConfig) { + this.azureAiConfig = azureAiConfig; + } + + @Override + public AzureAiConfig defaultConfig() { + return azureAiConfig; + } + + @Override + public Map namedConfig() { + throw new IllegalStateException("should not be called"); + } + } + + static class CustomAzureAiConfig implements Langchain4jAzureOpenAiConfig.AzureAiConfig { @Override public Optional resourceName() { return Optional.empty(); @@ -201,4 +225,4 @@ public Optional logResponses() { }; } } -} \ No newline at end of file +} diff --git a/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java b/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java index f1bcbe4d0..20efeffe9 100644 --- a/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java +++ b/openai/openai-vanilla/deployment/src/main/java/io/quarkiverse/langchain4j/openai/deployment/OpenAiProcessor.java @@ -6,10 +6,13 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.MODERATION_MODEL; import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.STREAMING_CHAT_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.EmbeddingModelBuildItem; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; @@ -21,6 +24,7 @@ import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem; import io.quarkiverse.langchain4j.openai.runtime.OpenAiRecorder; import io.quarkiverse.langchain4j.openai.runtime.config.Langchain4jOpenAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -65,58 +69,82 @@ public void providerCandidates(BuildProducer embeddingM @BuildStep @Record(ExecutionTime.RUNTIME_INIT) void generateBeans(OpenAiRecorder recorder, - Optional selectedChatItem, - Optional selectedEmbedding, - Optional selectedModeration, - Optional selectedImage, + List selectedChatItem, + List selectedEmbedding, + List selectedModeration, + List selectedImage, Langchain4jOpenAiConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); - - beanProducer.produce(SyntheticBeanBuildItem - .configure(STREAMING_CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.streamingChatModel(config)) - .done()); + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + + var streamingBuilder = SyntheticBeanBuildItem + .configure(STREAMING_CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.streamingChatModel(config, modelName)); + addQualifierIfNecessary(streamingBuilder, modelName); + beanProducer.produce(streamingBuilder.done()); + } } - if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(EMBEDDING_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.embeddingModel(config)) - .done()); + for (var selected : selectedEmbedding) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(EMBEDDING_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.embeddingModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } - if (selectedModeration.isPresent() && PROVIDER.equals(selectedModeration.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(MODERATION_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.moderationModel(config)) - .done()); + for (var selected : selectedModeration) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(MODERATION_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.moderationModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } - if (selectedImage.isPresent() && PROVIDER.equals(selectedImage.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(IMAGE_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.imageModel(config)) - .done()); + for (var selected : selectedImage) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(IMAGE_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.imageModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); } } diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModelsDeclarativeServiceTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModelsDeclarativeServiceTest.java new file mode 100644 index 000000000..06c35bddc --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModelsDeclarativeServiceTest.java @@ -0,0 +1,150 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.test.QuarkusUnitTest; + +public class MultipleChatModelsDeclarativeServiceTest { + + public static final String MESSAGE_CONTENT = "Tell me a joke about developers"; + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "defaultKey") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model1.api-key", "key1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model1.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model2.api-key", "key2") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model2.base-url", + "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + } + + @RegisterAiService + interface ChatWithDefaultModel { + + String chat(String userMessage); + } + + @RegisterAiService(modelName = "model1") + interface ChatWithModel1 { + + String chat(String userMessage); + } + + @RegisterAiService(modelName = "model2") + interface ChatWithModel2 { + + String chat(String userMessage); + } + + @Inject + ChatWithDefaultModel chatWithDefaultModel; + + @Inject + ChatWithModel1 chatWithModel1; + + @Inject + ChatWithModel2 chatWithModel2; + + @Test + @ActivateRequestContext + public void testDefaultModel() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("defaultKey"), MESSAGE_CONTENT)); + String result = chatWithDefaultModel.chat(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + @Test + @ActivateRequestContext + public void testNamedModel1() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("key1"), MESSAGE_CONTENT)); + String result = chatWithModel1.chat(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + @Test + @ActivateRequestContext + public void testNamedModel2() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("key2"), MESSAGE_CONTENT)); + String result = chatWithModel2.chat(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + private Map getRequestAsMap() throws IOException { + return getRequestAsMap(getRequestBody()); + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody() { + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + return getRequestBody(serveEvent); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } + +} diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModesTest.java new file mode 100644 index 000000000..a00a42f90 --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/MultipleChatModesTest.java @@ -0,0 +1,135 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.ModelName; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.test.QuarkusUnitTest; + +public class MultipleChatModesTest { + + public static final String MESSAGE_CONTENT = "Tell me a joke about developers"; + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "defaultKey") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model1.api-key", "key1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model1.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model2.api-key", "key2") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.model2.base-url", + "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + } + + @Inject + ChatLanguageModel chatWithDefaultModel; + + @Inject + @ModelName("model1") + ChatLanguageModel chatWithModel1; + + @Inject + @ModelName("model2") + ChatLanguageModel chatWithModel2; + + @Test + @ActivateRequestContext + public void testDefaultModel() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("defaultKey"), MESSAGE_CONTENT)); + String result = chatWithDefaultModel.generate(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + @Test + @ActivateRequestContext + public void testNamedModel1() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("key1"), MESSAGE_CONTENT)); + String result = chatWithModel1.generate(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + @Test + @ActivateRequestContext + public void testNamedModel2() throws IOException { + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.of("key2"), MESSAGE_CONTENT)); + String result = chatWithModel2.generate(MESSAGE_CONTENT); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), MESSAGE_CONTENT); + } + + private Map getRequestAsMap() throws IOException { + return getRequestAsMap(getRequestBody()); + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody() { + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + return getRequestBody(serveEvent); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } + +} diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index 97fd51742..ac55281db 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -18,6 +18,7 @@ import io.quarkiverse.langchain4j.openai.runtime.config.ImageModelConfig; import io.quarkiverse.langchain4j.openai.runtime.config.Langchain4jOpenAiConfig; import io.quarkiverse.langchain4j.openai.runtime.config.ModerationModelConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.ShutdownContext; import io.quarkus.runtime.annotations.Recorder; import io.smallrye.config.ConfigValidationException; @@ -25,26 +26,29 @@ @Recorder public class OpenAiRecorder { - public Supplier chatModel(Langchain4jOpenAiConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - if (apiKeyOpt.isEmpty()) { - throw new ConfigValidationException(createApiKeyConfigProblems()); + private static final String DUMMY_KEY = "dummy"; + + public Supplier chatModel(Langchain4jOpenAiConfig runtimeConfig, String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, modelName); + String apiKey = openAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblems(modelName)); } - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + ChatModelConfig chatModelConfig = openAiConfig.chatModel(); var builder = OpenAiChatModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .apiKey(apiKeyOpt.get()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) + .baseUrl(openAiConfig.baseUrl()) + .apiKey(apiKey) + .timeout(openAiConfig.timeout()) + .maxRetries(openAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), openAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), openAiConfig.logResponses())) .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) .presencePenalty(chatModelConfig.presencePenalty()) .frequencyPenalty(chatModelConfig.frequencyPenalty()); - runtimeConfig.organizationId().ifPresent(builder::organizationId); + openAiConfig.organizationId().ifPresent(builder::organizationId); if (chatModelConfig.maxTokens().isPresent()) { builder.maxTokens(chatModelConfig.maxTokens().get()); @@ -58,25 +62,26 @@ public Object get() { }; } - public Supplier streamingChatModel(Langchain4jOpenAiConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - if (apiKeyOpt.isEmpty()) { - throw new ConfigValidationException(createApiKeyConfigProblems()); + public Supplier streamingChatModel(Langchain4jOpenAiConfig runtimeConfig, String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, modelName); + String apiKey = openAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblems(modelName)); } - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + ChatModelConfig chatModelConfig = openAiConfig.chatModel(); var builder = OpenAiStreamingChatModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .apiKey(apiKeyOpt.get()) - .timeout(runtimeConfig.timeout()) - .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) + .baseUrl(openAiConfig.baseUrl()) + .apiKey(apiKey) + .timeout(openAiConfig.timeout()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), openAiConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), openAiConfig.logResponses())) .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) .presencePenalty(chatModelConfig.presencePenalty()) .frequencyPenalty(chatModelConfig.frequencyPenalty()); - runtimeConfig.organizationId().ifPresent(builder::organizationId); + openAiConfig.organizationId().ifPresent(builder::organizationId); if (chatModelConfig.maxTokens().isPresent()) { builder.maxTokens(chatModelConfig.maxTokens().get()); @@ -90,26 +95,27 @@ public Object get() { }; } - public Supplier embeddingModel(Langchain4jOpenAiConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - if (apiKeyOpt.isEmpty()) { - throw new ConfigValidationException(createApiKeyConfigProblems()); + public Supplier embeddingModel(Langchain4jOpenAiConfig runtimeConfig, String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, modelName); + String apiKeyOpt = openAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKeyOpt)) { + throw new ConfigValidationException(createApiKeyConfigProblems(modelName)); } - EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel(); + EmbeddingModelConfig embeddingModelConfig = openAiConfig.embeddingModel(); var builder = OpenAiEmbeddingModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .apiKey(apiKeyOpt.get()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses())) + .baseUrl(openAiConfig.baseUrl()) + .apiKey(apiKeyOpt) + .timeout(openAiConfig.timeout()) + .maxRetries(openAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), openAiConfig.logRequests())) + .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), openAiConfig.logResponses())) .modelName(embeddingModelConfig.modelName()); if (embeddingModelConfig.user().isPresent()) { builder.user(embeddingModelConfig.user().get()); } - runtimeConfig.organizationId().ifPresent(builder::organizationId); + openAiConfig.organizationId().ifPresent(builder::organizationId); return new Supplier<>() { @Override @@ -119,22 +125,23 @@ public Object get() { }; } - public Supplier moderationModel(Langchain4jOpenAiConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - if (apiKeyOpt.isEmpty()) { - throw new ConfigValidationException(createApiKeyConfigProblems()); + public Supplier moderationModel(Langchain4jOpenAiConfig runtimeConfig, String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, modelName); + String apiKey = openAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblems(modelName)); } - ModerationModelConfig moderationModelConfig = runtimeConfig.moderationModel(); + ModerationModelConfig moderationModelConfig = openAiConfig.moderationModel(); var builder = OpenAiModerationModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .apiKey(apiKeyOpt.get()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), runtimeConfig.logResponses())) + .baseUrl(openAiConfig.baseUrl()) + .apiKey(apiKey) + .timeout(openAiConfig.timeout()) + .maxRetries(openAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), openAiConfig.logRequests())) + .logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), openAiConfig.logResponses())) .modelName(moderationModelConfig.modelName()); - runtimeConfig.organizationId().ifPresent(builder::organizationId); + openAiConfig.organizationId().ifPresent(builder::organizationId); return new Supplier<>() { @Override @@ -144,19 +151,20 @@ public Object get() { }; } - public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig) { - Optional apiKeyOpt = runtimeConfig.apiKey(); - if (apiKeyOpt.isEmpty()) { - throw new ConfigValidationException(createApiKeyConfigProblems()); + public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig, String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig = correspondingOpenAiConfig(runtimeConfig, modelName); + String apiKey = openAiConfig.apiKey(); + if (DUMMY_KEY.equals(apiKey)) { + throw new ConfigValidationException(createApiKeyConfigProblems(modelName)); } - ImageModelConfig imageModelConfig = runtimeConfig.imageModel(); + ImageModelConfig imageModelConfig = openAiConfig.imageModel(); var builder = QuarkusOpenAiImageModel.builder() - .baseUrl(runtimeConfig.baseUrl()) - .apiKey(apiKeyOpt.get()) - .timeout(runtimeConfig.timeout()) - .maxRetries(runtimeConfig.maxRetries()) - .logRequests(firstOrDefault(false, imageModelConfig.logRequests(), runtimeConfig.logRequests())) - .logResponses(firstOrDefault(false, imageModelConfig.logResponses(), runtimeConfig.logResponses())) + .baseUrl(openAiConfig.baseUrl()) + .apiKey(apiKey) + .timeout(openAiConfig.timeout()) + .maxRetries(openAiConfig.maxRetries()) + .logRequests(firstOrDefault(false, imageModelConfig.logRequests(), openAiConfig.logRequests())) + .logResponses(firstOrDefault(false, imageModelConfig.logResponses(), openAiConfig.logResponses())) .modelName(imageModelConfig.modelName()) .size(imageModelConfig.size()) .quality(imageModelConfig.quality()) @@ -164,7 +172,7 @@ public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig) { .responseFormat(imageModelConfig.responseFormat()) .user(imageModelConfig.user()); - runtimeConfig.organizationId().ifPresent(builder::organizationId); + openAiConfig.organizationId().ifPresent(builder::organizationId); // we persist if the directory was set explicitly and the boolean flag was not set to false // or if the boolean flag was set explicitly to true @@ -196,18 +204,29 @@ public Object get() { } - private ConfigValidationException.Problem[] createApiKeyConfigProblems() { - return createConfigProblems("api-key"); + private Langchain4jOpenAiConfig.OpenAiConfig correspondingOpenAiConfig(Langchain4jOpenAiConfig runtimeConfig, + String modelName) { + Langchain4jOpenAiConfig.OpenAiConfig openAiConfig; + if (NamedModelUtil.isDefault(modelName)) { + openAiConfig = runtimeConfig.defaultConfig(); + } else { + openAiConfig = runtimeConfig.namedConfig().get(modelName); + } + return openAiConfig; + } + + private ConfigValidationException.Problem[] createApiKeyConfigProblems(String modelName) { + return createConfigProblems("api-key", modelName); } - private ConfigValidationException.Problem[] createConfigProblems(String key) { - return new ConfigValidationException.Problem[] { createConfigProblem(key) }; + private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) { + return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) }; } - private ConfigValidationException.Problem createConfigProblem(String key) { + private ConfigValidationException.Problem createConfigProblem(String key, String modelName) { return new ConfigValidationException.Problem(String.format( - "SRCFG00014: The config property quarkus.langchain4j.openai.%s is required but it could not be found in any config source", - key)); + "SRCFG00014: The config property quarkus.langchain4j.openai%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); } public void cleanUp(ShutdownContext shutdown) { diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java index 0b3e172d0..e94f24bdd 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java @@ -3,74 +3,100 @@ import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; import java.time.Duration; +import java.util.Map; import java.util.Optional; import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.openai") public interface Langchain4jOpenAiConfig { /** - * Base URL of OpenAI API + * Default model config. */ - @WithDefault("https://api.openai.com/v1/") - String baseUrl(); + @WithParentName + OpenAiConfig defaultConfig(); /** - * OpenAI API key + * Named model config. */ - Optional apiKey(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional) - */ - Optional organizationId(); + @ConfigGroup + interface OpenAiConfig { - /** - * Timeout for OpenAI calls - */ - @WithDefault("10s") - Duration timeout(); + /** + * Base URL of OpenAI API + */ + @WithDefault("https://api.openai.com/v1/") + String baseUrl(); - /** - * The maximum number of times to retry - */ - @WithDefault("3") - Integer maxRetries(); + /** + * OpenAI API key + */ + @WithDefault("dummy") // TODO: this should be Optional but Smallrye Config doesn't like it... + String apiKey(); - /** - * Whether the OpenAI client should log requests - */ - @ConfigDocDefault("false") - Optional logRequests(); + /** + * OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional) + */ + Optional organizationId(); - /** - * Whether the OpenAI client should log responses - */ - @ConfigDocDefault("false") - Optional logResponses(); + /** + * Timeout for OpenAI calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * The maximum number of times to retry + */ + @WithDefault("3") + Integer maxRetries(); - /** - * Embedding model related settings - */ - EmbeddingModelConfig embeddingModel(); + /** + * Whether the OpenAI client should log requests + */ + @ConfigDocDefault("false") + Optional logRequests(); - /** - * Moderation model related settings - */ - ModerationModelConfig moderationModel(); + /** + * Whether the OpenAI client should log responses + */ + @ConfigDocDefault("false") + Optional logResponses(); - /** - * Image model related settings - */ - ImageModelConfig imageModel(); + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); + + /** + * Moderation model related settings + */ + ModerationModelConfig moderationModel(); + + /** + * Image model related settings + */ + ImageModelConfig imageModel(); + } } diff --git a/openshift-ai/deployment/src/main/java/io/quarkiverse/langchain4j/openshift/ai/deployment/OpenshiftAiProcessor.java b/openshift-ai/deployment/src/main/java/io/quarkiverse/langchain4j/openshift/ai/deployment/OpenshiftAiProcessor.java index 56d265290..33e81c860 100644 --- a/openshift-ai/deployment/src/main/java/io/quarkiverse/langchain4j/openshift/ai/deployment/OpenshiftAiProcessor.java +++ b/openshift-ai/deployment/src/main/java/io/quarkiverse/langchain4j/openshift/ai/deployment/OpenshiftAiProcessor.java @@ -2,14 +2,18 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.openshiftai.runtime.OpenshiftAiRecorder; import io.quarkiverse.langchain4j.openshiftai.runtime.config.Langchain4jOpenshiftAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; @@ -40,17 +44,28 @@ public void providerCandidates(BuildProducer selectedChatItem, + List selectedChatItem, Langchain4jOpenshiftAiConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } + } + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); } } } diff --git a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/OpenshiftAiRecorder.java b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/OpenshiftAiRecorder.java index 4cbf722d6..25d139973 100644 --- a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/OpenshiftAiRecorder.java +++ b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/OpenshiftAiRecorder.java @@ -1,25 +1,50 @@ package io.quarkiverse.langchain4j.openshiftai.runtime; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; import java.util.function.Supplier; import io.quarkiverse.langchain4j.openshiftai.OpenshiftAiChatModel; import io.quarkiverse.langchain4j.openshiftai.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.openshiftai.runtime.config.Langchain4jOpenshiftAiConfig; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkus.runtime.annotations.Recorder; +import io.smallrye.config.ConfigValidationException; @Recorder public class OpenshiftAiRecorder { - public Supplier chatModel(Langchain4jOpenshiftAiConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + private static final String DUMMY_URL = "https://dummy.ai/api"; + private static final String DUMMY_MODEL_ID = "dummy"; + public static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; + + public Supplier chatModel(Langchain4jOpenshiftAiConfig runtimeConfig, String modelName) { + Langchain4jOpenshiftAiConfig.OpenshiftAiConfig openshiftAiConfig = correspondingOpenshiftAiConfig(runtimeConfig, + modelName); + ChatModelConfig chatModelConfig = openshiftAiConfig.chatModel(); + + List configProblems = new ArrayList<>(); + URL baseUrl = openshiftAiConfig.baseUrl(); + if (DUMMY_URL.equals(baseUrl.toString())) { + configProblems.add(createBaseURLConfigProblem(modelName)); + } + String modelId = chatModelConfig.modelId(); + if (DUMMY_MODEL_ID.equals(modelId)) { + configProblems.add(createModelIdConfigProblem(modelName)); + } + + if (!configProblems.isEmpty()) { + throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); + } var builder = OpenshiftAiChatModel.builder() - .url(runtimeConfig.baseUrl()) - .timeout(runtimeConfig.timeout()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .url(baseUrl) + .timeout(openshiftAiConfig.timeout()) + .logRequests(openshiftAiConfig.logRequests()) + .logResponses(openshiftAiConfig.logResponses()) - .modelId(chatModelConfig.modelId()); + .modelId(modelId); return new Supplier<>() { @Override @@ -28,4 +53,30 @@ public Object get() { } }; } + + private Langchain4jOpenshiftAiConfig.OpenshiftAiConfig correspondingOpenshiftAiConfig( + Langchain4jOpenshiftAiConfig runtimeConfig, + String modelName) { + Langchain4jOpenshiftAiConfig.OpenshiftAiConfig openshiftAiConfig; + if (NamedModelUtil.isDefault(modelName)) { + openshiftAiConfig = runtimeConfig.defaultConfig(); + } else { + openshiftAiConfig = runtimeConfig.namedConfig().get(modelName); + } + return openshiftAiConfig; + } + + private ConfigValidationException.Problem createBaseURLConfigProblem(String modelName) { + return createConfigProblem("base-url", modelName); + } + + private ConfigValidationException.Problem createModelIdConfigProblem(String modelName) { + return createConfigProblem("chat-model.model-id", modelName); + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { + return new ConfigValidationException.Problem(String.format( + "SRCFG00014: The config property quarkus.langchain4j.openshift-ai%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); + } } diff --git a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/ChatModelConfig.java b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/ChatModelConfig.java index 054404c0a..c80b07e04 100644 --- a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/ChatModelConfig.java +++ b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/ChatModelConfig.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.openshiftai.runtime.config; import io.quarkus.runtime.annotations.ConfigGroup; +import io.smallrye.config.WithDefault; @ConfigGroup public interface ChatModelConfig { @@ -8,5 +9,6 @@ public interface ChatModelConfig { /** * Model to use */ + @WithDefault("dummy") // TODO: this is set to a dummy value because otherwise Smallrye Config cannot give a proper error for named models String modelId(); } diff --git a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/Langchain4jOpenshiftAiConfig.java b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/Langchain4jOpenshiftAiConfig.java index 0191ff453..813ec649b 100644 --- a/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/Langchain4jOpenshiftAiConfig.java +++ b/openshift-ai/runtime/src/main/java/io/quarkiverse/langchain4j/openshiftai/runtime/config/Langchain4jOpenshiftAiConfig.java @@ -4,41 +4,66 @@ import java.net.URL; import java.time.Duration; +import java.util.Map; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.openshift-ai") public interface Langchain4jOpenshiftAiConfig { /** - * Base URL where OpenShift AI serving is running, such as - * {@code https://flant5s-l-predictor-ch2023.apps.cluster-hj2qv.dynamic.redhatworkshops.io:443/api} + * Default model config. */ - URL baseUrl(); + @WithParentName + OpenshiftAiConfig defaultConfig(); /** - * Timeout for OpenShift AI calls + * Named model config. */ - @WithDefault("10s") - Duration timeout(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * Whether the OpenShift AI client should log requests - */ - @WithDefault("false") - Boolean logRequests(); + @ConfigGroup + interface OpenshiftAiConfig { + /** + * Base URL where OpenShift AI serving is running, such as + * {@code https://flant5s-l-predictor-ch2023.apps.cluster-hj2qv.dynamic.redhatworkshops.io:443/api} + */ + @WithDefault("https://dummy.ai/api") // TODO: this should be Optional but Smallrye Config doesn't like it + URL baseUrl(); - /** - * Whether the OpenShift AI client should log responses - */ - @WithDefault("false") - Boolean logResponses(); + /** + * Timeout for OpenShift AI calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * Whether the OpenShift AI client should log requests + */ + @WithDefault("false") + Boolean logRequests(); + + /** + * Whether the OpenShift AI client should log responses + */ + @WithDefault("false") + Boolean logResponses(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + } } diff --git a/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonProcessor.java b/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonProcessor.java index 6c5f14b6a..6a3dd3a3f 100644 --- a/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonProcessor.java +++ b/watsonx/deployment/src/main/java/io/quarkiverse/langchain4j/watsonx/deployment/WatsonProcessor.java @@ -2,14 +2,16 @@ import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL; -import java.util.Optional; +import java.util.List; import jakarta.enterprise.context.ApplicationScoped; +import org.jboss.jandex.AnnotationInstance; + +import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem; -import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; -import io.quarkiverse.langchain4j.watsonx.TokenGenerator; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkiverse.langchain4j.watsonx.runtime.WatsonRecorder; import io.quarkiverse.langchain4j.watsonx.runtime.config.Langchain4jWatsonConfig; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; @@ -32,7 +34,6 @@ FeatureBuildItem feature() { @BuildStep public void providerCandidates(BuildProducer chatProducer, - BuildProducer embeddingProducer, Langchain4jWatsonBuildConfig config) { if (config.chatModel().enabled().isEmpty() || config.chatModel().enabled().get()) { @@ -44,25 +45,29 @@ public void providerCandidates(BuildProducer selectedChatItem, + List selectedChatItem, Langchain4jWatsonConfig config, BuildProducer beanProducer) { - if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) { - beanProducer.produce(SyntheticBeanBuildItem - .configure(CHAT_MODEL) - .setRuntimeInit() - .defaultBean() - .scope(ApplicationScoped.class) - .supplier(recorder.chatModel(config)) - .done()); + for (var selected : selectedChatItem) { + if (PROVIDER.equals(selected.getProvider())) { + String modelName = selected.getModelName(); + var builder = SyntheticBeanBuildItem + .configure(CHAT_MODEL) + .setRuntimeInit() + .defaultBean() + .scope(ApplicationScoped.class) + .supplier(recorder.chatModel(config, modelName)); + addQualifierIfNecessary(builder, modelName); + beanProducer.produce(builder.done()); + } } - beanProducer.produce(SyntheticBeanBuildItem - .configure(TokenGenerator.class) - .setRuntimeInit() - .scope(ApplicationScoped.class) - .supplier(recorder.tokenGenerator(config)) - .done()); + } + + private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) { + if (!NamedModelUtil.isDefault(modelName)) { + builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build()); + } } } diff --git a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiServiceTest.java b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiServiceTest.java index 878c83739..d28ab6ae3 100644 --- a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiServiceTest.java +++ b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AiServiceTest.java @@ -25,6 +25,7 @@ import io.quarkiverse.langchain4j.watsonx.bean.Parameters; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.client.WatsonRestApi; +import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.Langchain4jWatsonConfig; import io.quarkus.test.QuarkusUnitTest; @@ -35,7 +36,7 @@ public class AiServiceTest { static ObjectMapper mapper; @Inject - Langchain4jWatsonConfig config; + Langchain4jWatsonConfig langchain4jWatsonConfig; @Inject ChatLanguageModel model; @@ -84,8 +85,10 @@ interface NewAIService { @Test void chat() throws Exception { - String modelId = config.chatModel().modelId(); - String projectId = config.projectId(); + Langchain4jWatsonConfig.WatsonConfig watsonConfig = langchain4jWatsonConfig.defaultConfig(); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + String modelId = chatModelConfig.modelId(); + String projectId = watsonConfig.projectId(); String input = new StringBuilder() .append("This is a systemMessage") .append("\n\n") @@ -93,10 +96,10 @@ void chat() throws Exception { .append("\n") .toString(); Parameters parameters = Parameters.builder() - .decodingMethod(config.chatModel().decodingMethod()) - .temperature(config.chatModel().temperature()) - .minNewTokens(config.chatModel().minNewTokens()) - .maxNewTokens(config.chatModel().maxNewTokens()) + .decodingMethod(chatModelConfig.decodingMethod()) + .temperature(chatModelConfig.temperature()) + .minNewTokens(chatModelConfig.minNewTokens()) + .maxNewTokens(chatModelConfig.maxNewTokens()) .build(); TextGenerationRequest body = new TextGenerationRequest(modelId, projectId, input, parameters); diff --git a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java index fd5d04b93..633754481 100644 --- a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java +++ b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/AllPropertiesTest.java @@ -33,7 +33,7 @@ public class AllPropertiesTest { static ObjectMapper mapper; @Inject - Langchain4jWatsonConfig config; + Langchain4jWatsonConfig langchain4jWatsonConfig; @Inject ChatLanguageModel model; @@ -85,7 +85,7 @@ static void afterAll() { @Test void generate() throws Exception { - + var config = langchain4jWatsonConfig.defaultConfig(); assertEquals(WireMockUtil.URL_WATSONX_SERVER, config.baseUrl().toString()); assertEquals(WireMockUtil.URL_IAM_SERVER, config.iam().baseUrl().toString()); assertEquals(WireMockUtil.API_KEY, config.apiKey()); diff --git a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java index 51e40d7b5..781bd8aca 100644 --- a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java +++ b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/DefaultPropertiesTest.java @@ -32,7 +32,7 @@ public class DefaultPropertiesTest { static ObjectMapper mapper; @Inject - Langchain4jWatsonConfig config; + Langchain4jWatsonConfig langchain4jWatsonConfig; @Inject ChatLanguageModel model; @@ -68,7 +68,7 @@ static void afterAll() { @Test void generate() throws Exception { - + var config = langchain4jWatsonConfig.defaultConfig(); assertEquals(Duration.ofSeconds(10), config.timeout()); assertEquals("2023-05-29", config.version()); assertEquals(false, config.logRequests()); diff --git a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java index 9ea70b5cf..52bd8edfa 100644 --- a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java +++ b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/HttpErrorTest.java @@ -98,6 +98,7 @@ void error_404_model_not_supported() { .build(); mockServers.mockWatsonBuilder(404) + .responseMediaType(MediaType.APPLICATION_JSON) .response(""" { "errors": [ @@ -113,8 +114,6 @@ void error_404_model_not_supported() { .build(); WatsonException ex = assertThrowsExactly(WatsonException.class, () -> model.generate("message")); - assertNotNull(ex.details()); - assertNotNull(ex.details().trace()); assertEquals(404, ex.details().statusCode()); assertNotNull(ex.details().errors()); assertEquals(1, ex.details().errors().size()); @@ -197,7 +196,6 @@ void error_500() { WatsonException ex = assertThrowsExactly(WatsonException.class, () -> model.generate("message")); assertEquals(500, ex.statusCode()); - assertTrue(ex.getMessage().contains("Unexpected end-of-input")); } @Test diff --git a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/MissingPropertiesTest.java b/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/MissingPropertiesTest.java deleted file mode 100644 index 901593846..000000000 --- a/watsonx/deployment/src/test/java/com/ibm/langchain4j/watsonx/deployment/MissingPropertiesTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package com.ibm.langchain4j.watsonx.deployment; - -import static org.junit.jupiter.api.Assertions.fail; - -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import io.quarkus.test.QuarkusUnitTest; -import io.smallrye.config.ConfigValidationException; - -public class MissingPropertiesTest { - - @RegisterExtension - static QuarkusUnitTest unitTest = new QuarkusUnitTest() - .setExpectedException(ConfigValidationException.class) - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); - - @Test - void test() { - fail("Should not be called"); - } -} diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java index 215f45995..355158638 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/TokenGenerator.java @@ -13,11 +13,11 @@ public class TokenGenerator { - private static final ReentrantLock lock = new ReentrantLock();; - private IAMRestApi client; + private static final ReentrantLock lock = new ReentrantLock(); + private final IAMRestApi client; + private final String apiKey; + private final String grantType; private IdentityTokenResponse token; - private String apiKey; - private String grantType; public TokenGenerator(URL url, Duration timeout, String grantType, String apiKey) { diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonChatModel.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonChatModel.java index d8878552f..4e463416c 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonChatModel.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonChatModel.java @@ -20,6 +20,7 @@ import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse.Result; import io.quarkiverse.langchain4j.watsonx.client.WatsonRestApi; +import io.quarkiverse.langchain4j.watsonx.client.filter.BearerRequestFilter; import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder; public class WatsonChatModel implements ChatLanguageModel { @@ -52,6 +53,10 @@ public WatsonChatModel(Builder config) { config.logResponses)); } + if (config.tokenGenerator != null) { + builder.register(new BearerRequestFilter(config.tokenGenerator)); + } + this.client = builder.build(WatsonRestApi.class); this.modelId = config.modelId; this.version = config.version; @@ -160,6 +165,7 @@ public static final class Builder { private Double repetitionPenalty; public boolean logResponses; public boolean logRequests; + private TokenGenerator tokenGenerator; public Builder modelId(String modelId) { this.modelId = modelId; @@ -236,8 +242,9 @@ public Builder stopSequences(List stopSequences) { return this; } - public WatsonChatModel build() { - return new WatsonChatModel(this); + public Builder tokenGenerator(TokenGenerator tokenGenerator) { + this.tokenGenerator = tokenGenerator; + return this; } public Builder logRequests(boolean logRequests) { @@ -249,5 +256,9 @@ public Builder logResponses(boolean logResponses) { this.logResponses = logResponses; return this; } + + public WatsonChatModel build() { + return new WatsonChatModel(this); + } } } diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonRestApi.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonRestApi.java index 03284d9ae..2bc6da7a9 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonRestApi.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/WatsonRestApi.java @@ -10,14 +10,11 @@ import jakarta.ws.rs.Consumes; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; -import jakarta.ws.rs.ProcessingException; import jakarta.ws.rs.Produces; import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.core.MediaType; -import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; import org.jboss.logging.Logger; -import org.jboss.resteasy.reactive.ClientWebApplicationException; import org.jboss.resteasy.reactive.client.api.ClientLogger; import com.fasterxml.jackson.databind.ObjectMapper; @@ -26,7 +23,6 @@ import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest; import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse; import io.quarkiverse.langchain4j.watsonx.bean.WatsonError; -import io.quarkiverse.langchain4j.watsonx.client.filter.BearerRequestFilter; import io.quarkiverse.langchain4j.watsonx.exception.WatsonException; import io.quarkus.rest.client.reactive.ClientExceptionMapper; import io.quarkus.rest.client.reactive.jackson.ClientObjectMapper; @@ -42,38 +38,34 @@ * in Quarkus. */ @Path("/ml/v1-beta") -@RegisterProvider(BearerRequestFilter.class) @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public interface WatsonRestApi { - static Logger logger = Logger.getLogger(WatsonRestApi.class); - @POST @Path("generation/text") TextGenerationResponse chat(TextGenerationRequest request, @QueryParam("version") String version) throws WatsonException; @ClientExceptionMapper static WatsonException toException(jakarta.ws.rs.core.Response response) { + MediaType mediaType = response.getMediaType(); + if ((mediaType != null) && mediaType.isCompatible(MediaType.APPLICATION_JSON_TYPE)) { + try { + WatsonError ex = response.readEntity(WatsonError.class); - if (MediaType.TEXT_PLAIN.equals(response.getHeaderString("Content-Type"))) - return new WatsonException(response.readEntity(String.class), response.getStatus()); - - try { - - WatsonError ex = response.readEntity(WatsonError.class); + StringJoiner joiner = new StringJoiner("\n"); + if (ex.errors() != null && ex.errors().size() > 0) { + for (WatsonError.Error error : ex.errors()) + joiner.add("%s: %s".formatted(error.code(), error.message())); + } - StringJoiner joiner = new StringJoiner("\n"); - if (ex.errors() != null && ex.errors().size() > 0) { - for (WatsonError.Error error : ex.errors()) - joiner.add("%s: %s".formatted(error.code(), error.message())); + return new WatsonException(joiner.toString(), response.getStatus(), ex); + } catch (Exception e) { + return new WatsonException(response.readEntity(String.class), response.getStatus()); } - - return new WatsonException(joiner.toString(), response.getStatus(), ex); - - } catch (ClientWebApplicationException | ProcessingException e) { - return new WatsonException(e.getCause(), response.getStatus()); } + + return new WatsonException(response.readEntity(String.class), response.getStatus()); } @ClientObjectMapper diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerRequestFilter.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerRequestFilter.java index 5fa0fc9d4..3df252e7a 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerRequestFilter.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/client/filter/BearerRequestFilter.java @@ -9,7 +9,7 @@ public class BearerRequestFilter implements ClientRequestFilter { - private TokenGenerator tokenGenerator; + private final TokenGenerator tokenGenerator; public BearerRequestFilter(TokenGenerator tokenGenerator) { this.tokenGenerator = tokenGenerator; diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonRecorder.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonRecorder.java index f09827b77..c55ed1742 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonRecorder.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/WatsonRecorder.java @@ -2,28 +2,66 @@ import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; import java.util.function.Supplier; +import io.quarkiverse.langchain4j.runtime.NamedModelUtil; import io.quarkiverse.langchain4j.watsonx.TokenGenerator; import io.quarkiverse.langchain4j.watsonx.WatsonChatModel; import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.watsonx.runtime.config.IAMConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.Langchain4jWatsonConfig; import io.quarkus.runtime.annotations.Recorder; +import io.smallrye.config.ConfigValidationException; @Recorder public class WatsonRecorder { - public Supplier chatModel(Langchain4jWatsonConfig runtimeConfig) { - ChatModelConfig chatModelConfig = runtimeConfig.chatModel(); + private static final String DUMMY_URL = "https://dummy.ai/api"; + private static final String DUMMY_API_KEY = "dummy"; + private static final String DUMMY_PROJECT_ID = "dummy"; + public static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0]; + + public Supplier chatModel(Langchain4jWatsonConfig runtimeConfig, String modelName) { + Langchain4jWatsonConfig.WatsonConfig watsonConfig = correspondingWatsonConfig(runtimeConfig, modelName); + ChatModelConfig chatModelConfig = watsonConfig.chatModel(); + + List configProblems = new ArrayList<>(); + URL baseUrl = watsonConfig.baseUrl(); + if (DUMMY_URL.equals(baseUrl.toString())) { + configProblems.add(createBaseURLConfigProblem(modelName)); + } + String apiKey = watsonConfig.apiKey(); + if (DUMMY_API_KEY.equals(apiKey)) { + configProblems.add(createApiKeyConfigProblem(modelName)); + } + String projectId = watsonConfig.projectId(); + if (DUMMY_PROJECT_ID.equals(projectId)) { + configProblems.add(createProjectIdProblem(modelName)); + } + + if (!configProblems.isEmpty()) { + throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); + } + + IAMConfig iamConfig = watsonConfig.iam(); + var tokenGenerator = new TokenGenerator( + iamConfig.baseUrl(), + iamConfig.timeout(), + iamConfig.grantType(), + watsonConfig.apiKey()); var builder = WatsonChatModel.builder() - .url(runtimeConfig.baseUrl()) - .timeout(runtimeConfig.timeout()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .tokenGenerator(tokenGenerator) + .url(baseUrl) + .timeout(watsonConfig.timeout()) + .logRequests(watsonConfig.logRequests()) + .logResponses(watsonConfig.logResponses()) + .version(watsonConfig.version()) + .projectId(projectId) .modelId(chatModelConfig.modelId()) - .version(runtimeConfig.version()) - .projectId(runtimeConfig.projectId()) .decodingMethod(chatModelConfig.decodingMethod()) .minNewTokens(chatModelConfig.minNewTokens()) .maxNewTokens(chatModelConfig.maxNewTokens()) @@ -42,16 +80,32 @@ public Object get() { }; } - public Supplier tokenGenerator(Langchain4jWatsonConfig runtimeConfig) { - return new Supplier<>() { - @Override - public Object get() { - return new TokenGenerator( - runtimeConfig.iam().baseUrl(), - runtimeConfig.iam().timeout(), - runtimeConfig.iam().grantType(), - runtimeConfig.apiKey()); - } - }; + private Langchain4jWatsonConfig.WatsonConfig correspondingWatsonConfig(Langchain4jWatsonConfig runtimeConfig, + String modelName) { + Langchain4jWatsonConfig.WatsonConfig watsonConfig; + if (NamedModelUtil.isDefault(modelName)) { + watsonConfig = runtimeConfig.defaultConfig(); + } else { + watsonConfig = runtimeConfig.namedConfig().get(modelName); + } + return watsonConfig; + } + + private ConfigValidationException.Problem createBaseURLConfigProblem(String modelName) { + return createConfigProblem("base-url", modelName); + } + + private ConfigValidationException.Problem createApiKeyConfigProblem(String modelName) { + return createConfigProblem("api-key", modelName); + } + + private ConfigValidationException.Problem createProjectIdProblem(String modelName) { + return createConfigProblem("project-id", modelName); + } + + private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) { + return new ConfigValidationException.Problem(String.format( + "SRCFG00014: The config property quarkus.langchain4j.watsonx%s%s is required but it could not be found in any config source", + NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key)); } } diff --git a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonConfig.java b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonConfig.java index a83ef1863..85fdee5bc 100644 --- a/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonConfig.java +++ b/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/runtime/config/Langchain4jWatsonConfig.java @@ -4,61 +4,88 @@ import java.net.URL; import java.time.Duration; +import java.util.Map; +import io.quarkus.runtime.annotations.ConfigDocMapKey; +import io.quarkus.runtime.annotations.ConfigDocSection; +import io.quarkus.runtime.annotations.ConfigGroup; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; +import io.smallrye.config.WithDefaults; +import io.smallrye.config.WithParentName; @ConfigRoot(phase = RUN_TIME) @ConfigMapping(prefix = "quarkus.langchain4j.watsonx") public interface Langchain4jWatsonConfig { /** - * Base URL + * Default model config. */ - URL baseUrl(); + @WithParentName + WatsonConfig defaultConfig(); /** - * Watsonx API key + * Named model config. */ - String apiKey(); + @ConfigDocSection + @ConfigDocMapKey("model-name") + @WithParentName + @WithDefaults + Map namedConfig(); - /** - * Timeout for Watsonx API calls - */ - @WithDefault("10s") - Duration timeout(); + @ConfigGroup + interface WatsonConfig { + /** + * Base URL + */ + @WithDefault("https://dummy.ai/api") // TODO: this is set to a dummy value because otherwise Smallrye Config cannot give a proper error for named models + URL baseUrl(); - /** - * Version to use - */ - @WithDefault("2023-05-29") - String version(); + /** + * Watsonx API key + */ + @WithDefault("dummy") + String apiKey(); - /** - * Watsonx project id. - */ - String projectId(); + /** + * Timeout for Watsonx API calls + */ + @WithDefault("10s") + Duration timeout(); - /** - * Whether the Watsonx client should log requests - */ - @WithDefault("false") - Boolean logRequests(); + /** + * Version to use + */ + @WithDefault("2023-05-29") + String version(); - /** - * Whether the Watsonx client should log responses - */ - @WithDefault("false") - Boolean logResponses(); + /** + * Watsonx project id. + */ + @WithDefault("dummy") // TODO: this is set to a dummy value because otherwise Smallrye Config cannot give a proper error for named models + String projectId(); - /** - * Chat model related settings - */ - IAMConfig iam(); + /** + * Whether the Watsonx client should log requests + */ + @WithDefault("false") + Boolean logRequests(); - /** - * Chat model related settings - */ - ChatModelConfig chatModel(); + /** + * Whether the Watsonx client should log responses + */ + @WithDefault("false") + Boolean logResponses(); + + /** + * Chat model related settings + */ + IAMConfig iam(); + + /** + * Chat model related settings + */ + ChatModelConfig chatModel(); + } }