-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Wrap solvers with completion functions for compatibility with pre-sol…
…ver Evals Co-authored-by: johny-b <33967107+johny-b@users.noreply.github.com> Co-authored-by: Chan Jun Shern <JunShern@users.noreply.github.com> Co-authored-by: Giulio Starace <giulio.starace@gmail.com>
- Loading branch information
1 parent
063bf4f
commit 49fd9ef
Showing
9 changed files
with
194 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Any, Union | ||
|
||
from evals.api import CompletionFn, CompletionResult | ||
from evals.prompt.base import OpenAICreateChatPrompt | ||
from evals.solvers.nested.cot_solver import CoTSolver | ||
from evals.solvers.solver import Solver, SolverSpec, create_solver | ||
from evals.task_state import Message, TaskState | ||
|
||
|
||
class SolverCompletionFnResult(CompletionResult): | ||
def __init__(self, msg): | ||
self.msg = msg | ||
|
||
def get_completions(self): | ||
return [self.msg] | ||
|
||
|
||
class SolverCompletionFn(CompletionFn): | ||
""" | ||
Wraps a solver into a completion function, s.t. that the completion function's | ||
__call__ method calls the internal solver's _solve method, mapping the input | ||
completion function `prompt` to the solver's `task_state` input. | ||
Useful for using Solvers with eval.Eval classes, which would normally require a CompletionFn. | ||
Current limitations: | ||
- Stateful solvers are not supported: Solver state is not maintained between | ||
calls. | ||
- Prompts with more than `role` and `content` keys are not supported. | ||
""" | ||
|
||
def __init__(self, solver: Union[SolverSpec, Solver], registry: Any = None): | ||
if isinstance(solver, Solver): | ||
self.solver = solver | ||
else: | ||
self.solver = create_solver(solver) | ||
|
||
def __call__( | ||
self, prompt: Union[str, OpenAICreateChatPrompt], **kwargs | ||
) -> SolverCompletionFnResult: | ||
# We have this check here rather than __init__ since the solver may be unwrapped and used in a SolverEval | ||
if isinstance(self.solver, CoTSolver): | ||
if self.solver.interaction_cache is not None: | ||
raise ValueError( | ||
"`CoTSolver` with persistent memory is incompatible with " | ||
"CompletionFn-based `Eval` classes. " | ||
"Please set `CoTSolver(persistent_memory=False)` or update the eval to a `SolverEval`." | ||
) | ||
|
||
if isinstance(prompt, str): | ||
prompt = [{"role": "system", "content": prompt}] | ||
elif isinstance(prompt, list): | ||
assert prompt[0]["role"] == "system", "Unexpected prompt role ordering" | ||
else: | ||
raise ValueError( | ||
f"Unexpected prompt type: " | ||
f"string or OpenAICreateChatPrompt expected, got {type(prompt)}" | ||
) | ||
|
||
assert set(prompt[0].keys()) == {"role", "content",}, ( | ||
"Unexpected keys in prompt: " | ||
f"expected exactly {{'role', 'content'}}, got {set(prompt[0].keys())}" | ||
) | ||
task_state = TaskState( | ||
prompt[0]["content"], | ||
[Message(msg["role"], msg["content"]) for msg in prompt[1:]], | ||
) | ||
|
||
# use a copy to avoid task state surviving across samples | ||
pure_solver = self.solver.copy() | ||
|
||
result = pure_solver(task_state, **kwargs) | ||
return SolverCompletionFnResult(result.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
from evals.task_state import Message, TaskState | ||
|
||
|
||
@dataclass | ||
class Interaction: | ||
# All messages we've seen (except for the task_description) | ||
messages: List[Message] | ||
|
||
# IDs of the CoT private internal messages | ||
private_messages_ids: List[int] | ||
|
||
|
||
class PersistentMemoryCache: | ||
def __init__( | ||
self, | ||
interaction_length: int, | ||
): | ||
self.private_interaction_length = interaction_length | ||
self.last_interaction = None | ||
|
||
def save_private_interaction(self, task_state: TaskState): | ||
# Save the interaction | ||
interaction_messages = task_state.messages | ||
num_interaction_messages = len(interaction_messages) | ||
private_messages_ids = ( | ||
[] if self.last_interaction is None else self.last_interaction.private_messages_ids | ||
) | ||
private_messages_ids += list( | ||
range( | ||
num_interaction_messages - self.private_interaction_length - 1, | ||
num_interaction_messages - 1, | ||
) | ||
) | ||
self.last_interaction = Interaction(interaction_messages, private_messages_ids) | ||
|
||
def load_private_interaction(self, task_state: TaskState) -> List[Message]: | ||
if self.last_interaction is None: | ||
return task_state.messages | ||
|
||
# Check if task_state matches our last interaction | ||
interaction = self.last_interaction | ||
task_state_message_ix = 0 | ||
for our_message_ix in range(0, len(interaction.messages)): | ||
if our_message_ix in interaction.private_messages_ids: | ||
continue | ||
else: | ||
if ( | ||
task_state.messages[task_state_message_ix] | ||
!= interaction.messages[our_message_ix] | ||
): | ||
raise ValueError( | ||
( | ||
f"task_state message {task_state_message_ix} different than the corresponding message " | ||
"in the interaction history.\n" | ||
f"task_state.messages:\n{task_state.messages}\n" | ||
f"interaction.messages:\n{interaction.messages}\n" | ||
) | ||
) | ||
task_state_message_ix += 1 | ||
|
||
return interaction.messages + task_state.messages[task_state_message_ix:] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.