Skip to content

Commit

Permalink
feat: Support batch embeddings (#1186)
Browse files Browse the repository at this point in the history
* handle batched embeddings

* fix normalization issue

* fix type hints, ensure no breaking changes to embed

* Clear kv cache / reset internal state after embedding complete

---------

Co-authored-by: Andrei <abetlen@gmail.com>
  • Loading branch information
iamlemec and abetlen authored Feb 14, 2024
1 parent 36b8432 commit d7a6791
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 34 deletions.
22 changes: 22 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,14 @@ def __del__(self):
self._llama_batch_free(self.batch)
self.batch = None

def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens

def reset(self):
assert self.batch is not None
self.batch.n_tokens = 0

def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
Expand All @@ -522,6 +530,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
n_tokens0 = self.batch.n_tokens
self.batch.n_tokens += n_tokens
for i in range(n_tokens):
j = n_tokens0 + i
self.batch.token[j] = batch[i]
self.batch.pos[j] = i
self.batch.seq_id[j][0] = seq_id
self.batch.n_seq_id[j] = 1
self.batch.logits[j] = logits_all
self.batch.logits[n_tokens - 1] = True


class _LlamaTokenDataArray:
def __init__(self, *, n_vocab: int):
Expand Down
135 changes: 101 additions & 34 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,53 @@ def create_embedding(
Returns:
An embedding object.
"""
assert self._ctx.ctx is not None
assert self._model.model is not None
model_name: str = model if model is not None else self.model_path

# get numeric embeddings
embeds: List[List[float]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore

# convert to CreateEmbeddingResponse
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(embeds)
]

return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}

def embed(
self,
input: Union[str, List[str]],
normalize: bool = True,
truncate: bool = True,
return_count: bool = False,
):
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()

if self.context_params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
Expand All @@ -734,48 +777,72 @@ def create_embedding(
else:
inputs = input

data: List[Embedding] = []
# reset batch
self._batch.reset()

# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()

# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
data.append(embedding)

# init state
total_tokens = 0
for index, input in enumerate(inputs):
tokens = self.tokenize(input.encode("utf-8"), special=True)
self.reset()
self.eval(tokens)
t_batch = 0
s_sizes: List[int] = []

# accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]

n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
: llama_cpp.llama_n_embd(self._model.model)
]

data.append(
{
"object": "embedding",
"embedding": embedding,
"index": index,
}
)
# check for overrun
if n_tokens > n_ctx:
raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
)

# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
t_batch = 0
s_sizes = []

# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
t_batch += n_tokens
s_sizes.append(n_tokens)

# hanlde last batch
decode_batch(s_sizes)

if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)

return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}

def embed(self, input: str) -> List[float]:
"""Embed a string.
output = data[0] if isinstance(input, str) else data

Args:
input: The utf-8 encoded string to embed.
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self.reset()

Returns:
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
if return_count:
return output, total_tokens
else:
return output

def _create_completion(
self,
Expand Down

0 comments on commit d7a6791

Please sign in to comment.