From cb1632b5938ec36ed5f46ce929a6681e168ac216 Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Fri, 11 Oct 2024 12:20:48 +1100 Subject: [PATCH 1/6] llama : adds llama-grammar memorization stacks (#4218) --- src/llama-grammar.cpp | 118 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 74e9f64b393b2..22c63ebfe7ea9 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +// additionally memorizes the stack to its possible stacks by mapping +// < llama_grammar_stack, llama_grammar_stacks > + +struct VectorPointerHash { + size_t operator()(const llama_grammar_stack & v) const { + size_t seed = v.size(); + for (const auto* ptr : v) { + seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +static std::unordered_map< + llama_grammar_stack, + llama_grammar_stacks, + VectorPointerHash> + llama_grammar_stacks_cache = {}; + +static void llama_grammar_advance_stack_memo( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks); + +static void llama_grammar_advance_stack_memo_impl( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + new_stacks.emplace_back(stack); + } + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack_memo(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(stack); + } + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ABORT("fatal error"); + } +} + +static void llama_grammar_advance_stack_memo( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + + llama_grammar_stacks advanced_stacks; + // Look if stack is already in memory + auto it = llama_grammar_stacks_cache.find(stack); + if (it != llama_grammar_stacks_cache.end()) { + advanced_stacks = it->second; + } else { + // Advance stacks with memorization + llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks); + llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks)); + } + // Add the advanced stacks to new_stacks avoiding duplicates + for (const auto & new_stack : advanced_stacks) { + if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) { + new_stacks.emplace_back(new_stack); + } + } + +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -844,7 +952,7 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, stacks_new); + llama_grammar_advance_stack_memo(rules, new_stack, stacks_new); } } } @@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { + // Clear stacks cache + llama_grammar_stacks_cache.clear(); const llama_grammar_element * pos; // copy rule definitions into vectors @@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl( } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + // Clear stacks cache + llama_grammar_stacks_cache.clear(); llama_grammar_parser parser; // if there is a grammar, parse it @@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; From 901a3479b10c3e71a29b86050bcae25d98102908 Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Mon, 14 Oct 2024 17:13:40 +1100 Subject: [PATCH 2/6] move cache stack to advance stack --- examples/gbnf-validator/gbnf-validator.cpp | 3 +- src/llama-grammar.cpp | 53 ++++++++-------------- src/llama-grammar.h | 16 ++++++- tests/test-grammar-integration.cpp | 3 +- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 7493af9d3aec3..2cf3bb0477e63 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -15,10 +15,11 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; + llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); if (stacks_cur.empty()) { error_pos = pos; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 22c63ebfe7ea9..af72de9e0d0a2 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char( // additionally memorizes the stack to its possible stacks by mapping // < llama_grammar_stack, llama_grammar_stacks > -struct VectorPointerHash { - size_t operator()(const llama_grammar_stack & v) const { - size_t seed = v.size(); - for (const auto* ptr : v) { - seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - -static std::unordered_map< - llama_grammar_stack, - llama_grammar_stacks, - VectorPointerHash> - llama_grammar_stacks_cache = {}; - static void llama_grammar_advance_stack_memo( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks); + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache); static void llama_grammar_advance_stack_memo_impl( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache) { if (stack.empty()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { new_stacks.emplace_back(stack); @@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl( // if alternate is nonempty, add to stack new_stack.push_back(subpos); } - llama_grammar_advance_stack_memo(rules, new_stack, new_stacks); + llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl( static void llama_grammar_advance_stack_memo( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache) { llama_grammar_stacks advanced_stacks; // Look if stack is already in memory - auto it = llama_grammar_stacks_cache.find(stack); - if (it != llama_grammar_stacks_cache.end()) { + auto it = stacks_cache.find(stack); + if (it != stacks_cache.end()) { advanced_stacks = it->second; } else { // Advance stacks with memorization - llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks); - llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks)); + llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); + stacks_cache.insert(make_pair(stack, advanced_stacks)); } // Add the advanced stacks to new_stacks avoiding duplicates for (const auto & new_stack : advanced_stacks) { @@ -934,7 +921,8 @@ void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, const uint32_t chr, - llama_grammar_stacks & stacks_new) { + llama_grammar_stacks & stacks_new, + llama_grammar_stacks_cache & stacks_cache) { stacks_new.clear(); stacks_new.reserve(stacks.size()); @@ -952,7 +940,7 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack_memo(rules, new_stack, stacks_new); + llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); } } } @@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { - // Clear stacks cache - llama_grammar_stacks_cache.clear(); const llama_grammar_element * pos; // copy rule definitions into vectors @@ -1048,6 +1034,7 @@ struct llama_grammar * llama_grammar_init_impl( // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; + llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1055,7 +1042,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl( } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { - // Clear stacks cache - llama_grammar_stacks_cache.clear(); llama_grammar_parser parser; // if there is a grammar, parse it @@ -1128,6 +1113,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; + llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1135,7 +1121,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto & code_points = decoded.first; llama_grammar_stacks stacks_new; + llama_grammar_stacks_cache stacks_cache; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache); grammar.stacks = std::move(stacks_new); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f529ce351e416..de5e16874034f 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama-impl.h" #include +#include struct llama_vocab; @@ -61,6 +62,18 @@ using llama_grammar_candidates = std::vector; const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); +struct VectorPointerHash { + size_t operator()(const llama_grammar_stack & v) const { + size_t seed = v.size(); + for (const auto* ptr : v) { + seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +using llama_grammar_stacks_cache = std::unordered_map; + // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those @@ -69,7 +82,8 @@ void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, uint32_t chr, - llama_grammar_stacks & stacks_new); + llama_grammar_stacks & stacks_new, + llama_grammar_stacks_cache & stacks_cache); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 5cc0cdb04751f..dc260b55a8159 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -35,10 +35,11 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point From 2aa6dd273a4bb453bc03c91e3b5485c2d0c3bdd2 Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Thu, 17 Oct 2024 14:30:07 +1100 Subject: [PATCH 3/6] add stacks cache into llama_grammar --- examples/gbnf-validator/gbnf-validator.cpp | 2 +- src/llama-grammar.cpp | 11 +++++++---- src/llama-grammar.h | 9 ++++++--- tests/test-grammar-integration.cpp | 2 +- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 2cf3bb0477e63..646b8e176e9e3 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -13,9 +13,9 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); size_t pos = 0; - llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index af72de9e0d0a2..2148207901fdb 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -917,6 +917,10 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { + return grammar->stacks_cache; +} + void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, @@ -1058,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { @@ -1137,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1225,10 +1229,9 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto & code_points = decoded.first; llama_grammar_stacks stacks_new; - llama_grammar_stacks_cache stacks_cache; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); grammar.stacks = std::move(stacks_new); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index de5e16874034f..42ab06cd89431 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -59,9 +59,6 @@ using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - struct VectorPointerHash { size_t operator()(const llama_grammar_stack & v) const { size_t seed = v.size(); @@ -74,6 +71,10 @@ struct VectorPointerHash { using llama_grammar_stacks_cache = std::unordered_map; +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar); + // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those @@ -129,6 +130,8 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + // cache N possible stacks from a stack + llama_grammar_stacks_cache stacks_cache; }; // diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index dc260b55a8159..0883120d1b805 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -34,8 +34,8 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); - llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy From 17b3a3e8ccd503b635f4d2a30110989c315afe02 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Oct 2024 12:19:28 +0300 Subject: [PATCH 4/6] llama : minor llama_grammar refactoring ggml-ci --- examples/gbnf-validator/gbnf-validator.cpp | 12 +++---- src/llama-grammar.cpp | 42 ++++++++++------------ src/llama-grammar.h | 14 +++----- tests/test-grammar-integration.cpp | 10 ++---- tests/test-llama-grammar.cpp | 6 ++-- 5 files changed, 33 insertions(+), 51 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 646b8e176e9e3..17a0e27c444e8 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -11,20 +11,15 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { const auto cpts = unicode_cpts_from_utf8(input_str); - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); - llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); + auto & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); + llama_grammar_accept(grammar, cpt); if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; - stacks_cur = stacks_prev; return false; } ++pos; @@ -83,7 +78,8 @@ int main(int argc, char** argv) { llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); + fprintf(stdout, "Failed to initialize llama_grammar\n"); + return 1; } // Read the input file std::string input_str; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 2148207901fdb..1380dfc2abe7f 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo( if (it != stacks_cache.end()) { advanced_stacks = it->second; } else { - // Advance stacks with memorization + // Advance stacks with memorization llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); stacks_cache.insert(make_pair(stack, advanced_stacks)); } @@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { - return grammar->stacks_cache; -} - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & stacks_new, - llama_grammar_stacks_cache & stacks_cache) { - stacks_new.clear(); - stacks_new.reserve(stacks.size()); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); - for (const auto & stack : stacks) { + for (const auto & stack : grammar->stacks) { if (stack.empty()) { continue; } @@ -944,9 +935,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); + llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache); } } + + grammar->stacks = std::move(stacks_new); } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { @@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1153,7 +1146,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; + llama_grammar * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.stacks_cache, + grammar.partial_utf8, + }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -1161,7 +1160,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks stacks_new; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); - grammar.stacks = std::move(stacks_new); + llama_grammar_accept(&grammar, *it); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 42ab06cd89431..3f13fee4ff6f0 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -71,20 +71,15 @@ struct VectorPointerHash { using llama_grammar_stacks_cache = std::unordered_map; +// TODO: remove, needed for tests atm const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar); // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - uint32_t chr, - llama_grammar_stacks & stacks_new, - llama_grammar_stacks_cache & stacks_cache); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -128,10 +123,11 @@ struct llama_grammar { const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; - // buffer for partially generated UTF-8 sequence from accepted tokens - llama_partial_utf8 partial_utf8; // cache N possible stacks from a stack llama_grammar_stacks_cache stacks_cache; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; }; // diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0883120d1b805..e1bdbb9250fca 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -32,14 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { static bool match_string(const std::string & input, llama_grammar * grammar) { const auto cpts = unicode_cpts_from_utf8(input); - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); - llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); + auto & stacks_cur = llama_grammar_get_stacks(grammar); for (const auto & cpt : cpts) { - const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); + llama_grammar_accept(grammar, cpt); if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point @@ -64,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test - const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); + const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 6f1374ca8ed58..e2129206be156 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -113,12 +113,10 @@ int main() } } - llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) - { + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } From a33fbbe411ccf9125176b99ab1f4cbb754d1245a Mon Sep 17 00:00:00 2001 From: Clarissa Miranda <52264247+clarismiranda@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:44:24 +1100 Subject: [PATCH 5/6] Update spelling in memoize Co-authored-by: Clint Herron --- src/llama-grammar.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 1380dfc2abe7f..06749be8349c3 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -684,7 +684,7 @@ static bool llama_grammar_match_partial_char( // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) -// additionally memorizes the stack to its possible stacks by mapping +// additionally memoizes the stack to its possible stacks by mapping // < llama_grammar_stack, llama_grammar_stacks > static void llama_grammar_advance_stack_memo( From dc68a59064ac77a2088bedf00913a1b8b33ffa5b Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Thu, 17 Oct 2024 21:49:31 +1100 Subject: [PATCH 6/6] update spelling --- src/llama-grammar.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 06749be8349c3..53f49dcbf300b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo( if (it != stacks_cache.end()) { advanced_stacks = it->second; } else { - // Advance stacks with memorization + // Advance stacks with memoization llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); stacks_cache.insert(make_pair(stack, advanced_stacks)); }