Skip to content

Commit

Permalink
- Clean up revert and merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
humcqc committed Jul 15, 2024
1 parent 1768ec5 commit 0424252
Show file tree
Hide file tree
Showing 16 changed files with 60 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu

boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID;
if (toolReturnsVoid) {
invokeMc.returnValue(invokeMc.load("Success")); // TODO: To change
invokeMc.returnValue(invokeMc.load("Success"));
} else {
invokeMc.returnValue(result);
}
Expand Down
26 changes: 0 additions & 26 deletions integration-tests/ollama/src/test/resources/langchain-log

This file was deleted.

28 changes: 0 additions & 28 deletions integration-tests/ollama/src/test/resources/quarkus-langchain-log

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.STREAMING_CHAT_MODEL;

import java.util.List;
import java.util.stream.Stream;

import jakarta.enterprise.context.ApplicationScoped;

Expand All @@ -22,7 +21,6 @@
import io.quarkiverse.langchain4j.ollama.runtime.OllamaRecorder;
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaConfig;
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaFixedRuntimeConfig;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.IsNormal;
Expand All @@ -31,7 +29,6 @@
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;
import io.quarkus.runtime.configuration.ConfigUtils;

public class OllamaProcessor {
Expand Down Expand Up @@ -150,16 +147,4 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", configName).build());
}
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleDelegates(AiServicesRecorder recorder,
BuildProducer<RuntimeInitializedClassBuildItem> runtimeInitializedClassProducer) {
// since we only need reflection to the constructor of the class, we can specify `false` for both the methods and the fields arguments.
Stream.of("dev.langchain4j.model.ollama.tool.ExperimentalParallelToolsDelegate",
"dev.langchain4j.model.ollama.tool.ExperimentalSequentialToolsDelegate",
"dev.langchain4j.model.ollama.tool.NoToolsDelegate")
.map(RuntimeInitializedClassBuildItem::new)
.forEach(runtimeInitializedClassProducer::produce);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
public record ChatResponse(String model, String createdAt, Message message, Boolean done, Integer promptEvalCount,
Integer evalCount) {

public static ChatResponse emptyDone() {
public static ChatResponse emptyNotDone() {
return new ChatResponse(null, null, new Message(Role.ASSISTANT, "", null), true, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;

final class MessageMapper {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import java.io.InputStream;
import java.util.List;

import jakarta.ws.rs.*;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
Expand Down Expand Up @@ -89,18 +93,24 @@ public Object aroundReadFrom(ReaderInterceptorContext context) throws IOExceptio
throw e;
}

// This piece of code deals with is the case where the last message from Ollama is not sent as entire line
// but in pieces. There is nothing we can do in this case except for returning empty responses.
// We have to keep track of when "done": true has been recorded in order to make sure that subsequent pieces
// are dealt with instead of throwing an exception. We keep track of this by using Vert.x duplicated context

if (chunk.contains("\"done\":true")) {
ctx.putLocal("done", true);
return ChatResponse.emptyDone();
} else {
if (Boolean.TRUE.equals(ctx.getLocal("done"))) {
return ChatResponse.emptyDone();
// This piece of code deals with is the case where a message from Ollama is not received as an entire line
// but in pieces (my guess is that it is a Vertx bug).
// There is nothing we can do in this case except for returning empty responses and in the meantime buffer the pieces
// by storing them in the Vertx Duplicated Context
String existingBuffer = ctx.getLocal("buffer");
if ((existingBuffer != null) && !existingBuffer.isEmpty()) {
if (chunk.endsWith("}")) {
ctx.putLocal("buffer", "");
String entireLine = existingBuffer + chunk;
return QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER.readValue(entireLine,
ChatResponse.class);
} else {
ctx.putLocal("buffer", existingBuffer + chunk);
return ChatResponse.emptyNotDone();
}
} else {
ctx.putLocal("buffer", chunk);
return ChatResponse.emptyNotDone();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkiverse.langchain4j.ollama.runtime;

import java.time.Duration;
import java.util.function.Supplier;

import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -37,9 +38,11 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jOllamaConfig runtimeConf
Options.Builder optionsBuilder = Options.builder()
.temperature(chatModelConfig.temperature())
.topK(chatModelConfig.topK())
.topP(chatModelConfig.topP())
.numPredict(chatModelConfig.numPredict());
.topP(chatModelConfig.topP());

if (chatModelConfig.numPredict().isPresent()) {
optionsBuilder.numPredict(chatModelConfig.numPredict().getAsInt());
}
if (chatModelConfig.stop().isPresent()) {
optionsBuilder.stop(chatModelConfig.stop().get());
}
Expand All @@ -48,7 +51,7 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jOllamaConfig runtimeConf
}
var builder = OllamaChatLanguageModel.builder()
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
.timeout(ollamaConfig.timeout())
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(chatModelConfig.logRequests().orElse(false))
.logResponses(chatModelConfig.logResponses().orElse(false))
.model(ollamaFixedConfig.chatModel().modelId())
Expand Down Expand Up @@ -92,7 +95,7 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jOllamaConfig runtimeCo

var builder = OllamaEmbeddingModel.builder()
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
.timeout(ollamaConfig.timeout())
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
.model(ollamaFixedConfig.embeddingModel().modelId())
.logRequests(embeddingModelConfig.logRequests().orElse(false))
.logResponses(embeddingModelConfig.logResponses().orElse(false));
Expand Down Expand Up @@ -125,9 +128,11 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jOllama
Options.Builder optionsBuilder = Options.builder()
.temperature(chatModelConfig.temperature())
.topK(chatModelConfig.topK())
.topP(chatModelConfig.topP())
.numPredict(chatModelConfig.numPredict());
.topP(chatModelConfig.topP());

if (chatModelConfig.numPredict().isPresent()) {
optionsBuilder.numPredict(chatModelConfig.numPredict().getAsInt());
}
if (chatModelConfig.stop().isPresent()) {
optionsBuilder.stop(chatModelConfig.stop().get());
}
Expand All @@ -136,7 +141,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jOllama
}
var builder = OllamaStreamingChatLanguageModel.builder()
.baseUrl(ollamaConfig.baseUrl().orElse(DEFAULT_BASE_URL))
.timeout(ollamaConfig.timeout())
.timeout(ollamaConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(ollamaConfig.logRequests().orElse(false))
.logResponses(ollamaConfig.logResponses().orElse(false))
.model(ollamaFixedConfig.chatModel().modelId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;
Expand All @@ -20,8 +21,7 @@ public interface ChatModelConfig {
/**
* Maximum number of tokens to predict when generating text
*/
@WithDefault("128")
Integer numPredict();
OptionalInt numPredict();

/**
* Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ interface OllamaConfig {
/**
* Timeout for Ollama calls
*/
@WithDefault("10s")
Duration timeout();
@ConfigDocDefault("10s")
@WithDefault("${quarkus.langchain4j.timeout}")
Optional<Duration> timeout();

/**
* Whether the Ollama client should log requests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.io.IOException;
import java.util.Locale;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
Expand All @@ -16,7 +17,7 @@ public RoleDeserializer() {

@Override
public Role deserialize(JsonParser jp, DeserializationContext deserializationContext)
throws IOException {
throws IOException, JacksonException {
return Role.valueOf(jp.getValueAsString().toUpperCase(Locale.ROOT));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ public RoleSerializer() {
}

@Override
public void serialize(Role role, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
public void serialize(Role role, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
throws IOException {
jsonGenerator.writeString(role.toString().toLowerCase(Locale.ROOT));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ void should_execute_tool_then_answer() throws IOException {
assertMultipleRequestMessage(getRequestAsMap(getRequestBody(wiremock().getServeEvents().get(0))),
List.of(
new MessageContent("user", "What is the square root of 485906798473894056 in scientific notation?"),
new MessageContent("function", "6.97070153193991E8"),
new MessageContent("assistant", null)));
new MessageContent("assistant", null),
new MessageContent("function", "6.97070153193991E8")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ void should_execute_tool_then_answer() throws IOException {
List.of(
new MessageContent("user",
"What is the square root of 485906798473894056 in scientific notation?"),
new MessageContent("function", "6.97070153193991E8"),
new MessageContent("assistant", null)));
new MessageContent("assistant", null),
new MessageContent("function", "6.97070153193991E8")));

InstanceHandle<SimpleAuditService> auditServiceInstance = Arc.container().instance(SimpleAuditService.class);
assertTrue(auditServiceInstance.isAvailable());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ void should_execute_tool_then_answer() throws IOException {
List.of(
new MessageContent("user",
"What is the square root of 485906798473894056 in scientific notation?"),
new MessageContent("function", "6.97070153193991E8"),
new MessageContent("assistant", null)));
new MessageContent("assistant", null),
new MessageContent("function", "6.97070153193991E8")));
}

@RegisterAiService
Expand Down
Loading

0 comments on commit 0424252

Please sign in to comment.