Skip to content

Commit

Permalink
Small things
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Oct 20, 2024
1 parent 49fbb01 commit 2ab9230
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
rev: v0.7.0
hooks:
# Run the linter.
- id: ruff
Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ dependencies = [
"tqdm>=4.27.0",
]
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
authors = [
{ name = "Kush Dubey", email = "kushdubey63@gmail.com" },
]
Expand Down Expand Up @@ -40,14 +48,14 @@ demos = [
"scikit-learn>=1.2.2",
]
dev = [
"cappr[openai,hf-dev,llama-cpp-dev,demos]",
"docutils<0.19",
"cappr[demos]",
"docutils<0.19", # For readthedocs. TODO: get rid of this specifier
"pre-commit>=3.5.0",
"pydata-sphinx-theme>=0.13.1",
"pytest>=7.2.1",
"pytest-cov>=4.0.0",
"pytest-sugar>=1.0.0",
"ruff>=0.3.0",
"ruff>=0.7.0",
"sphinx>=6.1.3",
"sphinx-copybutton>=0.5.2",
"sphinx-togglebutton>=0.3.2",
Expand Down
33 changes: 11 additions & 22 deletions src/cappr/huggingface/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def token_logprobs(


########################################################################################
########################## Attention past_key_values utilities #########################
################################ Slice past_key_values #################################
########################################################################################


Expand Down Expand Up @@ -173,7 +173,7 @@ def _select(
for kv_idx in range(len(past_key_values[block_idx])):
batch_size = past_key_values[block_idx][kv_idx].shape[0]
if batch_size == 1:
raise ValueError("Should use _expand_batch_indices") # pragma: no cover
raise ValueError("Should use _expand") # pragma: no cover
kvs.append(past_key_values[block_idx][kv_idx][batch_idxs, ...])
blocks.append(tuple(kvs))
return tuple(blocks)
Expand Down Expand Up @@ -237,9 +237,11 @@ def __init__(
# This data is in one place to minimize pollution of the inputted model's
# namespace. This object should be treated like a ModelForCausalLM by the user
self._cappr.update_cache = True
_ = self.forward(**encodings_to_cache)
del _
self._cappr.update_cache = False
try:
_ = self.forward(**encodings_to_cache)
del _
finally:
self._cappr.update_cache = False

def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
Expand Down Expand Up @@ -359,7 +361,8 @@ def cache_model(
prefixes : str | Sequence[str]
prefix(es) for all future strings that will be processed, e.g., a string
containing shared prompt instructions, or a string containing instructions and
exemplars for few-shot prompting
exemplars for few-shot prompting. `prefixes` and future strings are assumed to
be separated by a whitespace.
logits_all : bool, optional
whether or not to have the cached model include logits for all tokens (including
the past). By default, past token logits are included
Expand All @@ -369,14 +372,6 @@ def cache_model(
tuple[ModelForCausalLM, PreTrainedTokenizerBase]
cached model and the (unmodified) tokenizer
Note
----
If you're inputting the cached model and tokenizer to a function in this module,
e.g., :func:`predict`, `prefixes` and future strings are assumed to be separated by
a whitespace. Otherwise, ensure that any strings that are processed by the tokenizer
start correctly. Furthermore, if applicable, set ``tokenizer.add_bos_token = False``
for future computations.
Example
-------
Usage with :func:`predict_proba`::
Expand Down Expand Up @@ -466,7 +461,8 @@ def cache(
prefixes : str | Sequence[str]
prefix(es) for all strings that will be processed in this context, e.g., a
string containing shared prompt instructions, or a string containing
instructions and exemplars for few-shot prompting
instructions and exemplars for few-shot prompting. `prefixes` and future strings
are assumed to be separated by a whitespace.
clear_cache_on_exit : bool, optional
whether or not to clear the cache and render the returned model and tokenizer
unusable when we exit the context. This is important because it saves memory,
Expand All @@ -475,13 +471,6 @@ def cache(
whether or not to have the cached model include logits for all tokens (including
the past). By default, past token logits are included
Note
----
If you're inputting the cached model and tokenizer to a function in this module,
e.g., :func:`predict`, `prefixes` and future strings are assumed to be separated by
a whitespace. Otherwise, ensure that any strings that are processed by the tokenizer
start correctly.
Example
-------
Usage with :func:`predict_proba`::
Expand Down
6 changes: 4 additions & 2 deletions src/cappr/huggingface/classify_no_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from cappr.utils import _batch, _check, classify
from cappr import Example
from cappr import huggingface as hf
from cappr.huggingface._utils import ModelForCausalLM
from cappr.huggingface._utils import BatchEncodingPT, ModelForCausalLM


def token_logprobs(
Expand Down Expand Up @@ -106,8 +106,10 @@ def _prompts_offsets(
prompts = list(prompts)
padding = len(prompts) > 1
with hf._utils.set_up_tokenizer(tokenizer):
encoding = tokenizer(prompts, return_tensors="pt", padding=padding)
encoding = cast(BatchEncodingPT, encoding)
offsets: torch.Tensor = (
tokenizer(prompts, return_tensors="pt", padding=padding)["attention_mask"]
encoding["attention_mask"]
.sum(dim=1)
.repeat_interleave(num_completions_per_prompt, dim=0)
)
Expand Down
18 changes: 8 additions & 10 deletions src/cappr/utils/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def _is_sliceable(object) -> bool:

def agg_log_probs(
log_probs: Sequence[Sequence[float]] | Sequence[Sequence[Sequence[float]]],
func: Callable[[Sequence[float]], float] = _avg_then_exp,
) -> npt.NDArray[np.floating] | list[float] | list[list[float]]:
func: Callable = _avg_then_exp,
) -> npt.NDArray[np.floating] | list[list[float]]:
"""
Aggregate token log-probabilities along the last dimension.
Expand All @@ -166,7 +166,7 @@ def agg_log_probs(
:class:`cappr.Example` object with completions. A 3-D sequence corresponds to
inputting multiple prompt strings or :class:`cappr.Example` objects with
completions
func : Callable[[Sequence[float]], float], optional
func : Callable, optional
a function which aggregates a sequence of token log-probabilities into a single
number, by default a probability. If the function is vectorized, it must take an
``axis`` argument, e.g., ``np.mean`` will efficiently average the token
Expand All @@ -175,7 +175,7 @@ def agg_log_probs(
Returns
-------
agg: npt.NDArray[np.floating] | list[float] | list[list[float]]
agg: npt.NDArray[np.floating] | list[list[float]]
If `log_probs` is 2-D, then `agg` is a numpy array or a list where::
agg[j] = func(log_probs[j])
Expand Down Expand Up @@ -575,16 +575,14 @@ def wrapper(
assert pred_probs.ndim == 1 # double check
pred_class_idx = pred_probs.argmax()
return examples.completions[pred_class_idx]
try:
# If it's an array, we can call .argmax on the whole thing, which is faster

if isinstance(pred_probs, np.ndarray):
pred_class_idxs = pred_probs.argmax(axis=1)
except (
AttributeError, # no argmax attr
TypeError, # no axis kwarg
):
else:
pred_class_idxs = [
example_pred_probs.argmax() for example_pred_probs in pred_probs
]

return [
example.completions[pred_class_idx]
for example, pred_class_idx in zip(examples, pred_class_idxs)
Expand Down

0 comments on commit 2ab9230

Please sign in to comment.