From 7587beb6547044cb3851f2659c64a29f13c1fa90 Mon Sep 17 00:00:00 2001 From: kddubey Date: Tue, 6 Feb 2024 15:47:37 -0800 Subject: [PATCH] Fix types in protocol --- tests/_protocol.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/_protocol.py b/tests/_protocol.py index dfc140d..a08db24 100644 --- a/tests/_protocol.py +++ b/tests/_protocol.py @@ -16,23 +16,27 @@ class classify_module(Protocol): Protocol for all `classify` modules. """ - def token_logprobs(self, texts: str | Sequence[str], Any) -> list[list[float]]: + def token_logprobs( + self, texts: str | Sequence[str], *args, **kwargs + ) -> list[list[float]]: """ For each text, log Pr(token[i] | tokens[:i]) for each token """ - def cache(self, model: Model, prefix: str) -> _GeneratorContextManager[Model]: + def cache( + self, model: Model, prefix: str, *args, **kwargs + ) -> _GeneratorContextManager[Model]: """ Optional: context manager which caches the `model` with `prefix` """ - def cache_model(self, model: Model, prefix: str) -> Model: + def cache_model(self, model: Model, prefix: str, *args, **kwargs) -> Model: """ Optional: returns a model cached with `prefix` """ def log_probs_conditional( - self, prompts: str | Sequence[str], completions: Sequence[str], Any + self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs ) -> list[list[float]] | list[list[list[float]]]: """ list[i][j][k] = log Pr( @@ -41,7 +45,7 @@ def log_probs_conditional( """ def log_probs_conditional_examples( - self, examples: Example | Sequence[Example], Any + self, examples: Example | Sequence[Example], *args, **kwargs ) -> list[list[float]] | list[list[list[float]]]: """ list[i][j][k] = log Pr( @@ -51,28 +55,28 @@ def log_probs_conditional_examples( """ def predict_proba( - self, prompts: str | Sequence[str], completions: Sequence[str], Any + self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs ) -> npt.NDArray[np.floating]: """ array[i, j] = Pr(completions[j] | prompts[i]) """ def predict_proba_examples( - self, examples: Example | Sequence[Example], Any + self, examples: Example | Sequence[Example], *args, **kwargs ) -> npt.NDArray[np.floating] | list[npt.NDArray[np.floating]]: """ list[i][j] = Pr(examples[i].completions[j] | examples[i].prompt) """ def predict( - self, prompts: str | Sequence[str], completions: Sequence[str], Any + self, prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs ) -> str | list[str]: """ list[i] = argmax_j Pr(completions[j] | prompts[i]) """ def predict_examples( - self, examples: Example | Sequence[Example], Any + self, examples: Example | Sequence[Example], *args, **kwargs ) -> str | list[str]: """ list[i] = argmax_j Pr(examples[i].completions[j] | examples[i].prompt)