Skip to content

Commit

Permalink
Merge pull request #245 from danbev/embeddings-plus-updated-llama.cpp
Browse files Browse the repository at this point in the history
llama: add Embeddings for llama
  • Loading branch information
williamhogman authored Dec 17, 2023
2 parents cb03333 + 0e29158 commit e6e02fb
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 23 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/llm-chain-llama/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ serde = { version = "1.0.163", features = ["derive"] }
thiserror.workspace = true
lazy_static = "1.4.0"
tokio.workspace = true
futures = "0.3.29"

[dev-dependencies]
tokio = { version = "1.28.2", features = ["macros", "rt"] }
26 changes: 26 additions & 0 deletions crates/llm-chain-llama/examples/simple_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use llm_chain::options;
use llm_chain::traits::Embeddings;

/// This example demonstrates using llm-chain-llama for generating
/// embeddings.
///
/// Usage:
/// env LLM_CHAIN_MODEL=<path_to_model> cargo run --example simple_embeddings
///
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let opts = options!(
NThreads: 4_usize,
MaxTokens: 2048_usize
);
let embeddings = llm_chain_llama::embeddings::Embeddings::new_with_options(opts)?;
let embedded_vecs = embeddings
.embed_texts(vec![
"This is an amazing way of writing LLM-powered applications".to_string(),
])
.await
.unwrap();
println!("Embedded text: {:?}", embedded_vecs[0]);

Ok(())
}
45 changes: 31 additions & 14 deletions crates/llm-chain-llama/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::ptr::null_mut;
#[allow(dead_code)]
pub struct LlamaBatch {
n_tokens: i32,
token: Vec<i32>,
tokens: Vec<i32>,
embd: Vec<f32>,
pos: Vec<i32>,
n_seq_id: Vec<i32>,
Expand All @@ -29,7 +29,7 @@ impl LlamaBatch {

Self {
n_tokens: tokens.len() as i32,
token: tokens,
tokens,
embd,
pos,
n_seq_id,
Expand All @@ -44,7 +44,7 @@ impl LlamaBatch {
pub fn new_with_token(token: i32, pos: i32) -> Self {
Self {
n_tokens: 1,
token: vec![token],
tokens: vec![token],
embd: vec![],
pos: vec![pos],
n_seq_id: vec![1],
Expand Down Expand Up @@ -75,22 +75,39 @@ impl Drop for LlamaBatch {

fn convert_llama_batch(batch: &LlamaBatch) -> llama_batch {
let n_tokens = batch.n_tokens;
let token_ptr = Box::leak(batch.token.clone().into_boxed_slice()).as_mut_ptr();
let token_ptr = Box::leak(batch.tokens.clone().into_boxed_slice()).as_mut_ptr();
let embd_ptr = if batch.embd.is_empty() {
null_mut()
} else {
Box::leak(batch.embd.clone().into_boxed_slice()).as_mut_ptr()
};
let pos_ptr = Box::leak(batch.pos.clone().into_boxed_slice()).as_mut_ptr();
let n_seq_id_ptr = Box::leak(batch.n_seq_id.clone().into_boxed_slice()).as_mut_ptr();
let raw_pointers = batch
.seq_id
.clone()
.into_iter()
.map(|inner_vec| Box::leak(inner_vec.into_boxed_slice()).as_mut_ptr())
.collect::<Vec<*mut llama_seq_id>>();
let seq_id_ptr = Box::leak(raw_pointers.into_boxed_slice()).as_mut_ptr();
let logits_ptr = Box::leak(batch.logits.clone().into_boxed_slice()).as_mut_ptr();
let pos_ptr = if batch.pos.is_empty() {
null_mut()
} else {
Box::leak(batch.pos.clone().into_boxed_slice()).as_mut_ptr()
};
let n_seq_id_ptr = if batch.n_seq_id.is_empty() {
null_mut()
} else {
Box::leak(batch.n_seq_id.clone().into_boxed_slice()).as_mut_ptr()
};

let seq_id_ptr = if batch.seq_id.is_empty() {
null_mut()
} else {
let raw_pointers = batch
.seq_id
.clone()
.into_iter()
.map(|inner_vec| Box::leak(inner_vec.into_boxed_slice()).as_mut_ptr())
.collect::<Vec<*mut llama_seq_id>>();
Box::leak(raw_pointers.into_boxed_slice()).as_mut_ptr()
};
let logits_ptr = if batch.logits.is_empty() {
null_mut()
} else {
Box::leak(batch.logits.clone().into_boxed_slice()).as_mut_ptr()
};
llama_batch {
n_tokens,
token: token_ptr,
Expand Down
27 changes: 20 additions & 7 deletions crates/llm-chain-llama/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use crate::options::LlamaInvocation;
use anyhow::Result;
use llm_chain_llama_sys::{
llama_context, llama_context_default_params, llama_context_params, llama_decode, llama_eval,
llama_free, llama_get_logits, llama_get_logits_ith, llama_load_model_from_file, llama_model,
llama_n_vocab, llama_new_context_with_model, llama_sample_repetition_penalties,
llama_sample_tail_free, llama_sample_temperature, llama_sample_token,
llama_sample_token_greedy, llama_sample_token_mirostat, llama_sample_token_mirostat_v2,
llama_sample_top_k, llama_sample_top_p, llama_sample_typical, llama_token_data,
llama_token_data_array, llama_token_eos, llama_token_get_text, llama_token_nl,
llama_token_to_piece,
llama_free, llama_get_embeddings, llama_get_logits, llama_get_logits_ith, llama_kv_cache_clear,
llama_load_model_from_file, llama_model, llama_n_embd, llama_n_vocab,
llama_new_context_with_model, llama_sample_repetition_penalties, llama_sample_tail_free,
llama_sample_temperature, llama_sample_token, llama_sample_token_greedy,
llama_sample_token_mirostat, llama_sample_token_mirostat_v2, llama_sample_top_k,
llama_sample_top_p, llama_sample_typical, llama_token_data, llama_token_data_array,
llama_token_eos, llama_token_get_text, llama_token_nl, llama_token_to_piece,
};

pub use batch::LlamaBatch;
Expand Down Expand Up @@ -161,6 +161,15 @@ impl LLamaContext {
Vec::from(unsafe { std::slice::from_raw_parts(float_ptr, self.llama_n_vocab() as usize) })
}

pub fn llama_get_embeddings(&self) -> Vec<f32> {
unsafe {
let len = llama_n_embd(self.model);
let ptr = llama_get_embeddings(self.ctx);
let slice = std::slice::from_raw_parts_mut(ptr, len as usize);
slice.to_vec()
}
}

// Executes the LLama sampling process with the specified configuration.
pub fn llama_sample(
&self,
Expand Down Expand Up @@ -301,6 +310,10 @@ impl LLamaContext {
unsafe { llama_token_nl(self.model) }
}

pub fn llama_kv_cache_clear(&self) {
unsafe { llama_kv_cache_clear(self.ctx) }
}

pub fn llama_token_to_piece(
&self,
token_id: i32,
Expand Down
159 changes: 159 additions & 0 deletions crates/llm-chain-llama/src/embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use crate::batch::LlamaBatch;
use crate::context::ContextParams;
use crate::context::LLamaContext;
use crate::model::ModelParams;
use crate::options::{LlamaInvocation, DEFAULT_OPTIONS};
use crate::tokenizer;
use async_trait::async_trait;
use futures::future::try_join_all;
use llm_chain::options::{options_from_env, Opt, OptDiscriminants, Options, OptionsCascade};
use llm_chain::prompt::Data;
use llm_chain::traits::EmbeddingsCreationError;
use llm_chain::traits::{self, EmbeddingsError};
use std::sync::Arc;
use std::{error::Error, fmt::Debug};
use tokio::sync::Mutex;

/// Generate embeddings using the llama.
///
/// This intended be similar to running the embedding example in llama.cpp:
/// ./embedding -m <path_to_model> --log-disable -p "Hello world" 2>/dev/null
///
pub struct Embeddings {
context: Arc<Mutex<LLamaContext>>,
options: Options,
}

#[async_trait]
impl traits::Embeddings for Embeddings {
type Error = LlamaEmbeddingsError;

async fn embed_texts(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, Self::Error> {
let futures = texts.into_iter().map(|text| self.embed_query(text));
let embeddings = try_join_all(futures).await?;
Ok(embeddings)
}

async fn embed_query(&self, query: String) -> Result<Vec<f32>, Self::Error> {
let options = vec![&DEFAULT_OPTIONS, &self.options];
let invocation =
LlamaInvocation::new(OptionsCascade::from_vec(options), &Data::Text(query)).unwrap();
let embeddings = self.generate_embeddings(invocation).await?;
Ok(embeddings)
}
}

#[allow(dead_code)]
impl Embeddings {
pub fn new_with_options(opt: Options) -> Result<Self, EmbeddingsCreationError> {
//TODO(danbev) This is pretty much a duplication of the code in
// llm_chain::executor::Executor::new_with_options. Find a good place
// to share this code.
let opts_from_env =
options_from_env().map_err(|err| EmbeddingsCreationError::InnerError(err.into()))?;
let cas = OptionsCascade::new()
.with_options(&DEFAULT_OPTIONS)
.with_options(&opts_from_env)
.with_options(&opt);

let Some(Opt::Model(model)) = cas.get(OptDiscriminants::Model) else {
return Err(EmbeddingsCreationError::FieldRequiredError(
"model_path".to_string(),
));
};

let mut mp = ModelParams::new();
if let Some(Opt::NGpuLayers(value)) = cas.get(OptDiscriminants::NGpuLayers) {
mp.n_gpu_layers = *value;
}
if let Some(Opt::MainGpu(value)) = cas.get(OptDiscriminants::MainGpu) {
mp.main_gpu = *value;
}
if let Some(Opt::TensorSplit(values)) = cas.get(OptDiscriminants::TensorSplit) {
mp.tensor_split = values.clone();
}
// Currently, the setting of vocab_only is not allowed as it will cause
// a crash when using the llama executor which needs to have wieghts loaded
// in order to work.
mp.vocab_only = false;

if let Some(Opt::UseMmap(value)) = cas.get(OptDiscriminants::UseMmap) {
mp.use_mmap = *value;
}
if let Some(Opt::UseMlock(value)) = cas.get(OptDiscriminants::UseMlock) {
mp.use_mlock = *value;
}

let mut cp = ContextParams::new();
if let Some(Opt::NThreads(value)) = cas.get(OptDiscriminants::NThreads) {
cp.n_threads = *value as u32;
}

if let Some(Opt::MaxContextSize(value)) = cas.get(OptDiscriminants::MaxContextSize) {
cp.n_ctx = *value as u32;
}

if let Some(Opt::MaxBatchSize(value)) = cas.get(OptDiscriminants::MaxBatchSize) {
cp.n_batch = *value as u32;
}
cp.embedding = true;

Ok(Self {
context: Arc::new(Mutex::new(LLamaContext::from_file_and_params(
&model.to_path(),
Some(&mp),
Some(&cp),
)?)),
options: opt,
})
}

fn get_model_path(options: &Options) -> Result<String, EmbeddingsCreationError> {
let opts_from_env =
options_from_env().map_err(|err| EmbeddingsCreationError::InnerError(err.into()))?;
let cas = OptionsCascade::new()
.with_options(&DEFAULT_OPTIONS)
.with_options(&opts_from_env)
.with_options(&options);
let model_path = cas
.get(OptDiscriminants::Model)
.and_then(|x| match x {
Opt::Model(m) => Some(m),
_ => None,
})
.ok_or(EmbeddingsCreationError::FieldRequiredError(
"model_path".to_string(),
))?;
Ok(model_path.to_path())
}

async fn generate_embeddings(
&self,
input: LlamaInvocation,
) -> Result<Vec<f32>, LlamaEmbeddingsError> {
let context = self.context.clone();
let embeddings = tokio::task::spawn_blocking(move || {
let context = context.blocking_lock();
let prompt_text = input.prompt.to_text();
let tokens = tokenizer::tokenize(&context, prompt_text.as_str(), true, false);
//TODO(danbev) Handle the case where the number of tokens
// are larger than the n_batch size.
let batch = LlamaBatch::new_with_tokens(tokens.clone(), 1);
let _ = context
.llama_decode(&batch)
.map_err(|e| LlamaEmbeddingsError::InnerError(e.into()));
context.llama_get_embeddings()
});
embeddings
.await
.map_err(|e| LlamaEmbeddingsError::InnerError(e.into()))
}
}

#[derive(thiserror::Error, Debug)]
pub enum LlamaEmbeddingsError {
#[error("error when trying to generate embeddings: {0}")]
InnerError(#[from] Box<dyn Error + Send + Sync>),
}

impl EmbeddingsError for LlamaEmbeddingsError {}
15 changes: 13 additions & 2 deletions crates/llm-chain-llama/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ macro_rules! must_send {

/// Executor is responsible for running the LLAMA model and managing its context.
pub struct Executor {
context: Arc<Mutex<LLamaContext>>,
options: Options,
pub(crate) context: Arc<Mutex<LLamaContext>>,
pub(crate) options: Options,
context_params: ContextParams,
}

Expand All @@ -62,6 +62,17 @@ impl Executor {
let context_size = context_size;
let context = context.blocking_lock();

// The following clears the Key-Value cache to allow conversational
// (chat) applications to be able to call run_model multiple times
// using the same context. Without this, and because the same
// sequence id is used below, the cache can contain tokens from
// a previous interaction which may cause the model to generate
// a response that is not appropriate for the current prompt.
//
// TODO(danbev) Is there a better way to do this, perhaps by using
// sequence ids in some way?
context.llama_kv_cache_clear();

let tokenized_stop_prompt = tokenize(
&context,
input
Expand Down
1 change: 1 addition & 0 deletions crates/llm-chain-llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
mod batch;
mod context;
pub mod embeddings;
mod executor;
mod model;
mod options;
Expand Down
1 change: 1 addition & 0 deletions crates/llm-chain-qdrant/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ uuid = "1.3.3"

[dev-dependencies]
llm-chain-openai = { path = "../llm-chain-openai" }
llm-chain-llama = { path = "../llm-chain-llama" }
tokio.workspace = true
serde_yaml.workspace = true
Loading

0 comments on commit e6e02fb

Please sign in to comment.