Skip to content

Commit

Permalink
Have no cache util handle BPE vs SP
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Nov 8, 2023
1 parent 7f52843 commit b6269e1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 52 deletions.
21 changes: 7 additions & 14 deletions src/cappr/llama_cpp/_classify_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from cappr.utils import _no_cache, classify
from cappr import Example
from cappr.llama_cpp.classify import token_logprobs
from cappr.llama_cpp import _utils


def _tokenize(model: Llama, texts: Sequence[str]) -> list[list[int]]:
Expand All @@ -26,18 +25,15 @@ def log_probs_conditional(
model: Llama,
end_of_prompt: Literal[" ", ""] = " ",
**kwargs,
) -> list[list[list[float]]]:
end_of_prompt_for_slicing = (
end_of_prompt if _utils.does_tokenizer_need_prepended_space(model) else ""
)
) -> list[list[float]] | list[list[list[float]]]:
return _no_cache.log_probs_conditional(
token_logprobs,
partial(_tokenize, model),
prompts,
completions,
end_of_prompt,
token_logprobs,
partial(_tokenize, model),
model.token_bos(),
model,
end_of_prompt=end_of_prompt,
end_of_prompt_for_slicing=end_of_prompt_for_slicing,
add_bos=True,
)

Expand All @@ -47,15 +43,12 @@ def log_probs_conditional_examples(
examples: Example | Sequence[Example],
model: Llama,
) -> list[list[float]] | list[list[list[float]]]:
should_end_of_prompt_be_empty = not _utils.does_tokenizer_need_prepended_space(
model
)
return _no_cache.log_probs_conditional_examples(
examples,
token_logprobs,
partial(_tokenize, model),
examples,
model.token_bos(),
model,
should_end_of_prompt_be_empty=should_end_of_prompt_be_empty,
add_bos=True,
)

Expand Down
13 changes: 8 additions & 5 deletions src/cappr/openai/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,15 @@ def log_probs_conditional(
# [[-11.6], [[log Pr(z | a, b, c)],
# [-0.3, -1.2]] [log Pr(d | a, b, c), log Pr(e | a, b, c, d)]]
"""
bos_token_id = None # no OpenAI BPE tokenizer adds a BOS token
return _no_cache.log_probs_conditional(
token_logprobs,
tiktoken.encoding_for_model(model).encode_batch,
prompts,
completions,
end_of_prompt,
token_logprobs,
tiktoken.encoding_for_model(model).encode_batch,
bos_token_id,
model,
end_of_prompt=end_of_prompt,
client=client,
show_progress_bar=show_progress_bar,
ask_if_ok=ask_if_ok,
Expand Down Expand Up @@ -301,12 +303,13 @@ def log_probs_conditional_examples(
print(log_probs_completions[1]) # corresponds to examples[1]
# [[-11.2, -4.7]] [[log Pr(1 | a, b, c)], log Pr(2 | a, b, c, 1)]]
"""
bos_token_id = None # no OpenAI BPE tokenizer adds a BOS token
return _no_cache.log_probs_conditional_examples(
examples,
token_logprobs,
tiktoken.encoding_for_model(model).encode_batch,
examples,
bos_token_id,
model,
should_end_of_prompt_be_empty=False,
client=client,
show_progress_bar=show_progress_bar,
ask_if_ok=ask_if_ok,
Expand Down
85 changes: 52 additions & 33 deletions src/cappr/utils/_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,39 @@
Utilities for implementations which don't cache.
"""
from __future__ import annotations
from typing import Callable, cast, Literal, Sequence
from functools import lru_cache
from typing import Any, Callable, cast, Literal, Sequence

from cappr.utils import _batch
from cappr.utils import _batch, _check


@lru_cache()
def _does_tokenizer_need_prepended_space(
tokenize: Callable[[Sequence[str]], list[list[int]]], bos_token_id: int | None
):
tokenize_single_text = lambda text: tokenize([text])[0]
return _check.does_tokenizer_need_prepended_space(
tokenize_single_text, bos_token_id
)


def _slice_completions(
completions: Sequence[str],
end_of_prompt: str,
log_probs: Sequence[Sequence[float]],
tokenize: Callable[[Sequence[str]], list[list[int]]],
bos_token_id: int | None,
) -> list[list[float]]:
"""
Returns a list `log_probs_completions` where `log_probs_completions[i]` is a list of
conditional log-probablities for each token in `end_of_prompt + completions[i]`,
extracted by slicing `log_probs[i]`.
Slice the completion's tokens from each list of log-probabilities in `log_probs`.
"""
if len(completions) != len(log_probs):
raise ValueError(
"Different number of completions and log_probs: "
f"{len(completions)}, {len(log_probs)}."
f"Different numbers of completions and log_probs: {len(completions)}, "
f"{len(log_probs)}, likely due to an issue with the token_logprobs function"
) # pragma: no cover
if not _does_tokenizer_need_prepended_space(tokenize, bos_token_id):
end_of_prompt = ""
completions = [end_of_prompt + completion for completion in completions]
completion_lengths = [len(tokens) for tokens in tokenize(completions)]
return [
Expand All @@ -32,15 +44,15 @@ def _slice_completions(


def log_probs_conditional(
token_logprobs: Callable[[Sequence[str]], list[list[float]]],
tokenize: Callable[[list[str]], list[list[int]]],
prompts: str | Sequence[str],
completions: Sequence[str],
end_of_prompt: Literal[" ", ""],
token_logprobs: Callable[[Sequence[str], Any], list[list[float]]],
tokenize: Callable[[list[str]], list[list[int]]],
bos_token_id: int | None,
*token_logprobs_args,
end_of_prompt: Literal[" ", ""] = " ",
end_of_prompt_for_slicing: Literal[" ", ""] = " ",
**token_logprobs_kwargs,
):
) -> list[list[list[float]]]:
texts = [
prompt + end_of_prompt + completion
for prompt in prompts
Expand All @@ -52,49 +64,56 @@ def log_probs_conditional(
end_of_prompt="",
**token_logprobs_kwargs,
)
# Since log_probs is a flat list, we'll need to batch them by the size and order of
# completions to fulfill the spec
# log_probs is a 2-D list. Batch it by the size and order of completions to fulfill
# the spec
return [
_slice_completions(
completions, end_of_prompt_for_slicing, log_probs_batch, tokenize
completions, end_of_prompt, log_probs_batch, tokenize, bos_token_id
)
for log_probs_batch in _batch.constant(log_probs, size=len(completions))
]


def log_probs_conditional_examples(
token_logprobs: Callable[[Sequence[str]], list[list[float]]],
tokenize: Callable[[Sequence[str]], list[list[int]]],
examples,
token_logprobs: Callable[[Sequence[str], Any], list[list[float]]],
tokenize: Callable[[Sequence[str]], list[list[int]]],
bos_token_id: int | None,
*token_logprobs_args,
should_end_of_prompt_be_empty: bool,
**token_logprobs_kwargs,
):
) -> list[list[list[float]]]:
from cappr import Example

# examples is always a Sequence[Example] b/c of the decorator.
# examples is always a Sequence[Example] b/c of the decorator
examples = cast(Sequence[Example], examples)

texts = [
example.prompt + example.end_of_prompt + completion
for example in examples
for completion in example.completions
]
log_probs_all = token_logprobs(
log_probs = token_logprobs(
texts, *token_logprobs_args, end_of_prompt="", **token_logprobs_kwargs
)
# Slice out completion tokens
num_completions_per_prompt = []
completions_all = []
for example in examples:
num_completions_per_prompt.append(len(example.completions))
end_of_prompt = "" if should_end_of_prompt_be_empty else example.end_of_prompt
for completion in example.completions:
completions_all.append(end_of_prompt + completion)
log_probs_completions_all = _slice_completions(
completions_all, end_of_prompt="", log_probs=log_probs_all, tokenize=tokenize
should_end_of_prompt_be_empty = not _does_tokenizer_need_prepended_space(
tokenize, bos_token_id
)
end_of_prompts = [
"" if should_end_of_prompt_be_empty else example.end_of_prompt
for example in examples
]
completions = [
end_of_prompt + completion
for end_of_prompt, example in zip(end_of_prompts, examples)
for completion in example.completions
]
end_of_prompt = "" # we already added it in completions
log_probs_completions = _slice_completions(
completions, end_of_prompt, log_probs, tokenize, bos_token_id
)
# Batch by completions to fulfill the spec
# log_probs is a 2-D list. Batch it by the size and order of completions to fulfill
# the spec
num_completions_per_prompt = [len(example.completions) for example in examples]
return list(
_batch.variable(log_probs_completions_all, sizes=num_completions_per_prompt)
_batch.variable(log_probs_completions, sizes=num_completions_per_prompt)
)

0 comments on commit b6269e1

Please sign in to comment.