-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for tools for the ollama provider #662
Changes from 5 commits
38261d6
55037da
c3c4776
95e3d20
213c0af
e33b92d
9f951e7
19f2157
aba545d
9231922
b4e1d04
d06a779
51c8f39
6f2b938
3fe6b8d
c6d0d18
da70e5b
93ec2cf
85a650c
59a58e4
cc814b0
1768ec5
0424252
d96b632
3aa3806
a6759c8
05fff69
7d99874
61a4343
e00ecea
02e82db
cf1a560
d1e2a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package io.quarkiverse.langchain4j.ollama.deployment; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
import jakarta.enterprise.context.control.ActivateRequestContext; | ||
import jakarta.inject.Inject; | ||
import jakarta.inject.Singleton; | ||
|
||
import org.jboss.shrinkwrap.api.ShrinkWrap; | ||
import org.jboss.shrinkwrap.api.spec.JavaArchive; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.RegisterExtension; | ||
|
||
import dev.langchain4j.agent.tool.Tool; | ||
import dev.langchain4j.service.SystemMessage; | ||
import dev.langchain4j.service.UserMessage; | ||
import io.quarkiverse.langchain4j.RegisterAiService; | ||
import io.quarkus.logging.Log; | ||
import io.quarkus.test.QuarkusUnitTest; | ||
|
||
@Disabled("Integration tests that need an ollama server running") | ||
public class ToolsTest { | ||
|
||
@RegisterExtension | ||
static final QuarkusUnitTest unitTest = new QuarkusUnitTest() | ||
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) | ||
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.timeout", "60s") | ||
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true") | ||
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true") | ||
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.chat-model.temperature", "0") | ||
.overrideRuntimeConfigKey("quarkus.langchain4j.ollama.experimental-tools", "true"); | ||
|
||
@Singleton | ||
@SuppressWarnings("unused") | ||
static class ExpenseService { | ||
@Tool("useful for when you need to lookup condominium expenses for given dates.") | ||
public String getExpenses(String condominium, String fromDate, String toDate) { | ||
String result = String.format(""" | ||
The Expenses for %s from %s to %s are: | ||
- Expense hp12: 2800e | ||
- Expense 2: 15000e | ||
""", condominium, fromDate, toDate); | ||
Log.infof(result); | ||
return result; | ||
} | ||
} | ||
|
||
@RegisterAiService(tools = ExpenseService.class) | ||
public interface Assistant { | ||
@SystemMessage(""" | ||
You are a property manager assistant, answering to co-owners requests. | ||
Format the date as YYYY-MM-DD and the time as HH:MM | ||
Today is {{current_date}} use this date as date time reference | ||
The co-owners is leaving in the following condominium: {condominium} | ||
""") | ||
@UserMessage(""" | ||
{{request}} | ||
""") | ||
String answer(String condominium, String request); | ||
} | ||
|
||
@Inject | ||
Assistant assistant; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void test_simple_tool() { | ||
String response = assistant.answer("Rives de Marne", | ||
"What are the expenses for this year ?"); | ||
assertThat(response).contains("Expense hp12"); | ||
} | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void test_should_not_calls_tool() { | ||
String response = assistant.answer("Rives de Marne", "What time is it ?"); | ||
assertThat(response).doesNotContain("Expense hp12"); | ||
} | ||
|
||
@Singleton | ||
@SuppressWarnings("unused") | ||
public static class Calculator { | ||
@Tool("Calculates the length of a string") | ||
String stringLengthStr(String s) { | ||
return String.format("The length of the word %s is %d", s, s.length()); | ||
} | ||
|
||
@Tool("Calculates the sum of two numbers") | ||
String addStr(int a, int b) { | ||
return String.format("The sum of %s and %s is %d", a, b, a + b); | ||
} | ||
|
||
@Tool("Calculates the square root of a number") | ||
String sqrtStr(int x) { | ||
return String.format("The square root of %s is %f", x, Math.sqrt(x)); | ||
} | ||
} | ||
|
||
@RegisterAiService(tools = Calculator.class) | ||
public interface MathAssistant { | ||
String chat(String userMessage); | ||
} | ||
|
||
@Inject | ||
MathAssistant mathAssistant; | ||
|
||
@Test | ||
@ActivateRequestContext | ||
void test_multiple_tools() { | ||
String msg = "What is the square root with maximal precision of the sum of the numbers of letters in the words " + | ||
"\"hello\" and \"world\""; | ||
String response = mathAssistant.chat(msg); | ||
assertThat(response).contains("3.162278"); | ||
|
||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
import java.util.List; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
|
||
/** | ||
* Does not add any tools, default behavior | ||
*/ | ||
public class EmptyToolsHandler implements ToolsHandler { | ||
@Override | ||
public ChatRequest.Builder enhanceWithTools(ChatRequest.Builder requestBuilder, List<Message> messages, | ||
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted) { | ||
return requestBuilder; | ||
} | ||
|
||
@Override | ||
public AiMessage getAiMessageFromResponse(ChatResponse response, List<ToolSpecification> toolSpecifications) { | ||
return AiMessage.from(response.message().content()); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package io.quarkiverse.langchain4j.ollama; | ||
|
||
import static io.quarkiverse.langchain4j.ollama.ChatRequest.Builder; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import dev.langchain4j.agent.tool.ToolExecutionRequest; | ||
import dev.langchain4j.agent.tool.ToolParameters; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.internal.Json; | ||
import dev.langchain4j.model.input.Prompt; | ||
import dev.langchain4j.model.input.PromptTemplate; | ||
|
||
public class OllamaDefaultToolsHandler implements ToolsHandler { | ||
|
||
static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate.from(""" | ||
You have access to the following tools: | ||
|
||
{tools} | ||
|
||
You must always select one of the above tools and respond with a JSON object matching the following schema, | ||
and only this json object: | ||
{ | ||
"tool": <name of the selected tool>, | ||
"tool_input": <parameters for the selected tool, matching the tool's JSON schema> | ||
} | ||
Do not use other tools than the ones from the list above. Always provide the "tool_input" field. | ||
If several tools are necessary, answer them sequentially. | ||
|
||
When the user provides sufficient information, answer with the __conversational_response tool. | ||
"""); | ||
|
||
static final ToolSpecification DEFAULT_RESPONSE_TOOL = ToolSpecification.builder() | ||
.name("__conversational_response") | ||
.description("Respond conversationally if no other tools should be called for a given query and history.") | ||
.parameters(ToolParameters.builder() | ||
.type("object") | ||
.properties( | ||
Map.of("reponse", | ||
Map.of("type", "string", | ||
"description", "Conversational response to the user."))) | ||
.required(Collections.singletonList("response")) | ||
.build()) | ||
.build(); | ||
|
||
@Override | ||
public Builder enhanceWithTools(Builder builder, List<Message> messages, List<ToolSpecification> toolSpecifications, | ||
ToolSpecification toolThatMustBeExecuted) { | ||
if (toolSpecifications == null || toolSpecifications.isEmpty()) { | ||
return builder; | ||
} | ||
// Set Json Format | ||
builder.format("json"); | ||
|
||
// Construct prompt with tools toolSpecifications and DEFAULT_RESPONSE_TOOL | ||
List<ToolSpecification> extendedList = new ArrayList<>(toolSpecifications.size() + 1); | ||
extendedList.addAll(toolSpecifications); | ||
extendedList.add(DEFAULT_RESPONSE_TOOL); | ||
Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply( | ||
Map.of("tools", Json.toJson(extendedList))); | ||
|
||
// TODO handle -> toolThatMustBeExecuted skipped for the moment | ||
String initialSystemMessages = messages.stream().filter(cm -> cm.role() == Role.SYSTEM) | ||
.map(Message::content) | ||
.collect(Collectors.joining("\n")); | ||
|
||
Message groupedSystemMessage = Message.builder() | ||
.role(Role.SYSTEM) | ||
.content(initialSystemMessages + "\n" + prompt.text()) | ||
.build(); | ||
|
||
List<Message> otherMessages = messages.stream().filter(cm -> cm.role() != Role.SYSTEM).toList(); | ||
|
||
// Add specific tools message | ||
List<Message> messagesWithTools = new ArrayList<>(messages.size() + 1); | ||
messagesWithTools.add(groupedSystemMessage); | ||
messagesWithTools.addAll(otherMessages); | ||
|
||
builder.messages(messagesWithTools); | ||
|
||
return builder; | ||
} | ||
|
||
@Override | ||
public AiMessage getAiMessageFromResponse(ChatResponse response, List<ToolSpecification> toolSpecifications) { | ||
ToolResponse toolResponse; | ||
try { | ||
// Extract tools | ||
toolResponse = Json.fromJson(response.message().content(), ToolResponse.class); | ||
} catch (Exception e) { | ||
throw new RuntimeException("Ollama server did not respond with valid JSON. Please try again!"); | ||
} | ||
// If the tool is the final result with default response tool | ||
if (toolResponse.tool.equals(DEFAULT_RESPONSE_TOOL.name())) { | ||
return AiMessage.from(toolResponse.tool_input.get("response").toString()); | ||
} | ||
// Check if tool is part of the available tools | ||
List<String> availableTools = toolSpecifications.stream().map(ToolSpecification::name).toList(); | ||
if (!availableTools.contains(toolResponse.tool)) { | ||
return AiMessage.from(String.format( | ||
"Ollama server wants to call a tool '%s' that is not part of the available tools %s", | ||
toolResponse.tool, availableTools)); | ||
} | ||
// Extract tools request from response | ||
List<ToolExecutionRequest> toolExecutionRequests = toToolExecutionRequests(toolResponse, toolSpecifications); | ||
return AiMessage.aiMessage(toolExecutionRequests); | ||
} | ||
|
||
record ToolResponse(String tool, Map<String, Object> tool_input) { | ||
} | ||
|
||
private List<ToolExecutionRequest> toToolExecutionRequests(ToolResponse toolResponse, | ||
List<ToolSpecification> toolSpecifications) { | ||
return toolSpecifications.stream() | ||
.filter(ts -> ts.name().equals(toolResponse.tool)) | ||
.map(ts -> toToolExecutionRequest(toolResponse, ts)) | ||
.toList(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally try hard to avoid lambdas in Quarkus code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why (just curious)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Quarkus started, the team found that the lambdas had a small (but not zero) impact on memory usage. Mind you, this on Java 8, so things may have changed substantially since then, but we still try to avoid them unless the alternative is just plain terrible. |
||
} | ||
|
||
static ToolExecutionRequest toToolExecutionRequest(ToolResponse toolResponse, ToolSpecification toolSpecification) { | ||
return ToolExecutionRequest.builder() | ||
.name(toolSpecification.name()) | ||
.arguments(Json.toJson(toolResponse.tool_input)) | ||
.build(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally don't write such tests, but instead use Wiremock (see the OpenAI module for tools related tests)