Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for tools for the ollama provider #662

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
38261d6
Proposition for : #305, tested on llama3
humcqc Jun 10, 2024
55037da
Test class missing in previous commit
humcqc Jun 10, 2024
c3c4776
Fix typo
humcqc Jun 10, 2024
95e3d20
Review 1
humcqc Jun 10, 2024
213c0af
Fix experimentalTool default value
humcqc Jun 10, 2024
e33b92d
New Prompt, but still some issues with send poem. Need to tune it to …
humcqc Jun 12, 2024
9f951e7
Fix Format
humcqc Jun 12, 2024
19f2157
Fix Format 2
humcqc Jun 12, 2024
aba545d
Fix Test model
humcqc Jun 12, 2024
9231922
New approach
humcqc Jun 15, 2024
b4e1d04
small enhancement still need to fix get_Expenses
humcqc Jun 15, 2024
d06a779
Prompt Optimisation for Llama3
humcqc Jun 15, 2024
51c8f39
fix format
humcqc Jun 15, 2024
6f2b938
Merge branch 'quarkiverse:main' into main
humcqc Jun 16, 2024
3fe6b8d
change tests order
humcqc Jun 16, 2024
c6d0d18
Merge branch 'quarkiverse:main' into main
humcqc Jun 24, 2024
da70e5b
Merge branch 'quarkiverse:main' into main
humcqc Jun 30, 2024
93ec2cf
- Depends on https://github.com/langchain4j/langchain4j/pull/1353
humcqc Jun 30, 2024
85a650c
- Add Missing Class
humcqc Jun 30, 2024
59a58e4
Merge branch 'main' into main
humcqc Jul 4, 2024
cc814b0
- Implementation based on released langchain 0.32.0.
humcqc Jul 15, 2024
1768ec5
Merge branch 'quarkiverse:main' into main
humcqc Jul 15, 2024
0424252
- Clean up revert and merge.
humcqc Jul 15, 2024
d96b632
- Fix langchain4j undependency.
humcqc Jul 15, 2024
3aa3806
- Fix langchain4j unwanted dependency 2
humcqc Jul 15, 2024
a6759c8
- Fix native build
humcqc Jul 15, 2024
05fff69
- @geoand review 1
humcqc Jul 16, 2024
7d99874
Merge branch 'main' into main
humcqc Jul 24, 2024
61a4343
- Fix build + rename file to be more explicit + add unit test on Vari…
humcqc Jul 24, 2024
e00ecea
- Fix import Order
humcqc Jul 24, 2024
02e82db
- Fix conflict merge impact
humcqc Jul 24, 2024
cf1a560
- Prompt Optimization for llama3.1
humcqc Jul 24, 2024
d1e2a53
- Fix format
humcqc Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package io.quarkiverse.langchain4j.ollama.deployment;

import static org.assertj.core.api.Assertions.assertThat;

import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.logging.Log;
import io.quarkus.test.QuarkusUnitTest;

@Disabled("Integration tests that need an ollama server running")
public class ToolsTest {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally don't write such tests, but instead use Wiremock (see the OpenAI module for tools related tests)


@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class))
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.timeout", "60s")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.chat-model.temperature", "0")
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.experimental-tools", "true");

@Singleton
@SuppressWarnings("unused")
static class ExpenseService {
@Tool("useful for when you need to lookup condominium expenses for given dates.")
public String getExpenses(String condominium, String fromDate, String toDate) {
String result = String.format("""
The Expenses for %s from %s to %s are:
- Expense hp12: 2800e
- Expense 2: 15000e
""", condominium, fromDate, toDate);
Log.infof(result);
return result;
}
}

@RegisterAiService(tools = ExpenseService.class)
public interface Assistant {
@SystemMessage("""
You are a property manager assistant, answering to co-owners requests.
Format the date as YYYY-MM-DD and the time as HH:MM
Today is {{current_date}} use this date as date time reference
The co-owners is leaving in the following condominium: {condominium}
""")
@UserMessage("""
{{request}}
""")
String answer(String condominium, String request);
}

@Inject
Assistant assistant;

@Test
@ActivateRequestContext
void test_simple_tool() {
String response = assistant.answer("Rives de Marne",
"What are the expenses for this year ?");
assertThat(response).contains("Expense hp12");
}

@Test
@ActivateRequestContext
void test_should_not_calls_tool() {
String response = assistant.answer("Rives de Marne", "What time is it ?");
assertThat(response).doesNotContain("Expense hp12");
}

@Singleton
@SuppressWarnings("unused")
public static class Calculator {
@Tool("Calculates the length of a string")
String stringLengthStr(String s) {
return String.format("The length of the word %s is %d", s, s.length());
}

@Tool("Calculates the sum of two numbers")
String addStr(int a, int b) {
return String.format("The sum of %s and %s is %d", a, b, a + b);
}

@Tool("Calculates the square root of a number")
String sqrtStr(int x) {
return String.format("The square root of %s is %f", x, Math.sqrt(x));
}
}

@RegisterAiService(tools = Calculator.class)
public interface MathAssistant {
String chat(String userMessage);
}

@Inject
MathAssistant mathAssistant;

@Test
@ActivateRequestContext
void test_multiple_tools() {
String msg = "What is the square root with maximal precision of the sum of the numbers of letters in the words " +
"\"hello\" and \"world\"";
String response = mathAssistant.chat(msg);
assertThat(response).contains("3.162278");

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkiverse.langchain4j.ollama;

import java.util.List;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;

/**
* Does not add any tools, default behavior
*/
public class EmptyToolsHandler implements ToolsHandler {
@Override
public ChatRequest.Builder enhanceWithTools(ChatRequest.Builder requestBuilder, List<Message> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted) {
return requestBuilder;
}

@Override
public AiMessage getAiMessageFromResponse(ChatResponse response, List<ToolSpecification> toolSpecifications) {
return AiMessage.from(response.message().content());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ private static Message otherMessages(ChatMessage chatMessage) {
private static Role toOllamaRole(ChatMessageType chatMessageType) {
return switch (chatMessageType) {
case SYSTEM -> Role.SYSTEM;
case USER -> Role.USER;
case USER, TOOL_EXECUTION_RESULT -> Role.USER;
case AI -> Role.ASSISTANT;
default -> throw new IllegalArgumentException("Unknown ChatMessageType: " + chatMessageType);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static io.quarkiverse.langchain4j.ollama.MessageMapper.toOllamaMessages;
import static java.util.Collections.singletonList;

import java.time.Duration;
import java.util.List;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -18,12 +20,14 @@ public class OllamaChatLanguageModel implements ChatLanguageModel {
private final String model;
private final String format;
private final Options options;
private final ToolsHandler toolsHandler;

private OllamaChatLanguageModel(Builder builder) {
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses);
model = builder.model;
format = builder.format;
options = builder.options;
toolsHandler = builder.experimentalTool ? ToolsHandlerFactory.get(model) : new EmptyToolsHandler();
}

public static Builder builder() {
Expand All @@ -32,20 +36,38 @@ public static Builder builder() {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
}

private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
ensureNotEmpty(messages, "messages");

ChatRequest request = ChatRequest.builder()
List<Message> ollamaMessages = toOllamaMessages(messages);
ChatRequest.Builder requestBuilder = ChatRequest.builder()
.model(model)
.messages(toOllamaMessages(messages))
.messages(ollamaMessages)
.options(options)
.format(format)
.stream(false)
.build();
.stream(false);

requestBuilder = toolsHandler.enhanceWithTools(requestBuilder, ollamaMessages,
toolSpecifications, toolThatMustBeExecuted);
ChatResponse response = client.chat(requestBuilder.build());
AiMessage aiMessage = toolsHandler.getAiMessageFromResponse(response, toolSpecifications);
return Response.from(aiMessage, new TokenUsage(response.promptEvalCount(), response.evalCount()));

ChatResponse response = client.chat(request);
return Response.from(
AiMessage.from(response.message().content()),
new TokenUsage(response.promptEvalCount(), response.evalCount()));
}

public static final class Builder {
Expand All @@ -57,6 +79,7 @@ public static final class Builder {

private boolean logRequests = false;
private boolean logResponses = false;
private boolean experimentalTool = false;

private Builder() {
}
Expand Down Expand Up @@ -96,6 +119,11 @@ public Builder logResponses(boolean logResponses) {
return this;
}

public Builder experimentalTool(boolean experimentalTool) {
this.experimentalTool = experimentalTool;
return this;
}

public OllamaChatLanguageModel build() {
return new OllamaChatLanguageModel(this);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package io.quarkiverse.langchain4j.ollama;

import static io.quarkiverse.langchain4j.ollama.ChatRequest.Builder;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;

public class OllamaDefaultToolsHandler implements ToolsHandler {

static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate.from("""
You have access to the following tools:

{tools}

You must always select one of the above tools and respond with a JSON object matching the following schema,
and only this json object:
{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}
Do not use other tools than the ones from the list above. Always provide the "tool_input" field.
If several tools are necessary, answer them sequentially.

When the user provides sufficient information, answer with the __conversational_response tool.
""");

static final ToolSpecification DEFAULT_RESPONSE_TOOL = ToolSpecification.builder()
.name("__conversational_response")
.description("Respond conversationally if no other tools should be called for a given query and history.")
.parameters(ToolParameters.builder()
.type("object")
.properties(
Map.of("reponse",
Map.of("type", "string",
"description", "Conversational response to the user.")))
.required(Collections.singletonList("response"))
.build())
.build();

@Override
public Builder enhanceWithTools(Builder builder, List<Message> messages, List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
if (toolSpecifications == null || toolSpecifications.isEmpty()) {
return builder;
}
// Set Json Format
builder.format("json");

// Construct prompt with tools toolSpecifications and DEFAULT_RESPONSE_TOOL
List<ToolSpecification> extendedList = new ArrayList<>(toolSpecifications.size() + 1);
extendedList.addAll(toolSpecifications);
extendedList.add(DEFAULT_RESPONSE_TOOL);
Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply(
Map.of("tools", Json.toJson(extendedList)));

// TODO handle -> toolThatMustBeExecuted skipped for the moment
String initialSystemMessages = messages.stream().filter(cm -> cm.role() == Role.SYSTEM)
.map(Message::content)
.collect(Collectors.joining("\n"));

Message groupedSystemMessage = Message.builder()
.role(Role.SYSTEM)
.content(initialSystemMessages + "\n" + prompt.text())
.build();

List<Message> otherMessages = messages.stream().filter(cm -> cm.role() != Role.SYSTEM).toList();

// Add specific tools message
List<Message> messagesWithTools = new ArrayList<>(messages.size() + 1);
messagesWithTools.add(groupedSystemMessage);
messagesWithTools.addAll(otherMessages);

builder.messages(messagesWithTools);

return builder;
}

@Override
public AiMessage getAiMessageFromResponse(ChatResponse response, List<ToolSpecification> toolSpecifications) {
ToolResponse toolResponse;
try {
// Extract tools
toolResponse = Json.fromJson(response.message().content(), ToolResponse.class);
} catch (Exception e) {
throw new RuntimeException("Ollama server did not respond with valid JSON. Please try again!");
}
// If the tool is the final result with default response tool
if (toolResponse.tool.equals(DEFAULT_RESPONSE_TOOL.name())) {
return AiMessage.from(toolResponse.tool_input.get("response").toString());
}
// Check if tool is part of the available tools
List<String> availableTools = toolSpecifications.stream().map(ToolSpecification::name).toList();
if (!availableTools.contains(toolResponse.tool)) {
return AiMessage.from(String.format(
"Ollama server wants to call a tool '%s' that is not part of the available tools %s",
toolResponse.tool, availableTools));
}
// Extract tools request from response
List<ToolExecutionRequest> toolExecutionRequests = toToolExecutionRequests(toolResponse, toolSpecifications);
return AiMessage.aiMessage(toolExecutionRequests);
}

record ToolResponse(String tool, Map<String, Object> tool_input) {
}

private List<ToolExecutionRequest> toToolExecutionRequests(ToolResponse toolResponse,
List<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream()
.filter(ts -> ts.name().equals(toolResponse.tool))
.map(ts -> toToolExecutionRequest(toolResponse, ts))
.toList();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally try hard to avoid lambdas in Quarkus code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why (just curious)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When Quarkus started, the team found that the lambdas had a small (but not zero) impact on memory usage.

Mind you, this on Java 8, so things may have changed substantially since then, but we still try to avoid them unless the alternative is just plain terrible.

}

static ToolExecutionRequest toToolExecutionRequest(ToolResponse toolResponse, ToolSpecification toolSpecification) {
return ToolExecutionRequest.builder()
.name(toolSpecification.name())
.arguments(Json.toJson(toolResponse.tool_input))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

@JsonDeserialize(using = RoleDeserializer.class)
public enum Role {

SYSTEM,
USER,
ASSISTANT
ASSISTANT,
TOOL_EXECUTION_RESULT
}
Loading
Loading