Skip to content

Commit

Permalink
Fix types in protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Feb 6, 2024
1 parent deda2a4 commit 7587beb
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tests/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,27 @@ class classify_module(Protocol):
Protocol for all `classify` modules.
"""

def token_logprobs(self, texts: str | Sequence[str], Any) -> list[list[float]]:
def token_logprobs(
self, texts: str | Sequence[str], *args, **kwargs
) -> list[list[float]]:
"""
For each text, log Pr(token[i] | tokens[:i]) for each token
"""

def cache(self, model: Model, prefix: str) -> _GeneratorContextManager[Model]:
def cache(
self, model: Model, prefix: str, *args, **kwargs
) -> _GeneratorContextManager[Model]:
"""
Optional: context manager which caches the `model` with `prefix`
"""

def cache_model(self, model: Model, prefix: str) -> Model:
def cache_model(self, model: Model, prefix: str, *args, **kwargs) -> Model:
"""
Optional: returns a model cached with `prefix`
"""

def log_probs_conditional(
self, prompts: str | Sequence[str], completions: Sequence[str], Any
self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs
) -> list[list[float]] | list[list[list[float]]]:
"""
list[i][j][k] = log Pr(
Expand All @@ -41,7 +45,7 @@ def log_probs_conditional(
"""

def log_probs_conditional_examples(
self, examples: Example | Sequence[Example], Any
self, examples: Example | Sequence[Example], *args, **kwargs
) -> list[list[float]] | list[list[list[float]]]:
"""
list[i][j][k] = log Pr(
Expand All @@ -51,28 +55,28 @@ def log_probs_conditional_examples(
"""

def predict_proba(
self, prompts: str | Sequence[str], completions: Sequence[str], Any
self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs
) -> npt.NDArray[np.floating]:
"""
array[i, j] = Pr(completions[j] | prompts[i])
"""

def predict_proba_examples(
self, examples: Example | Sequence[Example], Any
self, examples: Example | Sequence[Example], *args, **kwargs
) -> npt.NDArray[np.floating] | list[npt.NDArray[np.floating]]:
"""
list[i][j] = Pr(examples[i].completions[j] | examples[i].prompt)
"""

def predict(
self, prompts: str | Sequence[str], completions: Sequence[str], Any
self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs
) -> str | list[str]:
"""
list[i] = argmax_j Pr(completions[j] | prompts[i])
"""

def predict_examples(
self, examples: Example | Sequence[Example], Any
self, examples: Example | Sequence[Example], *args, **kwargs
) -> str | list[str]:
"""
list[i] = argmax_j Pr(examples[i].completions[j] | examples[i].prompt)
Expand Down

0 comments on commit 7587beb

Please sign in to comment.