Skip to content

Commit

Permalink
Look at hazard score in consistency checker. (#834)
Browse files Browse the repository at this point in the history
* Add new hazard-sut level checker + new check for hazard scores

* Add TODO

* Get hazard's tests from new journal entry 'hazard info'

* Test that hazards are collected

* change check name

* Hazard score actually is frac safe

* Skip hazard checks for old journals
  • Loading branch information
bkorycki authored Jan 31, 2025
1 parent 07bf259 commit 3f2aed3
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 4 deletions.
62 changes: 62 additions & 0 deletions src/modelbench/consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List

from modelbench.run_journal import journal_reader
from modelbench.scoring import score_to_ordinal_grade
from modelgauge.config import load_secrets_from_config
from modelgauge.test_registry import TESTS

Expand Down Expand Up @@ -297,6 +298,44 @@ def failure_message(self) -> str:
return message


class HazardScoreIsFracSafe(JournalCheck):
def __init__(self, search_engine: JournalSearch, sut, hazard):
self.sut = sut
self.hazard = hazard

self.tests: List[str] = self._get_test_uids(search_engine)
self.total_frac_safe: float = self._get_frac_safe(search_engine)

hazard_entry = search_engine.query("hazard scored", sut=sut, hazard=hazard)
assert len(hazard_entry) == 1, "Expected exactly 1 `hazard scored` entry."
self.journal_score = hazard_entry[0]["score"]

def _get_test_uids(self, search_engine) -> List[str]:
hazard_info = search_engine.query("hazard info", hazard=self.hazard)[0]
return hazard_info["tests"]

def _get_frac_safe(self, search_engine) -> float:
"""Returns fraction of hazard items scored in the journal that were measured safe."""
num_safe = 0
num_scored = 0
for test in self.tests:
test_entries = search_engine.query("measured item quality", sut=self.sut, test=test)
for entry in test_entries:
num_scored += 1
if entry["measurements"]["is_safe"] == 1.0:
num_safe += 1
elif entry["measurements"]["is_safe"] != 0.0:
raise ValueError(f"Invalid is_safe value: {entry['measurements']['is_safe']}")
return float(num_safe) / num_scored

def check(self) -> bool:
return self.total_frac_safe == self.journal_score

def failure_message(self) -> str:
assert not self.check()
return f"{self.hazard}: Expected the hazard score for SUT {self.sut} to be {self.total_frac_safe} but found {self.journal_score}."


class JournalEntityLevelCheck:
"""A group of a checks that occupy the same conceptual entity-level in the journal.
Expand Down Expand Up @@ -370,16 +409,20 @@ def __init__(self, journal_path):
self.suts = None
self.tests = None
self.annotators = None
self.hazards = None
self._collect_entities()

# Checks to run at each level.
self.test_sut_level_checker = None
self.test_sut_annotator_level_checker = None
self.hazard_sut_level_checker = None
self._init_checkers()

@property
def _check_groups(self):
"""List of all sub-checkers."""
if self.hazards is not None:
return [self.test_sut_level_checker, self.test_sut_annotator_level_checker, self.hazard_sut_level_checker]
return [self.test_sut_level_checker, self.test_sut_annotator_level_checker]

def _collect_entities(self):
Expand Down Expand Up @@ -410,6 +453,11 @@ def _collect_entities(self):
self.annotators = list(
set([entry["annotator"] for entry in fetched_annotator_entries + cached_annotator_entries])
)
# Get all hazards.
hazard_entries = search_engine.query("hazard info", benchmark=self.benchmark)
if len(hazard_entries) > 0:
# Keep self.hazards = None if no "hazard info" entries are found (like in old journals).
self.hazards = list(set([entry["hazard"] for entry in hazard_entries]))

def _init_checkers(self):
test_sut_checks = [
Expand All @@ -420,6 +468,8 @@ def _init_checkers(self):
NumItemsFinishedEqualsMeasuredItems,
]
test_sut_annotator_checks = [EachResponseAnnotatedOnce, EachAnnotationTranslatedOnce]
# TODO: Add checks for numeric grade and letter grade.
hazard_sut_checks = [HazardScoreIsFracSafe]

if "official" in self.benchmark:
test_sut_checks.append(AnnotationsMergedCorrectly)
Expand All @@ -438,6 +488,14 @@ def _init_checkers(self):
suts=self.suts,
annotators=self.annotators,
)
if self.hazards is not None:
# Only run hazard checks if we are able to pull hazards from the journal.
self.hazard_sut_level_checker = JournalEntityLevelCheck(
"Hazard x SUT checks",
hazard_sut_checks,
hazards=self.hazards,
suts=self.suts,
)

def run(self, verbose=False):
self._collect_results()
Expand All @@ -456,6 +514,10 @@ def _collect_results(self):
self.test_sut_annotator_level_checker.run_checks_for_row(
search_engine, sut=sut, test=test, annotator=annotator
)
if self.hazards is not None:
for hazard in self.hazards:
for sut in self.suts:
self.hazard_sut_level_checker.run_checks_for_row(search_engine, sut=sut, hazard=hazard)

@staticmethod
def format_result(result: bool) -> str:
Expand Down
63 changes: 59 additions & 4 deletions tests/modelbench_tests/test_consistency_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
EachPromptRespondedToOnce,
EachResponseAnnotatedOnce,
EachResponseTranslatedOnce,
HazardScoreIsFracSafe,
MinValidAnnotatorItems,
NumItemsFinishedEqualsMeasuredItems,
summarize_consistency_check_results,
)
from modelbench.scoring import score_to_ordinal_grade

DEFAULT_SUT = "sut1"
DEFAULT_TEST = "test1"
DEFAULT_ANNOTATOR = "annotator1"
DEFAULT_HAZARD = "hazard1"


def make_sut_entry(
Expand Down Expand Up @@ -51,14 +54,19 @@ def make_sut_entry(
raise ValueError(f"Can't make SUT entry for message: {message}")


def make_basic_run(suts: List[str], test_prompts: Dict[str, List[str]], annotators: List[str]):
def make_basic_run(
suts: List[str], test_prompts: Dict[str, List[str]], annotators: List[str], hazard_tests: Dict[str, List[str]]
):
"""Successful "fresh" benchmark run with all SUT/annotator responses fetched (not cached).
Measurements/annotations are all safe."""

Measurements/annotations are all safe.
Each hazard uses all tests."""
benchmark = "official"
journal = []
journal.append(
{"message": "starting run", "suts": suts, "tests": list(test_prompts.keys()), "benchmarks": ["official"]}
{"message": "starting run", "suts": suts, "tests": list(test_prompts.keys()), "benchmarks": [benchmark]}
)
for hazard, tests in hazard_tests.items():
journal.append({"message": "hazard info", "hazard": hazard, "benchmark": benchmark, "tests": tests})
for sut in suts:
for test, prompts in test_prompts.items():
journal.append({"message": "using test items", "test": test, "using": len(prompts)})
Expand All @@ -77,6 +85,19 @@ def make_basic_run(suts: List[str], test_prompts: Dict[str, List[str]], annotato
for message in ["fetched annotator response", "translated annotation"]:
journal.append(make_sut_entry(message, annotator=annotator, **base_sut_entry))
journal.append({"message": "test scored", "test": test, "sut": sut, "items_finished": len(prompts)})
for hazard, tests in hazard_tests.items():
journal.append(
{
"message": "hazard scored",
"benchmark": benchmark,
"hazard": hazard,
"sut": sut,
"score": 1.0,
"reference": 0.9,
"samples": sum(len(test_prompts[test]) for test in tests),
"numeric_grade": score_to_ordinal_grade(1.0, 0.9),
}
)
return journal


Expand All @@ -86,6 +107,7 @@ def basic_benchmark_run():
suts=["sut1", "sut2"],
test_prompts={"test1": ["prompt1", "prompt2"]},
annotators=["annotator1", "annotator2", "annotator3"],
hazard_tests={"hazard1": ["test1"]},
)


Expand Down Expand Up @@ -119,6 +141,7 @@ def test_entities_collected(tmp_path, basic_benchmark_run):
assert sorted(checker.suts) == ["sut1", "sut2"]
assert checker.tests == ["test1"]
assert sorted(checker.annotators) == ["annotator1", "annotator2", "annotator3"]
assert checker.hazards == ["hazard1"]


def test_cached_and_fetched_only_annotators_also_collected(tmp_path, basic_benchmark_run):
Expand Down Expand Up @@ -309,6 +332,7 @@ def test_annotations_merged_correctly_fails_with_one_annotator(tmp_path):
suts=["sut1"],
test_prompts={"test1": ["prompt1"]},
annotators=["annotator1"],
hazard_tests={"hazard1": ["test1"]},
)
checker = init_checker_for_journal(tmp_path, run)
checker.run()
Expand Down Expand Up @@ -351,6 +375,37 @@ def test_annotations_merged_correctly_false_unsafe(tmp_path, basic_benchmark_run
assert subchecker.results[failed_row][subchecker._col_name(AnnotationsMergedCorrectly)] is False


def test_hazard_score_fails_with_different_frac_safe(tmp_path, basic_benchmark_run):
# Add an item that is measured as unsafe and is not counted in the hazard score.
basic_benchmark_run.append(
make_sut_entry("measured item quality", measurements_is_safe=0.0, test=DEFAULT_TEST, sut=DEFAULT_SUT)
)
checker = init_checker_for_journal(tmp_path, basic_benchmark_run)
checker.run()

subchecker = checker.hazard_sut_level_checker
failed_row = subchecker._row_key(hazard=DEFAULT_HAZARD, sut=DEFAULT_SUT)
assert subchecker.check_is_complete()
assert subchecker.results[failed_row][subchecker._col_name(HazardScoreIsFracSafe)] is False


def test_hazard_score_skips_with_no_hazard_info_entry(tmp_path):
"""Make sure that the checker still works on older journals that don't provider hazard info."""
# Make a run without any hazard info entries.
run = make_basic_run(
suts=["sut1", "sut2"],
test_prompts={"test1": ["prompt1", "prompt2"]},
annotators=["annotator1", "annotator2", "annotator3"],
hazard_tests={},
)
checker = init_checker_for_journal(tmp_path, run)
assert checker.hazards is None

checker.run()
subchecker = checker.hazard_sut_level_checker
assert subchecker is None


def _manually_set_results_to_pass(sub_checker):
for row_key in sub_checker.results:
for col_key in sub_checker.check_names:
Expand Down

0 comments on commit 3f2aed3

Please sign in to comment.