Skip to content

Commit

Permalink
Wrap solvers with completion functions for compatibility with pre-sol…
Browse files Browse the repository at this point in the history
…ver Evals

Co-authored-by: johny-b <33967107+johny-b@users.noreply.github.com>
Co-authored-by: Chan Jun Shern <JunShern@users.noreply.github.com>
Co-authored-by: Giulio Starace <giulio.starace@gmail.com>
  • Loading branch information
4 people committed Mar 27, 2024
1 parent 063bf4f commit 49fd9ef
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 98 deletions.
73 changes: 73 additions & 0 deletions evals/completion_fns/solver_completion_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Union

from evals.api import CompletionFn, CompletionResult
from evals.prompt.base import OpenAICreateChatPrompt
from evals.solvers.nested.cot_solver import CoTSolver
from evals.solvers.solver import Solver, SolverSpec, create_solver
from evals.task_state import Message, TaskState


class SolverCompletionFnResult(CompletionResult):
def __init__(self, msg):
self.msg = msg

def get_completions(self):
return [self.msg]


class SolverCompletionFn(CompletionFn):
"""
Wraps a solver into a completion function, s.t. that the completion function's
__call__ method calls the internal solver's _solve method, mapping the input
completion function `prompt` to the solver's `task_state` input.
Useful for using Solvers with eval.Eval classes, which would normally require a CompletionFn.
Current limitations:
- Stateful solvers are not supported: Solver state is not maintained between
calls.
- Prompts with more than `role` and `content` keys are not supported.
"""

def __init__(self, solver: Union[SolverSpec, Solver], registry: Any = None):
if isinstance(solver, Solver):
self.solver = solver
else:
self.solver = create_solver(solver)

def __call__(
self, prompt: Union[str, OpenAICreateChatPrompt], **kwargs
) -> SolverCompletionFnResult:
# We have this check here rather than __init__ since the solver may be unwrapped and used in a SolverEval
if isinstance(self.solver, CoTSolver):
if self.solver.interaction_cache is not None:
raise ValueError(
"`CoTSolver` with persistent memory is incompatible with "
"CompletionFn-based `Eval` classes. "
"Please set `CoTSolver(persistent_memory=False)` or update the eval to a `SolverEval`."
)

if isinstance(prompt, str):
prompt = [{"role": "system", "content": prompt}]
elif isinstance(prompt, list):
assert prompt[0]["role"] == "system", "Unexpected prompt role ordering"
else:
raise ValueError(
f"Unexpected prompt type: "
f"string or OpenAICreateChatPrompt expected, got {type(prompt)}"
)

assert set(prompt[0].keys()) == {"role", "content",}, (
"Unexpected keys in prompt: "
f"expected exactly {{'role', 'content'}}, got {set(prompt[0].keys())}"
)
task_state = TaskState(
prompt[0]["content"],
[Message(msg["role"], msg["content"]) for msg in prompt[1:]],
)

# use a copy to avoid task state surviving across samples
pure_solver = self.solver.copy()

result = pure_solver(task_state, **kwargs)
return SolverCompletionFnResult(result.output)
4 changes: 2 additions & 2 deletions evals/elsuite/bluff/strategy_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Optional

from evals.elsuite.bluff.bluff.cards import get_bluff_move
from evals.solvers.memory import PersistentMemoryCache
from evals.solvers.solver import Solver, SolverResult
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState


Expand All @@ -28,7 +28,7 @@ def __init__(

# interaction_length=1 to store reasoning step in private memory
self.interaction_cache = PersistentMemoryCache(interaction_length=1)

def _generate_response(self, task_state: TaskState):
"""
Calls base solver. Modifies taks state to remove all non-reasoning messages
Expand Down
21 changes: 8 additions & 13 deletions evals/elsuite/make_me_say/eval.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,26 @@
import numpy as np

import evals
from evals.api import CompletionFn, DummyCompletionFn
from evals.api import DummyCompletionFn
from evals.elsuite.make_me_say.autoeval import run as run_auto_eval
from evals.elsuite.make_me_say.core import Game
from evals.record import RecorderBase


class MakeMeSay(evals.Eval):
def __init__(
self,
completion_fns: list[CompletionFn],
*args,
**kwargs,
):
super().__init__(completion_fns, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if len(completion_fns) == 1 and isinstance(completion_fns[0], DummyCompletionFn):
completion_fn = completion_fns[0]
completion_fns = [completion_fn for _ in range(3)]
if len(self.completion_fns) == 1 and isinstance(self.completion_fns[0], DummyCompletionFn):
completion_fn = self.completion_fns[0]
self.completion_fns = [completion_fn for _ in range(3)]

assert len(completion_fns) == 3, "MakeMeSay only supports three completion fns"
assert len(self.completion_fns) == 3, "MakeMeSay only supports three completion fns"
(
self.manipulator_completion_fn,
self.manipulatee_completion_fn,
self.judge_completion_fn,
) = completion_fns
) = self.completion_fns

def eval_sample(self, sample: dict, rng) -> None:
del rng
Expand Down
8 changes: 4 additions & 4 deletions evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union

from tqdm import tqdm

Expand All @@ -18,7 +18,7 @@
from .record import RecorderBase
from .registry import Registry
from .solvers.solver import Solver
from .solvers.utils import maybe_wrap_with_solver
from .solvers.utils import maybe_wrap_with_compl_fn, maybe_wrap_with_solver

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +55,7 @@ class Eval(abc.ABC):

def __init__(
self,
completion_fns: list[CompletionFn],
completion_fns: list[Union[CompletionFn, Solver]],
eval_registry_path: Path,
seed: int = 20220722,
name: str = "no_name_eval.default",
Expand All @@ -66,7 +66,7 @@ def __init__(
if len(splits) < 2:
raise ValueError(f"Eval name must at least have <base_eval>.<split>. Got name {name}")

self.completion_fns = completion_fns
self.completion_fns = [maybe_wrap_with_compl_fn(fn) for fn in completion_fns]
self.eval_registry_path = eval_registry_path
self.seed = seed
self.name = name
Expand Down
64 changes: 64 additions & 0 deletions evals/solvers/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import List

from evals.task_state import Message, TaskState


@dataclass
class Interaction:
# All messages we've seen (except for the task_description)
messages: List[Message]

# IDs of the CoT private internal messages
private_messages_ids: List[int]


class PersistentMemoryCache:
def __init__(
self,
interaction_length: int,
):
self.private_interaction_length = interaction_length
self.last_interaction = None

def save_private_interaction(self, task_state: TaskState):
# Save the interaction
interaction_messages = task_state.messages
num_interaction_messages = len(interaction_messages)
private_messages_ids = (
[] if self.last_interaction is None else self.last_interaction.private_messages_ids
)
private_messages_ids += list(
range(
num_interaction_messages - self.private_interaction_length - 1,
num_interaction_messages - 1,
)
)
self.last_interaction = Interaction(interaction_messages, private_messages_ids)

def load_private_interaction(self, task_state: TaskState) -> List[Message]:
if self.last_interaction is None:
return task_state.messages

# Check if task_state matches our last interaction
interaction = self.last_interaction
task_state_message_ix = 0
for our_message_ix in range(0, len(interaction.messages)):
if our_message_ix in interaction.private_messages_ids:
continue
else:
if (
task_state.messages[task_state_message_ix]
!= interaction.messages[our_message_ix]
):
raise ValueError(
(
f"task_state message {task_state_message_ix} different than the corresponding message "
"in the interaction history.\n"
f"task_state.messages:\n{task_state.messages}\n"
f"interaction.messages:\n{interaction.messages}\n"
)
)
task_state_message_ix += 1

return interaction.messages + task_state.messages[task_state_message_ix:]
2 changes: 1 addition & 1 deletion evals/solvers/nested/cot_solver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any

from evals.solvers.memory import PersistentMemoryCache
from evals.solvers.prompts.cot import DEFAULT_COT_TEMPLATE, DEFAULT_EXTRACT_ANSWER_TEMPLATE
from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState


Expand Down
2 changes: 1 addition & 1 deletion evals/solvers/nested/self_consistency_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from collections import Counter
from typing import Any, Optional

from evals.solvers.memory import PersistentMemoryCache
from evals.solvers.solver import NestedSolver, Solver, SolverResult, SolverSpec
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState

DEFAULT_COT_TEMPLATE = """Before answering, reason in a step-by-step manner as to get the right answer, then conclude with the answer. Format your output as {prefix} <answer>"""
Expand Down
12 changes: 8 additions & 4 deletions evals/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,7 @@ def _solve(
return self._solver_cache[solver_name]

def _create_solver(self, solver_spec: SolverSpec) -> Solver:
module_name, class_name = solver_spec["class"].split(":")
module = import_module(module_name)
cls = getattr(module, class_name)
return cls(**solver_spec["args"])
return create_solver(solver_spec)

def copy(self: SolverType) -> SolverType:
# The NestedSolver needs to manually copy the sub-solvers, otherwise we will miss any
Expand All @@ -210,3 +207,10 @@ def model_version(self) -> Union[str, dict]:
model_versions[solver_name] = solver_model_version

return model_versions


def create_solver(solver_spec: dict) -> Solver:
module_name, class_name = solver_spec["class"].split(":")
module = import_module(module_name)
cls = getattr(module, class_name)
return cls(**solver_spec["args"])
Loading

0 comments on commit 49fd9ef

Please sign in to comment.