Skip to content

Commit

Permalink
- @geoand review 1
Browse files Browse the repository at this point in the history
  • Loading branch information
humcqc committed Jul 16, 2024
1 parent a6759c8 commit 05fff69
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.output.TokenUsage;

/**
* This class is the equivalent of Langchain4j AiMessage.
* It contains the token usage from the response that produce this AiMessage.
* And add the possibility to update the text in case of text containing Tools Result variables.
* Needed for @{@link io.quarkiverse.langchain4j.runtime.aiservice.ToolsResultMemory}
* Example of usage in ExperimentalParallelToolsDelegate in Ollama model provider
*/
public class AiStatsMessage extends AiMessage {
private String updatableText;

Expand All @@ -15,12 +22,12 @@ public class AiStatsMessage extends AiMessage {
public AiStatsMessage(String text, TokenUsage tokenUsage) {
super(text);
this.updatableText = text;
this.tokenUsage = ValidationUtils.ensureNotNull(tokenUsage, "tokeUsage");
this.tokenUsage = ValidationUtils.ensureNotNull(tokenUsage, "tokenUsage");
}

AiStatsMessage(List<ToolExecutionRequest> toolExecutionRequests, TokenUsage tokenUsage) {
super(toolExecutionRequests);
this.tokenUsage = ValidationUtils.ensureNotNull(tokenUsage, "tokeUsage");
this.tokenUsage = ValidationUtils.ensureNotNull(tokenUsage, "tokenUsage");
}

AiStatsMessage(String text, List<ToolExecutionRequest> toolExecutionRequests, TokenUsage tokenUsage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
Expand Down Expand Up @@ -200,6 +204,7 @@ public void accept(Response<AiMessage> message) {
AiMessage aiMessage = response.content();

if (!aiMessage.hasToolExecutionRequests()) {
// If there is no tool Execution request we add the Ai Message directly in the chatMEmory
if (context.hasChatMemory()) {
context.chatMemory(memoryId).add(aiMessage);
}
Expand Down Expand Up @@ -228,6 +233,8 @@ public void accept(Response<AiMessage> message) {
}
tmpToolExecutionResultMessages.add(toolExecutionResultMessage);
}
// In case of tool Execution request we need to update the AiMessage with tools results
// before adding it into chatMemory
aiMessage = toolsResultMemory.substituteAiMessage(aiMessage);
if (context.hasChatMemory()) {
chatMemory.add(aiMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.data.AiStatsMessage;

/**
* This class associate the ToolExecutionResul to the associated variable
* identified by @{@link ToolExecutionRequest#id()}.
* It will use them after when tools inputs @{@link ToolExecutionRequest#arguments()} are based
* on previous result variable and when AiMessage text contains variables too.
* See usage in @{@link AiServiceMethodImplementationSupport}
*/
@Experimental
public class ToolsResultMemory {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaConfig;
import io.quarkiverse.langchain4j.ollama.runtime.config.LangChain4jOllamaFixedRuntimeConfig;
import io.quarkiverse.langchain4j.ollama.tool.ExperimentalParallelToolsDelegate;
import io.quarkiverse.langchain4j.ollama.tool.ExperimentalSequentialToolsDelegate;
import io.quarkiverse.langchain4j.ollama.tool.NoToolsDelegate;
import io.quarkiverse.langchain4j.runtime.NamedConfigUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
Expand Down Expand Up @@ -155,8 +154,7 @@ private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigur

@BuildStep
public void handleDelegates(BuildProducer<RuntimeInitializedClassBuildItem> runtimeInitializedClassProducer) {
Stream.of(ExperimentalSequentialToolsDelegate.class.getName(),
ExperimentalParallelToolsDelegate.class.getName(),
Stream.of(ExperimentalParallelToolsDelegate.class.getName(),
NoToolsDelegate.class.getName())
.map(RuntimeInitializedClassBuildItem::new)
.forEach(runtimeInitializedClassProducer::produce);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.ollama.tool.ExperimentalParallelToolsDelegate;
import io.quarkiverse.langchain4j.ollama.tool.ExperimentalSequentialToolsDelegate;
import io.quarkiverse.langchain4j.ollama.tool.ExperimentalTools;
import io.quarkiverse.langchain4j.ollama.tool.NoToolsDelegate;

Expand All @@ -36,7 +35,6 @@ public class OllamaChatLanguageModel implements ChatLanguageModel {
private ChatLanguageModel getDelegate(ExperimentalTools toolsEnum) {
return switch (toolsEnum) {
case NONE -> new NoToolsDelegate(this.client, this.model, this.options, this.format);
case SEQUENTIAL -> new ExperimentalSequentialToolsDelegate(this.client, this.model, this.options);
case PARALLEL -> new ExperimentalParallelToolsDelegate(this.client, this.model, this.options);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;

public class OllamaMessagesUtils {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ interface OllamaConfig {
* NONE: for no tools, works as before without handling tools.
* PARALLEL: Tools we be used and simulated with one call to llm, that will answer with all tool request to execute
* and the response using the result of the tool request.
* SEQUENTIAL: Tools will be call sequentially and llm will call next tool following tool request result.
*/
@WithDefault("NONE")
Optional<String> experimentalTools();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import io.quarkiverse.langchain4j.ollama.Message;
import io.quarkiverse.langchain4j.ollama.Role;

/**
* Same as @{@link io.quarkiverse.langchain4j.ollama.OllamaMessagesUtils} but with AiStatsMessage Support
* And allow to group Messages for Experimental Tool llm request.
*/
class ExperimentalMessagesUtils {

final static Predicate<ChatMessage> isUserMessage = chatMessage -> chatMessage instanceof UserMessage;
Expand All @@ -42,7 +46,7 @@ static AiStatsMessage toAiStatsMessage(List<ChatMessage> messages) {
.findFirst().orElseThrow();
}

static AiStatsMessage withoutRequests(AiStatsMessage aiStatsMessage) {
static AiStatsMessage withoutToolRequests(AiStatsMessage aiStatsMessage) {
return new AiStatsMessage(aiStatsMessage.text(), aiStatsMessage.getTokenUsage());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,16 @@ record ToolResponse(String name, Map<String, Object> inputs, String result_id) {
public ExperimentalParallelToolsDelegate(OllamaClient client, String modelName, Options options) {
this.client = client;
this.modelName = modelName;
this.options = options;
this.options = Options.builder()
.topP(options.topP())
.topK(options.topK())
.numCtx(options.numCtx())
.numPredict(options.numPredict())
.seed(options.seed())
.stop(options.stop())
.repeatPenalty(options.repeatPenalty())
.temperature(0.0)
.build();
}

@Override
Expand Down Expand Up @@ -102,7 +111,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages,
if (aiStatsMessage.text() == null) {
throw new RuntimeException("Conclusion cannot be null or empty!");
}
return Response.from(withoutRequests(aiStatsMessage), aiStatsMessage.getTokenUsage(), FinishReason.STOP);
return Response.from(withoutToolRequests(aiStatsMessage), aiStatsMessage.getTokenUsage(), FinishReason.STOP);
}

Message systemMessage = createSystemMessageWithTools(ollamaMessages, toolSpecifications);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

public enum ExperimentalTools {
NONE,
SEQUENTIAL,
PARALLEL
}

0 comments on commit 05fff69

Please sign in to comment.