diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index 45a4f9f57b35..2110c9ebdd6f 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -11,6 +11,9 @@ class TextInferenceEngine { virtual uint32_t start(const rust::Str prompt) const = 0; virtual uint32_t step(uint32_t next_token_id) const = 0; + virtual void end() const = 0; + + virtual uint32_t eos_token() const = 0; }; std::shared_ptr create_engine(rust::Str model_path); diff --git a/crates/llama-cpp-bindings/llama.cpp b/crates/llama-cpp-bindings/llama.cpp index bce1fef32894..06fc4020de0b 160000 --- a/crates/llama-cpp-bindings/llama.cpp +++ b/crates/llama-cpp-bindings/llama.cpp @@ -1 +1 @@ -Subproject commit bce1fef328941499dc0acb76cc7fd7ac90449c2f +Subproject commit 06fc4020de0b92ee13407fdabca7870f53c75de5 diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index b1b0dfd4f71d..a5148954149a 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -37,6 +37,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { uint32_t start(const rust::Str prompt) const override { auto* ctx = ctx_.get(); + llama_reset_timings(ctx); std::vector tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true); eval(tokens_list, /* reset = */ true); return sample(); @@ -47,6 +48,14 @@ class TextInferenceEngineImpl : public TextInferenceEngine { return sample(); } + void end() const override { + llama_print_timings(ctx_.get()); + } + + uint32_t eos_token() const override { + return llama_token_eos(ctx_.get()); + } + private: uint32_t sample() const { auto* ctx = ctx_.get(); @@ -65,7 +74,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { tokens_list.data(), tokens_list.size(), reset ? 0 : llama_get_kv_cache_token_count(ctx), - /* n_threads = */ 1)) { + /* n_threads = */ 4)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -92,7 +101,8 @@ std::shared_ptr create_engine(rust::Str model_path) { static BackendInitializer initializer; llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_gpu_layers = 4; + ctx_params.n_ctx = 2048; + ctx_params.n_gpu_layers = 1; llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params); diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 81dd40ceb93c..1f144aca9f83 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -19,6 +19,9 @@ mod ffi { fn start(&self, prompt: &str) -> u32; fn step(&self, next_token_id: u32) -> u32; + fn end(&self); + + fn eos_token(&self) -> u32; } } @@ -62,7 +65,13 @@ impl TextGeneration for LlamaEngine { let output_ids = tokio::task::spawn_blocking(move || { let engine = engine.lock().unwrap(); + let eos_token = engine.eos_token(); + let mut next_token_id = engine.start(&prompt); + if next_token_id == eos_token { + return Vec::new(); + } + let mut n_remains = options.max_decoding_length - 1; let mut output_ids = vec![next_token_id]; @@ -73,6 +82,10 @@ impl TextGeneration for LlamaEngine { } next_token_id = engine.step(next_token_id); + if next_token_id == eos_token { + break; + } + if stop_condition.next_token(next_token_id) { break; } @@ -80,11 +93,11 @@ impl TextGeneration for LlamaEngine { n_remains -= 1; } + engine.end(); output_ids }) .await .expect("Inference failed"); - self.tokenizer.decode(&output_ids, true).unwrap() } }