Skip to content

Commit

Permalink
Add tests for debugging flags
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Nov 16, 2023
1 parent 4897f57 commit cfa4929
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions test/test_experiment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""These are integration and unit tests for the overall experiment calls."""

import json
import os
from test.examples.params import params1

import pytest
from pytest_mock import mocker # noqa: F401 -- flake8 doesn't see it's used as fixture

from curifactory.experiment import run_experiment
from curifactory.hashing import param_set_string_hash_representations
from curifactory.manager import ArtifactManager

# TODO: need to test that specifying no params will default to experiment_name
Expand Down Expand Up @@ -573,3 +576,70 @@ def test_double_run_many_records_are_distinct_agg(
configured_test_manager2.records[0].state["my_agg_output"]
!= configured_test_manager2.records[1].state["my_agg_output"]
)


def test_hashes_only_output(configured_test_manager, capsys):
"""Running experiment with hashes only should just output the hashes and names of the
parameter sets passed into the experiment."""
out, mngr = run_experiment(
"simple_cache",
["simple_cache"],
param_set_names=["thing1", "thing2"],
hashes_only=True,
mngr=configured_test_manager,
)

stdout = capsys.readouterr().out
lines = stdout.split("\n")
last_param_set = mngr.param_file_param_sets["simple_cache"][1]
assert lines[-2] == f"{last_param_set[1]} {last_param_set[0]}"


@pytest.mark.skip
def test_print_params_output(configured_test_manager, capsys):
"""Running experiment with print-params should just output the parameter set representations
of the parameter sets passed into the experiment."""
out, mngr = run_experiment(
"simple_cache",
["simple_cache"],
param_set_names=["thing1"],
print_params=True,
mngr=configured_test_manager,
)

stdout = capsys.readouterr().out
lines = stdout.split("\n")
last_param_set = mngr.param_file_param_sets["simple_cache"][0]
assert lines[7] == f"{last_param_set[1]} {last_param_set[0]}"


@pytest.mark.skip
def test_print_params_for_registry(
configured_test_manager, configured_test_manager2, capsys
):
"""Running experiment with print-params should just output the parameter set representation from the params_registry if a value was passed."""
with capsys.disabled():
out, mngr = run_experiment("basic", ["params1"], mngr=configured_test_manager2)
param_set = mngr.records[-1].params

# just get the first few letters of the query
hash_query = param_set.hash[:6]

actual_params = params1.get_params()[0]
string_to_match = param_set_string_hash_representations(actual_params)

out, mngr = run_experiment(
"simple_cache",
["simple_cache"],
param_set_names=["thing1"],
print_params=hash_query,
mngr=configured_test_manager,
)

stdout = capsys.readouterr().out
lines = stdout.split("\n")
assert lines[7] == f"{param_set.hash} {param_set.name}"

printed_param_set = json.loads("\n".join(lines[8:]))
del printed_param_set["_DRY_REPS"]
assert printed_param_set == string_to_match

0 comments on commit cfa4929

Please sign in to comment.