Skip to content

Commit

Permalink
Prompt Optimisation for Llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
humcqc committed Jun 15, 2024
1 parent b4e1d04 commit d06a779
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int add(int a, int b) {
@Singleton
@SuppressWarnings("unused")
static class ExpenseService {
@Tool("get condominium expenses for given dates.")
@Tool("Get expenses for a given condominium, from date and to date.")
public String getExpenses(String condominium, String fromDate, String toDate) {
String result = String.format("""
The Expenses for %s from %s to %s are:
Expand All @@ -55,9 +55,4 @@ public void sendAnEmail(String content) {
""");
}
}

@Tool(name = "__conversational_response", value = "Respond conversationally if no other tools should be called for a given query.")
public String conversation(String response) {
return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public interface Assistant {
@SystemMessage("""
You are a property manager assistant, answering to co-owners requests.
Format the date as YYYY-MM-DD and the time as HH:MM
Today is {{current_time}} use this date as date time reference
Today is {{current_date_time}} use this date as date time reference
The co-owners is living in the following condominium: {condominium}
""")
@UserMessage("""
Expand Down Expand Up @@ -86,7 +86,7 @@ public interface PoemService {
@ActivateRequestContext
void send_a_poem() {
String response = poemService.writeAPoem("Condominium Rives de marne", 4);
assertThat(response).contains("he poem has been sent by email.");
assertThat(response).contains("sent by email");
}

@RegisterAiService(modelName = MODEL_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,33 @@
import dev.langchain4j.model.output.TokenUsage;
import io.quarkus.runtime.annotations.RegisterForReflection;

public class ToolsHandler {
public class ExperimentalToolsChatLM {

private static final Logger log = Logger.getLogger(ToolsHandler.class);
private static final Logger log = Logger.getLogger(ExperimentalToolsChatLM.class);

private static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate
.from("""
You are a helpful AI assistant responding to user requests.
--- Context ---
{context}
---------------
You are a helpful AI assistant responding to user requests taking into account the previous context.
You have access to the following tools:
{tools}
Select the most appropriate tools from this list, and respond with a JSON object containing required "tools" and "response" fields:
Create a list of most appropriate tools to call in order to answer to the user request.
If no tools are required respond with response field directly.
Respond with a JSON object containing required "tools" and required not null "response" fields:
- "tools": a list of selected tools in JSON format, each with:
- "name": <selected tool name>
- "inputs": <required parameters matching the tool's JSON schema>
- "inputs": <required parameters using tools result_id matching the tool's JSON schema>
- "result_id": <an id to identify the result of this tool, e.g., id1>
- "response": <answer or description of what have been done. Could use tools result_id>
- "response": < Summary of tools used with your response using tools result_id>
Guidelines:
- Reference previous tools results using the format: $(xxx), where xxx is a result_id.
- Only reference previous tools results using the format: $(xxx), where xxx is a previous result_id.
- Break down complex requests into sequential and necessary tools.
- Use previous results through result_id for inputs response, do not invent them.
""");

@RegisterForReflection
Expand All @@ -53,26 +60,21 @@ record ToolResponse(String name, Map<String, Object> inputs, String result_id) {
public Response<AiMessage> chat(OllamaClient client, Builder builder, List<Message> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
// Test if it's an AI Service request with tool results
// Test if it's an AI request with tools execution response.
boolean hasResultMessages = messages.stream().anyMatch(m -> m.role() == Role.TOOL_EXECUTION_RESULT);
if (hasResultMessages) {
String result = messages.stream().filter(term -> term.role() == Role.ASSISTANT)
.map(Message::content).collect(Collectors.joining("\n"));
return Response.from(AiMessage.from(result));
}

// Creates Chat request
builder.format("json");
Message groupedSystemMessage = createSystemMessageWithTools(messages, toolSpecifications);
Message systemMessage = createSystemMessageWithTools(messages, toolSpecifications);

List<Message> otherMessages = messages.stream().filter(cm -> cm.role() != Role.SYSTEM).toList();
Message initialUserMessage = Message.builder()
.role(Role.USER)
.content("--- " + otherMessages.get(0).content() + " ---").build();

List<Message> messagesWithTools = new ArrayList<>(messages.size() + 1);
messagesWithTools.add(groupedSystemMessage);
messagesWithTools.addAll(otherMessages.subList(1, otherMessages.size()));
messagesWithTools.add(initialUserMessage);
List<Message> messagesWithTools = new ArrayList<>(otherMessages.size() + 1);
messagesWithTools.add(systemMessage);
messagesWithTools.addAll(otherMessages);

builder.messages(messagesWithTools);

Expand All @@ -82,18 +84,16 @@ public Response<AiMessage> chat(OllamaClient client, Builder builder, List<Messa
}

private Message createSystemMessageWithTools(List<Message> messages, List<ToolSpecification> toolSpecifications) {
Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply(
Map.of("tools", Json.toJson(toolSpecifications)));

String initialSystemMessages = messages.stream().filter(sm -> sm.role() == Role.SYSTEM)
.map(Message::content)
.collect(Collectors.joining("\n"));

Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply(
Map.of("tools", Json.toJson(toolSpecifications),
"context", initialSystemMessages));
return Message.builder()
.role(Role.SYSTEM)
.content(prompt.text() + "\n" + initialSystemMessages)
.content(prompt.text())
.build();

}

private AiMessage handleResponse(ChatResponse response, List<ToolSpecification> toolSpecifications) {
Expand Down Expand Up @@ -123,7 +123,7 @@ private AiMessage handleResponse(ChatResponse response, List<ToolSpecification>
}
}

if (toolResponses.response != null) {
if (toolResponses.response != null && !toolResponses.response().isEmpty()) {
return new AiMessage(toolResponses.response, toolExecutionRequests);
}
return AiMessage.from(toolExecutionRequests);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ public class OllamaChatLanguageModel implements ChatLanguageModel {
private final String model;
private final String format;
private final Options options;
private final ToolsHandler toolsHandler;
private final ExperimentalToolsChatLM experimentalToolsChatLM;

private OllamaChatLanguageModel(Builder builder) {
client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses);
model = builder.model;
format = builder.format;
options = builder.options;
toolsHandler = builder.experimentalTool ? new ToolsHandler() : null;
experimentalToolsChatLM = builder.experimentalTool ? new ExperimentalToolsChatLM() : null;
}

public static Builder builder() {
Expand Down Expand Up @@ -61,9 +61,10 @@ private Response<AiMessage> generate(List<ChatMessage> messages,
.options(options)
.format(format)
.stream(false);
boolean isToolNeeded = toolsHandler != null && toolSpecifications != null && !toolSpecifications.isEmpty();
boolean isToolNeeded = experimentalToolsChatLM != null && toolSpecifications != null && !toolSpecifications.isEmpty();
if (isToolNeeded) {
return toolsHandler.chat(client, requestBuilder, ollamaMessages, toolSpecifications, toolThatMustBeExecuted);
return experimentalToolsChatLM.chat(client, requestBuilder, ollamaMessages, toolSpecifications,
toolThatMustBeExecuted);
} else {
ChatResponse response = client.chat(requestBuilder.build());
AiMessage aiMessage = AiMessage.from(response.message().content());
Expand Down

0 comments on commit d06a779

Please sign in to comment.