Skip to content

Commit

Permalink
tool-call: force printing of lazy grammar trigger tokens to regular…
Browse files Browse the repository at this point in the history
…ize function call parsing
  • Loading branch information
ochafik committed Oct 29, 2024
1 parent fa4c111 commit 773ff91
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
8 changes: 3 additions & 5 deletions common/tool-call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,20 +455,18 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_schema("root", schema);
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
handler.grammar_trigger_words.push_back("[{\"");
handler.grammar_trigger_words.push_back("[ { \"");
}
// auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
break;
}
case llama_tool_call_style::Llama31:
case llama_tool_call_style::Llama32: {
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
static auto builtin_tools = json {"wolfram_alpha", "brave_search", "code_interpreter"};

auto uses_python_tag = style == llama_tool_call_style::Llama31;

Expand Down Expand Up @@ -569,7 +567,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python") {
if (name == "python" || name == "ipython") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
Expand Down
5 changes: 3 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,11 +1062,12 @@ struct server_context {
}

bool process_token(completion_token_output & result, server_slot & slot) {
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);

// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special || (match.pos != std::string::npos && match.is_grammar_trigger));
slot.sampled = result.tok;

auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params.special));
Expand Down

0 comments on commit 773ff91

Please sign in to comment.