Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tune llama metal backend performance #393

Merged
merged 4 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class TextInferenceEngine {

virtual uint32_t start(const rust::Str prompt) const = 0;
virtual uint32_t step(uint32_t next_token_id) const = 0;
virtual void end() const = 0;

virtual uint32_t eos_token() const = 0;
};

std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
Expand Down
14 changes: 12 additions & 2 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {

uint32_t start(const rust::Str prompt) const override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list = tokenize(ctx, std::string(prompt), /* add_bos = */ true);
eval(tokens_list, /* reset = */ true);
return sample();
Expand All @@ -47,6 +48,14 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
return sample();
}

void end() const override {
llama_print_timings(ctx_.get());
}

uint32_t eos_token() const override {
return llama_token_eos(ctx_.get());
}

private:
uint32_t sample() const {
auto* ctx = ctx_.get();
Expand All @@ -65,7 +74,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
tokens_list.data(),
tokens_list.size(),
reset ? 0 : llama_get_kv_cache_token_count(ctx),
/* n_threads = */ 1)) {
/* n_threads = */ 4)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
Expand All @@ -92,7 +101,8 @@ std::shared_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
static BackendInitializer initializer;

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_gpu_layers = 4;
ctx_params.n_ctx = 2048;
ctx_params.n_gpu_layers = 1;

llama_model* model = llama_load_model_from_file(std::string(model_path).c_str(), ctx_params);

Expand Down
15 changes: 14 additions & 1 deletion crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ mod ffi {

fn start(&self, prompt: &str) -> u32;
fn step(&self, next_token_id: u32) -> u32;
fn end(&self);

fn eos_token(&self) -> u32;
}
}

Expand Down Expand Up @@ -62,7 +65,13 @@ impl TextGeneration for LlamaEngine {

let output_ids = tokio::task::spawn_blocking(move || {
let engine = engine.lock().unwrap();
let eos_token = engine.eos_token();

let mut next_token_id = engine.start(&prompt);
if next_token_id == eos_token {
return Vec::new();
}

let mut n_remains = options.max_decoding_length - 1;
let mut output_ids = vec![next_token_id];

Expand All @@ -73,18 +82,22 @@ impl TextGeneration for LlamaEngine {
}

next_token_id = engine.step(next_token_id);
if next_token_id == eos_token {
break;
}

if stop_condition.next_token(next_token_id) {
break;
}
output_ids.push(next_token_id);
n_remains -= 1;
}

engine.end();
output_ids
})
.await
.expect("Inference failed");

self.tokenizer.decode(&output_ids, true).unwrap()
}
}