Skip to content

Commit

Permalink
Little fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Nov 7, 2023
1 parent 1c684b0 commit 26d377c
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 51 deletions.
21 changes: 0 additions & 21 deletions src/cappr/_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,6 @@ class Example:
probability distribution over completions. Set this to `False` if you'd like the
raw completion-after-prompt probability, or you're solving a multi-label
prediction problem. By default, True
Raises
------
TypeError
if `prompt` is not a string
ValueError
if `prompt` is empty
TypeError
if `completions` is not a sequence
ValueError
if `completions` is empty, or contains an empty string
TypeError
if `end_of_prompt` is not a string
ValueError
if `end_of_prompt` is not a whitespace or empty
TypeError
if `prior` is not None, or it isn't a sequence or numpy array
ValueError
if `prior` is not a 1-D probability distribution over `completions`
ValueError
if `normalize` is True but there's only one completion in `completions`
"""

prompt: str
Expand Down
9 changes: 4 additions & 5 deletions src/cappr/huggingface/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def cache(


########################################################################################
############################## Logits (from cached model) ##############################
############################### Logits from cached model ###############################
########################################################################################


Expand Down Expand Up @@ -534,7 +534,7 @@ def _blessed_helper(
)
num_batches = len(completions_input_ids)

# TODO: put this in the context manager? Little weird.
# TODO: put this in the context manager? Little weird
if not hf._utils.does_tokenizer_need_prepended_space(tokenizer):
start_of_prompt = ""
else:
Expand Down Expand Up @@ -890,9 +890,8 @@ def log_probs_conditional_examples(
print(log_probs_completions[1]) # corresponds to examples[1]
# [[-5.0, -1.7]] [[log Pr(1 | a, b, c)], log Pr(2 | a, b, c, 1)]]
"""
# Little weird. I want my IDE to know that examples is always a Sequence[Example]
# b/c of the decorator.
examples: Sequence[Example] = examples
# examples is always a Sequence[Example] b/c of the decorator.
examples = cast(Sequence[Example], examples)

@_batch.flatten
@_batch.batchify(
Expand Down
22 changes: 18 additions & 4 deletions src/cappr/huggingface/classify_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
this module **does not** precompute attention block keys and values for prompts.
"""
from __future__ import annotations
from typing import Literal, Mapping, Sequence
from typing import cast, Literal, Mapping, Sequence

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -88,6 +88,11 @@ def token_logprobs(
return hf.classify.token_logprobs(**locals())


########################################################################################
######################################## Logits ########################################
########################################################################################


def _prompts_offsets(
tokenizer: PreTrainedTokenizerBase,
prompts: Sequence[str],
Expand Down Expand Up @@ -163,6 +168,11 @@ def _logits_completions_given_prompts_examples(
return logits, encodings


########################################################################################
################################## Logits to log-probs #################################
########################################################################################


def _logits_to_log_probs_completions(
logits: torch.Tensor, encodings: Mapping[str, torch.Tensor]
) -> list[list[float]]:
Expand All @@ -179,6 +189,11 @@ def _logits_to_log_probs_completions(
]


########################################################################################
##################################### Implementation ###################################
########################################################################################


@classify._log_probs_conditional
def log_probs_conditional(
prompts: str | Sequence[str],
Expand Down Expand Up @@ -377,9 +392,8 @@ def log_probs_conditional_examples(
print(log_probs_completions[1]) # corresponds to examples[1]
# [[-5.0, -1.7]] [[log Pr(1 | a, b, c)], log Pr(2 | a, b, c, 1)]]
"""
# Little weird. I want my IDE to know that examples is always a Sequence[Example]
# b/c of the decorator.
examples: Sequence[Example] = examples
# examples is always a Sequence[Example] b/c of the decorator.
examples = cast(Sequence[Example], examples)

@_batch.flatten
@_batch.batchify(
Expand Down
5 changes: 2 additions & 3 deletions src/cappr/llama_cpp/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def check_logits(logits) -> np.ndarray:
logits = np.array(logits)
if np.any(np.isnan(logits)):
raise TypeError(
"There are nan logits. This can happen if the model is re-loaded too many "
"times in the same session. Please raise this as an issue so that I can "
"investigate: https://github.com/kddubey/cappr/issues"
"There are nan logits. Is there something wrong with the model? This can "
"happen if the model is reloaded many times in the same session."
) # pragma: no cover
return logits

Expand Down
24 changes: 19 additions & 5 deletions src/cappr/llama_cpp/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""
from __future__ import annotations
from contextlib import contextmanager
from typing import Literal, Sequence
from typing import cast, Literal, Sequence

from llama_cpp import Llama
import numpy as np
Expand Down Expand Up @@ -119,6 +119,11 @@ def token_logprobs(
return log_probs


########################################################################################
###################################### KV caching ######################################
########################################################################################


@contextmanager
def cache(model: Llama, prefix: str, reset_model: bool = True):
"""
Expand Down Expand Up @@ -191,6 +196,11 @@ def cache(model: Llama, prefix: str, reset_model: bool = True):
model.n_tokens = n_tokens


########################################################################################
############################## Logprobs from cached model ##############################
########################################################################################


def _log_probs_conditional_prompt(
prompt: str,
completions: Sequence[str],
Expand All @@ -199,7 +209,7 @@ def _log_probs_conditional_prompt(
) -> list[list[float]]:
_utils.check_model(model)
# Prepend whitespaces if the tokenizer or context call for it
# TODO: put this in the context manager? Little weird.
# TODO: put this in the context manager? Little weird
if not _utils.does_tokenizer_need_prepended_space(model):
start_of_prompt = ""
end_of_prompt = ""
Expand Down Expand Up @@ -258,6 +268,11 @@ def _log_probs_conditional_prompt(
return log_probs_completions


########################################################################################
#################################### Implementation ####################################
########################################################################################


@classify._log_probs_conditional
def log_probs_conditional(
prompts: str | Sequence[str],
Expand Down Expand Up @@ -441,9 +456,8 @@ def log_probs_conditional_examples(
print(log_probs_completions[1]) # corresponds to examples[1]
# [[-9.90, -10.0]] [[log Pr(d | a, b, c)], log Pr(e | a, b, c, d)]]
"""
# Little weird. I want my IDE to know that examples is always a Sequence[Example]
# b/c of the decorator.
examples: Sequence[Example] = examples
# examples is always a Sequence[Example] b/c of the decorator.
examples = cast(Sequence[Example], examples)
if reset_model:
model.reset()
log_probs_completions = [
Expand Down
6 changes: 3 additions & 3 deletions src/cappr/openai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
try:
from openai import OpenAI
except ImportError: # pragma: no cover
# openai version < 1.0.0. Many breaking changes need handling
# openai < v1.0.0. Many breaking changes need handling
OpenAI = type("OpenAI", (object,), {}) # pragma: no cover
_ERRORS_MODULE = openai.error # pragma: no cover
else:
Expand Down Expand Up @@ -304,7 +304,7 @@ def gpt_complete(
list with the same length as `texts`. Each element is the ``choices`` mapping
"""
_check.ordered(texts, variable_name="texts")
try:
try: # openai < v1.0.0
openai_method = openai.Completion.create # pragma: no cover
except AttributeError:
openai_method = (
Expand Down Expand Up @@ -400,7 +400,7 @@ def gpt_chat_complete(
list with the same length as `texts`. Each element is the ``choices`` mapping
"""
_check.ordered(texts, variable_name="texts")
try:
try: # openai < v1.0.0
openai_method = openai.ChatCompletion.create # pragma: no cover
except AttributeError:
openai_method = (
Expand Down
5 changes: 3 additions & 2 deletions src/cappr/utils/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def remove_bos(tokens: list[int]) -> list[int]:
tokens_concat_correct = tokenize("a") + remove_bos(tokenize(" b"))
if tokens != tokens_concat_correct:
raise ValueError(
"This tokenizer is weird. Please raise this as an issue so that I can "
"investigate: https://github.com/kddubey/cappr/issues"
"This tokenizer is weird. Perhaps it's adding EOS tokens? Please raise "
"this as an issue so that I can investigate: "
"https://github.com/kddubey/cappr/issues"
) # pragma: no cover
return True
return False
6 changes: 3 additions & 3 deletions src/cappr/utils/_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Utilities for implementations which don't cache.
"""
from __future__ import annotations
from typing import Callable, Literal, Sequence
from typing import Callable, cast, Literal, Sequence

from cappr.utils import _batch

Expand Down Expand Up @@ -72,8 +72,8 @@ def log_probs_conditional_examples(
):
from cappr import Example

# Little weird. I want my IDE to know that examples is always a Sequence[Example]
examples: Sequence[Example] = examples
# examples is always a Sequence[Example] b/c of the decorator.
examples = cast(Sequence[Example], examples)

texts = [
example.prompt + example.end_of_prompt + completion
Expand Down
11 changes: 6 additions & 5 deletions src/cappr/utils/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _predict(predict_proba_func):
@wraps(predict_proba_func)
def wrapper(
prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs
) -> list[str]:
) -> str | list[str]:
if len(completions) == 1:
raise ValueError(
"completions only has one completion. predict will trivially return "
Expand All @@ -511,9 +511,8 @@ def wrapper(
pred_probs: npt.NDArray = predict_proba_func(
prompts, completions, *args, **kwargs
)
if not isinstance(completions, Sequence):
# We need completions to support 0-indexed __getitem__
completions = list(completions)
# We need completions to support 0-indexed __getitem__
completions = list(completions)
num_dimensions = pred_probs.ndim
if isinstance(prompts, str):
# User convenience: prompts was a single string, so pred_probs is 1-D
Expand All @@ -535,7 +534,9 @@ def _predict_examples(predict_proba_examples_func):
from cappr import Example

@wraps(predict_proba_examples_func)
def wrapper(examples: Example | Sequence[Example], *args, **kwargs) -> list[str]:
def wrapper(
examples: Example | Sequence[Example], *args, **kwargs
) -> str | list[str]:
pred_probs: npt.NDArray[np.floating] | list[
npt.NDArray[np.floating]
] = predict_proba_examples_func(examples, *args, **kwargs)
Expand Down

0 comments on commit 26d377c

Please sign in to comment.