Skip to content

Commit

Permalink
renaming omnimath to omni_math
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu committed Dec 16, 2024
1 parent 7b2d7b6 commit a3ab0a6
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ recursive-include src/helm/benchmark/ *.json
recursive-include src/helm/benchmark/static/ *.css *.html *.js *.png *.yaml
recursive-include src/helm/benchmark/static_build/ *.css *.html *.js *.png *.yaml
recursive-include src/helm/config/ *.yaml
recursive-include src/helm/benchmark/annotation/omnimath/ *.txt
recursive-include src/helm/benchmark/annotation/omni_math/ *.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def parse_report(report):


class OmniMATHAnnotator(Annotator):
"""The OmniMATH autograder."""
"""The Omni-MATH autograder."""

name = "omnimath"
name = "omni_math"

def __init__(self, auto_client: AutoClient):
self._auto_client = auto_client
template_path = files("src.helm.benchmark.annotation.omnimath").joinpath("gpt_evaluation_template.txt")
template_path = files("src.helm.benchmark.annotation.omni_math").joinpath("gpt_evaluation_template.txt")
with template_path.open("r") as file:
self._score_template = file.read()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class OmniMATHMetric(Metric):
"""Score metrics for OmniMATH."""
"""Score metrics for Omni-MATH."""

def evaluate_generation(
self,
Expand All @@ -19,7 +19,7 @@ def evaluate_generation(
eval_cache_path: str,
) -> List[Stat]:
assert request_state.annotations
score = request_state.annotations["omnimath"]["correctness"]
score = request_state.annotations["omni_math"]["correctness"]
return [
Stat(MetricName("omnimath_accuracy")).add(score),
Stat(MetricName("omni_math_accuracy")).add(score),
]
14 changes: 7 additions & 7 deletions src/helm/benchmark/run_specs/lite_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,24 +513,24 @@ def get_bigcodebench_spec(version: str) -> RunSpec:
)


@run_spec_function("omnimath")
def get_omnimath_spec() -> RunSpec:
@run_spec_function("omni_math")
def get_omni_math_spec() -> RunSpec:

scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.omnimath_scenario.OmniMATHScenario"
class_name="helm.benchmark.scenarios.omni_math_scenario.OmniMATHScenario"
)

adapter_spec = AdapterSpec(
method=ADAPT_GENERATION, input_prefix="", output_prefix="", max_tokens=1000, num_outputs=1, temperature=0.0,
)
annotator_specs = [AnnotatorSpec(class_name="helm.benchmark.annotation.omnimath_annotator.OmniMATHAnnotator")]
metric_specs = get_basic_metric_specs([]) + [MetricSpec(class_name="helm.benchmark.metrics.omnimath_metrics.OmniMATHMetric")]
annotator_specs = [AnnotatorSpec(class_name="helm.benchmark.annotation.omni_math_annotator.OmniMATHAnnotator")]
metric_specs = get_basic_metric_specs([]) + [MetricSpec(class_name="helm.benchmark.metrics.omni_math_metrics.OmniMATHMetric")]

return RunSpec(
name="omnimath",
name="omni_math",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
annotators=annotator_specs,
metric_specs=metric_specs,
groups=["omnimath"],
groups=["omni_math"],
)
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from tempfile import TemporaryDirectory

from helm.benchmark.scenarios.omnimath_scenario import OmniMATHScenario
from helm.benchmark.scenarios.omni_math_scenario import OmniMATHScenario
from helm.benchmark.scenarios.scenario import Input, TEST_SPLIT


@pytest.mark.scenarios
def test_omnimath_scenario_get_instances():
omnimath_scenario = OmniMATHScenario()
def test_omni_math_scenario_get_instances():
omni_math_scenario = OmniMATHScenario()
with TemporaryDirectory() as tmpdir:
instances = omnimath_scenario.get_instances(tmpdir)
instances = omni_math_scenario.get_instances(tmpdir)
assert len(instances) == 4428
assert instances[0].input == Input(
text=(
Expand Down

0 comments on commit a3ab0a6

Please sign in to comment.