Skip to content

Commit

Permalink
Allow arbitrary aggregation of token log-probs
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Jul 31, 2024
1 parent 81b3134 commit ca8076d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 25 deletions.
49 changes: 28 additions & 21 deletions src/cappr/utils/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,16 @@ def _setattr(obj, name: str, value):
########################################################################################


def _avg_then_exp(
log_probs: Sequence[Sequence[float]] | Sequence[Sequence[Sequence[float]]],
axis: int | None = None,
) -> npt.NDArray[np.floating]:
return np.exp(np.mean(log_probs, axis=axis))


def _agg_log_probs_vectorized(
log_probs: Sequence[Sequence[Sequence[float]]],
func: Callable[[Sequence[Sequence[float]]], Sequence[float]] = np.mean,
func: Callable[[Sequence[Sequence[float]]], Sequence[float]] = _avg_then_exp,
) -> npt.NDArray[np.floating]:
"""
Aggregate using a vectorized numpy function `func`.
Expand Down Expand Up @@ -96,28 +103,25 @@ def _agg_log_probs_vectorized(
"non-constant # of tokens. Vectorization is not possible."
) from exception
# Now apply the vectorized function to each array in the list
likelihoods: npt.NDArray[np.floating] = np.exp(
likelihoods: npt.NDArray[np.floating] = np.array(
[func(array, axis=1) for array in array_list]
)
# likelihoods looks like:
# array([[likelihood_a1b1, likelihood_a2b2 ],
# [likelihood_c1d1e1, likelihood_c2d2e2]])
# Transpose it to satisfy likelihoods[i][j] = exp(func(log_probs[i][j]))
# Transpose it to satisfy likelihoods[i][j] = func(log_probs[i][j])
return likelihoods.T


def _agg_log_probs(
log_probs: Sequence[Sequence[Sequence[float]]],
func: Callable[[Sequence[float]], float] = np.mean,
func: Callable[[Sequence[float]], float] = _avg_then_exp,
) -> list[list[float]]:
"""
Aggregate using a slow, nested list comprehension.
"""
return [
[
np.exp(func(log_probs_completion))
for log_probs_completion in log_probs_completions
]
[func(log_probs_completion) for log_probs_completion in log_probs_completions]
for log_probs_completions in log_probs
]

Expand Down Expand Up @@ -145,34 +149,37 @@ def _is_sliceable(object) -> bool:

def agg_log_probs(
log_probs: Sequence[Sequence[float]] | Sequence[Sequence[Sequence[float]]],
func: Callable[[Sequence[float]], float] = np.mean,
func: Callable[[Sequence[float]], float] = _avg_then_exp,
) -> npt.NDArray[np.floating] | list[float] | list[list[float]]:
"""
Aggregate token log-probabilities along the last dimension into probabilities.
Aggregate token log-probabilities along the last dimension.
Parameters
----------
log_probs : Sequence[Sequence[float]] | Sequence[Sequence[Sequence[float]]]
nested sequences where token log-probabilities are in the last dimension. A 2-D
sequence corresponds to inputting a single prompt string or
:class:`cappr.Example` object. A 3-D sequence corresponds to inputting multiple
prompt strings or :class:`cappr.Example` objects
:class:`cappr.Example` object with completions. A 3-D sequence corresponds to
inputting multiple prompt strings or :class:`cappr.Example` objects with
completions
func : Callable[[Sequence[float]], float], optional
function which aggregates a sequence of token log-probabilities into a single
log-probability. If the function is vectorized, it must take an ``axis``
argument. By default, `numpy.mean`
a function which aggregates a sequence of token log-probabilities into a single
number, by default a probability. If the function is vectorized, it must take an
``axis`` argument, e.g., ``np.mean`` will efficiently average the token
log-probabilities. By default, token log-probabilities are averaged and then
exponentiated.
Returns
-------
probs: npt.NDArray[np.floating] | list[float] | list[list[float]]
If `log_probs` is 2-D, then `probs` is an array or list of probabilities where::
agg: npt.NDArray[np.floating] | list[float] | list[list[float]]
If `log_probs` is 2-D, then `agg` is an array or list where::
probs[j] = exp(func(log_probs[j]))
agg[j] = func(log_probs[j])
If `log_probs` is 3-D, then `probs` is an array or a list of list of
probabilities where::
If `log_probs` is 3-D, then `agg` is an array or a list of list of probabilities
where::
probs[i][j] = exp(func(log_probs[i][j]))
agg[i][j] = func(log_probs[i][j])
Raises
------
Expand Down
31 changes: 27 additions & 4 deletions tests/utils/test_utils_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@


def test___agg_log_probs_vectorized():
log_probs = [[[2, 2], [1]], [[1 / 2, 1 / 2], [4]]]
# There are 2 prompts, each associated with 3 completions. The completions have 2,
# 1, and 2 tokens, respectively
log_probs = [
[[2, 2], [1], [3, 2]],
[[1 / 2, 1 / 2], [4], [3, 0]],
]
log_probs_agg_expected = [
[2 + 2, 1, 3 + 2],
[1 / 2 + 1 / 2, 4, 3 + 0],
]
log_probs_agg = classify._agg_log_probs_vectorized(log_probs, func=np.sum)
assert np.allclose(log_probs_agg, np.exp([[4, 1], [1, 4]]))
assert np.allclose(log_probs_agg, log_probs_agg_expected)

# Test the default _avg_then_exp func
log_probs_agg = classify._agg_log_probs_vectorized(log_probs)
log_probs_agg_expected = classify._agg_log_probs(log_probs)
assert np.allclose(log_probs_agg, log_probs_agg_expected)


@pytest.mark.parametrize(
Expand All @@ -34,14 +48,23 @@ def test__ndim(sequence_and_depth_expected: tuple[Any, int]):


def test_agg_log_probs():
# There are 2 prompts. The first prompt is associated with 2 completions, with 2 and
# 3 tokens each. The second prompt is associated with 3 completions, with 1, 3, and
# 2 tokens each
log_probs = [
[[0, 1], [2, 3, 4]],
[[5], [6, 7, 8], [9, 10]],
]
log_probs_agg_expected = [
[0 + 1, 2 + 3 + 4],
[5, 6 + 7 + 8, 9 + 10],
]
log_probs_agg = classify.agg_log_probs(log_probs, func=sum)
assert len(log_probs_agg) == len(log_probs)
assert np.allclose(log_probs_agg[0], np.exp([0 + 1, 2 + 3 + 4]))
assert np.allclose(log_probs_agg[1], np.exp([5, 6 + 7 + 8, 9 + 10]))
for prompt_idx in range(len(log_probs)):
assert np.allclose(
log_probs_agg[prompt_idx], log_probs_agg_expected[prompt_idx]
)

# Test bad input
with pytest.raises(
Expand Down

0 comments on commit ca8076d

Please sign in to comment.