diff --git a/src/cappr/llama_cpp/_classify_no_cache.py b/src/cappr/llama_cpp/_classify_no_cache.py index 98bf31c..a91fc85 100644 --- a/src/cappr/llama_cpp/_classify_no_cache.py +++ b/src/cappr/llama_cpp/_classify_no_cache.py @@ -12,7 +12,6 @@ from cappr.utils import _no_cache, classify from cappr import Example from cappr.llama_cpp.classify import token_logprobs -from cappr.llama_cpp import _utils def _tokenize(model: Llama, texts: Sequence[str]) -> list[list[int]]: @@ -26,18 +25,15 @@ def log_probs_conditional( model: Llama, end_of_prompt: Literal[" ", ""] = " ", **kwargs, -) -> list[list[list[float]]]: - end_of_prompt_for_slicing = ( - end_of_prompt if _utils.does_tokenizer_need_prepended_space(model) else "" - ) +) -> list[list[float]] | list[list[list[float]]]: return _no_cache.log_probs_conditional( - token_logprobs, - partial(_tokenize, model), prompts, completions, + end_of_prompt, + token_logprobs, + partial(_tokenize, model), + model.token_bos(), model, - end_of_prompt=end_of_prompt, - end_of_prompt_for_slicing=end_of_prompt_for_slicing, add_bos=True, ) @@ -47,15 +43,12 @@ def log_probs_conditional_examples( examples: Example | Sequence[Example], model: Llama, ) -> list[list[float]] | list[list[list[float]]]: - should_end_of_prompt_be_empty = not _utils.does_tokenizer_need_prepended_space( - model - ) return _no_cache.log_probs_conditional_examples( + examples, token_logprobs, partial(_tokenize, model), - examples, + model.token_bos(), model, - should_end_of_prompt_be_empty=should_end_of_prompt_be_empty, add_bos=True, ) diff --git a/src/cappr/openai/classify.py b/src/cappr/openai/classify.py index 62cbee2..58f516f 100644 --- a/src/cappr/openai/classify.py +++ b/src/cappr/openai/classify.py @@ -198,13 +198,15 @@ def log_probs_conditional( # [[-11.6], [[log Pr(z | a, b, c)], # [-0.3, -1.2]] [log Pr(d | a, b, c), log Pr(e | a, b, c, d)]] """ + bos_token_id = None # no OpenAI BPE tokenizer adds a BOS token return _no_cache.log_probs_conditional( - token_logprobs, - tiktoken.encoding_for_model(model).encode_batch, prompts, completions, + end_of_prompt, + token_logprobs, + tiktoken.encoding_for_model(model).encode_batch, + bos_token_id, model, - end_of_prompt=end_of_prompt, client=client, show_progress_bar=show_progress_bar, ask_if_ok=ask_if_ok, @@ -301,12 +303,13 @@ def log_probs_conditional_examples( print(log_probs_completions[1]) # corresponds to examples[1] # [[-11.2, -4.7]] [[log Pr(1 | a, b, c)], log Pr(2 | a, b, c, 1)]] """ + bos_token_id = None # no OpenAI BPE tokenizer adds a BOS token return _no_cache.log_probs_conditional_examples( + examples, token_logprobs, tiktoken.encoding_for_model(model).encode_batch, - examples, + bos_token_id, model, - should_end_of_prompt_be_empty=False, client=client, show_progress_bar=show_progress_bar, ask_if_ok=ask_if_ok, diff --git a/src/cappr/utils/_no_cache.py b/src/cappr/utils/_no_cache.py index 6c5f5db..943f67b 100644 --- a/src/cappr/utils/_no_cache.py +++ b/src/cappr/utils/_no_cache.py @@ -2,9 +2,20 @@ Utilities for implementations which don't cache. """ from __future__ import annotations -from typing import Callable, cast, Literal, Sequence +from functools import lru_cache +from typing import Any, Callable, cast, Literal, Sequence -from cappr.utils import _batch +from cappr.utils import _batch, _check + + +@lru_cache() +def _does_tokenizer_need_prepended_space( + tokenize: Callable[[Sequence[str]], list[list[int]]], bos_token_id: int | None +): + tokenize_single_text = lambda text: tokenize([text])[0] + return _check.does_tokenizer_need_prepended_space( + tokenize_single_text, bos_token_id + ) def _slice_completions( @@ -12,17 +23,18 @@ def _slice_completions( end_of_prompt: str, log_probs: Sequence[Sequence[float]], tokenize: Callable[[Sequence[str]], list[list[int]]], + bos_token_id: int | None, ) -> list[list[float]]: """ - Returns a list `log_probs_completions` where `log_probs_completions[i]` is a list of - conditional log-probablities for each token in `end_of_prompt + completions[i]`, - extracted by slicing `log_probs[i]`. + Slice the completion's tokens from each list of log-probabilities in `log_probs`. """ if len(completions) != len(log_probs): raise ValueError( - "Different number of completions and log_probs: " - f"{len(completions)}, {len(log_probs)}." + f"Different numbers of completions and log_probs: {len(completions)}, " + f"{len(log_probs)}, likely due to an issue with the token_logprobs function" ) # pragma: no cover + if not _does_tokenizer_need_prepended_space(tokenize, bos_token_id): + end_of_prompt = "" completions = [end_of_prompt + completion for completion in completions] completion_lengths = [len(tokens) for tokens in tokenize(completions)] return [ @@ -32,15 +44,15 @@ def _slice_completions( def log_probs_conditional( - token_logprobs: Callable[[Sequence[str]], list[list[float]]], - tokenize: Callable[[list[str]], list[list[int]]], prompts: str | Sequence[str], completions: Sequence[str], + end_of_prompt: Literal[" ", ""], + token_logprobs: Callable[[Sequence[str], Any], list[list[float]]], + tokenize: Callable[[list[str]], list[list[int]]], + bos_token_id: int | None, *token_logprobs_args, - end_of_prompt: Literal[" ", ""] = " ", - end_of_prompt_for_slicing: Literal[" ", ""] = " ", **token_logprobs_kwargs, -): +) -> list[list[list[float]]]: texts = [ prompt + end_of_prompt + completion for prompt in prompts @@ -52,27 +64,27 @@ def log_probs_conditional( end_of_prompt="", **token_logprobs_kwargs, ) - # Since log_probs is a flat list, we'll need to batch them by the size and order of - # completions to fulfill the spec + # log_probs is a 2-D list. Batch it by the size and order of completions to fulfill + # the spec return [ _slice_completions( - completions, end_of_prompt_for_slicing, log_probs_batch, tokenize + completions, end_of_prompt, log_probs_batch, tokenize, bos_token_id ) for log_probs_batch in _batch.constant(log_probs, size=len(completions)) ] def log_probs_conditional_examples( - token_logprobs: Callable[[Sequence[str]], list[list[float]]], - tokenize: Callable[[Sequence[str]], list[list[int]]], examples, + token_logprobs: Callable[[Sequence[str], Any], list[list[float]]], + tokenize: Callable[[Sequence[str]], list[list[int]]], + bos_token_id: int | None, *token_logprobs_args, - should_end_of_prompt_be_empty: bool, **token_logprobs_kwargs, -): +) -> list[list[list[float]]]: from cappr import Example - # examples is always a Sequence[Example] b/c of the decorator. + # examples is always a Sequence[Example] b/c of the decorator examples = cast(Sequence[Example], examples) texts = [ @@ -80,21 +92,28 @@ def log_probs_conditional_examples( for example in examples for completion in example.completions ] - log_probs_all = token_logprobs( + log_probs = token_logprobs( texts, *token_logprobs_args, end_of_prompt="", **token_logprobs_kwargs ) - # Slice out completion tokens - num_completions_per_prompt = [] - completions_all = [] - for example in examples: - num_completions_per_prompt.append(len(example.completions)) - end_of_prompt = "" if should_end_of_prompt_be_empty else example.end_of_prompt - for completion in example.completions: - completions_all.append(end_of_prompt + completion) - log_probs_completions_all = _slice_completions( - completions_all, end_of_prompt="", log_probs=log_probs_all, tokenize=tokenize + should_end_of_prompt_be_empty = not _does_tokenizer_need_prepended_space( + tokenize, bos_token_id + ) + end_of_prompts = [ + "" if should_end_of_prompt_be_empty else example.end_of_prompt + for example in examples + ] + completions = [ + end_of_prompt + completion + for end_of_prompt, example in zip(end_of_prompts, examples) + for completion in example.completions + ] + end_of_prompt = "" # we already added it in completions + log_probs_completions = _slice_completions( + completions, end_of_prompt, log_probs, tokenize, bos_token_id ) - # Batch by completions to fulfill the spec + # log_probs is a 2-D list. Batch it by the size and order of completions to fulfill + # the spec + num_completions_per_prompt = [len(example.completions) for example in examples] return list( - _batch.variable(log_probs_completions_all, sizes=num_completions_per_prompt) + _batch.variable(log_probs_completions, sizes=num_completions_per_prompt) )