Skip to content

Commit

Permalink
remove token requirement to download demo prompt set; make unit tests…
Browse files Browse the repository at this point in the history
… use the demo prompt set rather than a practice prompt set gated behind a token (#861)
  • Loading branch information
rogthefrog authored Feb 12, 2025
1 parent 42e869c commit 8ba3e6b
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 33 deletions.
6 changes: 3 additions & 3 deletions plugins/validation_tests/test_object_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from modelgauge.sut import PromptResponseSUT, SUTResponse
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_registry import SUTS
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1 # see "workaround" below

from modelgauge.suts.huggingface_chat_completion import HUGGING_FACE_TIMEOUT
from modelgauge.test_registry import TESTS
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1 # see "workaround" below
from modelgauge_tests.fake_secrets import fake_all_secrets
from modelgauge_tests.utilities import expensive_tests

# Ensure all the plugins are available during testing.
load_plugins()
# Some tests need to download a file from modellab, which requires a real auth token
_FAKE_SECRETS = fake_all_secrets(use_real_secrets_for=("modellab_files",))
_FAKE_SECRETS = fake_all_secrets()


@pytest.mark.parametrize("test_name", [key for key, _ in TESTS.items()])
Expand Down
3 changes: 2 additions & 1 deletion src/modelgauge/instance_factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import inspect
import threading
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Sequence, Tuple, Type, TypeVar

from modelgauge.dependency_injection import inject_dependencies
from modelgauge.secret_values import MissingSecretValues, RawSecrets
from modelgauge.tracked_object import TrackedObject
from typing import Any, Dict, Generic, List, Sequence, Tuple, Type, TypeVar

_T = TypeVar("_T", bound=TrackedObject)

Expand Down
15 changes: 13 additions & 2 deletions src/modelgauge/prompt_sets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Optional

from modelgauge.locales import EN_US
from modelgauge.secret_values import RequiredSecret, SecretDescription
from modelgauge.secret_values import OptionalSecret, SecretDescription


class ModellabFileDownloadToken(RequiredSecret):
class ModellabFileDownloadToken(OptionalSecret):
@classmethod
def description(cls) -> SecretDescription:
return SecretDescription(
Expand Down Expand Up @@ -62,3 +64,12 @@ def validate_prompt_set(prompt_set: str, locale: str = EN_US, prompt_sets: dict
def prompt_set_to_filename(prompt_set: str) -> str:
"""The official, secret prompt set files are named .+_heldback_*, not _official_"""
return prompt_set.replace("official", "heldback")


def validate_token_requirement(prompt_set: str, token=None) -> bool:
"""This does not validate the token itself, only its presence."""
if prompt_set == "demo":
return True
if token:
return True
raise ValueError(f"Prompt set {prompt_set} requires a token from MLCommons.")
2 changes: 2 additions & 0 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PROMPT_SETS,
ModellabFileDownloadToken,
prompt_set_file_base_name,
validate_token_requirement,
validate_prompt_set,
)
from modelgauge.secret_values import InjectSecret
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
persona_types
), f"Must specify a unique set of persona types, but got {persona_types}"
validate_prompt_set(prompt_set, locale)
validate_token_requirement(prompt_set, token)
validate_locale(locale)

self.hazard = hazard
Expand Down
3 changes: 0 additions & 3 deletions tests/config/secrets.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,3 @@ api_key = "fake key"

# [perspective_api]
# api_key = "<your key here>"

[modellab_files]
token = "fake token"
11 changes: 2 additions & 9 deletions tests/modelgauge_tests/fake_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,13 @@ def description(cls) -> SecretDescription:
return SecretDescription(scope="some-scope", key="some-key", instructions="some-instructions")


def fake_all_secrets(value="some-value", use_real_secrets_for: list[str] | None = None) -> RawSecrets:
def fake_all_secrets(value="some-value") -> RawSecrets:
secrets = get_all_secrets()
raw_secrets: Dict[str, Dict[str, str]] = {}
if use_real_secrets_for:
real_secrets = load_secrets_from_config()
else:
real_secrets = {}

for secret in secrets:
if secret.scope not in raw_secrets:
raw_secrets[secret.scope] = {}
if use_real_secrets_for and secret.scope in use_real_secrets_for:
raw_secrets[secret.scope][secret.key] = real_secrets[secret.scope][secret.key]
else:
raw_secrets[secret.scope][secret.key] = value
raw_secrets[secret.scope][secret.key] = value

return raw_secrets
27 changes: 12 additions & 15 deletions tests/modelgauge_tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
FAKE_TOGETHER_KEY = TogetherApiKey("some-value")


def _init_safe_test_v1(hazard, persona_types, prompt_set="practice"):
def _init_safe_test_v1(hazard, persona_types, prompt_set="demo"):
return SafeTestVersion1("uid", hazard, EN_US, persona_types, prompt_set)


def _init_safe_test_v1_private(hazard, persona_types, prompt_set="practice"):
def _init_safe_test_v1_private(hazard, persona_types, prompt_set="demo"):
# TODO: Mock the private annotators
try:
return SafeTestVersion1("uid", hazard, EN_US, persona_types, prompt_set)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_make_test_items_skips_out_of_scope_prompts(self, tmpdir, safe_test):
["3", "a", self.hazard, FR_FR, "normal"],
["4", "a", self.hazard, FR_FR, "skilled"],
]
_pps = PROMPT_SETS["practice"][EN_US]
_pps = PROMPT_SETS["demo"][EN_US]
dependencies = {_pps: make_csv(csv_header, rows)}
dependency_helper = FakeDependencyHelper(tmpdir, dependencies)

Expand All @@ -141,7 +141,7 @@ def test_make_test_items_skips_bad_hazards(self, tmpdir, safe_test):
["1", "a", self.hazard, EN_US, SafePersonasVersion1.NORMAL],
["2", "a", "wrong", EN_US, SafePersonasVersion1.NORMAL],
]
pps = PROMPT_SETS["practice"][EN_US]
pps = PROMPT_SETS["demo"][EN_US]
dependencies = {pps: make_csv(csv_header, rows)}
dependency_helper = FakeDependencyHelper(tmpdir, dependencies)

Expand Down Expand Up @@ -178,7 +178,7 @@ def test_error_instantiating_undefined_annotators(self):
self.hazard,
EN_US,
[SafePersonasVersion1.NORMAL],
"practice",
"demo",
)

def test_error_getting_undefined_annotators(self):
Expand All @@ -197,20 +197,17 @@ def test_registered_tests(self):

def test_uid_construction(self):
# locale/prompt set variants
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "practice") == "safe-xyz-en_us-practice-1.0"
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "demo") == "safe-xyz-en_us-demo-1.0"
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "official") == "safe-xyz-en_us-official-1.0"

# evaluator variants
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "practice", "default") == "safe-xyz-en_us-practice-1.0"
assert (
BaseSafeTestVersion1.create_uid("xyz", EN_US, "practice", "ensemble")
== "safe-xyz-en_us-practice-1.0-ensemble"
)
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "demo", "default") == "safe-xyz-en_us-demo-1.0"
assert BaseSafeTestVersion1.create_uid("xyz", EN_US, "demo", "ensemble") == "safe-xyz-en_us-demo-1.0-ensemble"

@pytest.mark.parametrize("prompt_set", PROMPT_SETS.keys())
def test_correct_prompt_set_dependency(self, prompt_set):
practice_test = _init_safe_test_v1(self.hazard, "normal", prompt_set=prompt_set)
dependencies = practice_test.get_dependencies()
def test_correct_prompt_set_dependency(self):
prompt_set = "demo" # using demo because it doesn't require a token to download
demo_test = _init_safe_test_v1(self.hazard, "normal", prompt_set=prompt_set)
dependencies = demo_test.get_dependencies()

assert len(dependencies) == 1

Expand Down

0 comments on commit 8ba3e6b

Please sign in to comment.