From cfa492953ccb6dcba3890843877f2efc0d2c6a6b Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Thu, 16 Nov 2023 12:55:49 -0500 Subject: [PATCH] Add tests for debugging flags --- test/test_experiment.py | 70 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/test/test_experiment.py b/test/test_experiment.py index d81298f..84ee324 100644 --- a/test/test_experiment.py +++ b/test/test_experiment.py @@ -1,11 +1,14 @@ """These are integration and unit tests for the overall experiment calls.""" +import json import os +from test.examples.params import params1 import pytest from pytest_mock import mocker # noqa: F401 -- flake8 doesn't see it's used as fixture from curifactory.experiment import run_experiment +from curifactory.hashing import param_set_string_hash_representations from curifactory.manager import ArtifactManager # TODO: need to test that specifying no params will default to experiment_name @@ -573,3 +576,70 @@ def test_double_run_many_records_are_distinct_agg( configured_test_manager2.records[0].state["my_agg_output"] != configured_test_manager2.records[1].state["my_agg_output"] ) + + +def test_hashes_only_output(configured_test_manager, capsys): + """Running experiment with hashes only should just output the hashes and names of the + parameter sets passed into the experiment.""" + out, mngr = run_experiment( + "simple_cache", + ["simple_cache"], + param_set_names=["thing1", "thing2"], + hashes_only=True, + mngr=configured_test_manager, + ) + + stdout = capsys.readouterr().out + lines = stdout.split("\n") + last_param_set = mngr.param_file_param_sets["simple_cache"][1] + assert lines[-2] == f"{last_param_set[1]} {last_param_set[0]}" + + +@pytest.mark.skip +def test_print_params_output(configured_test_manager, capsys): + """Running experiment with print-params should just output the parameter set representations + of the parameter sets passed into the experiment.""" + out, mngr = run_experiment( + "simple_cache", + ["simple_cache"], + param_set_names=["thing1"], + print_params=True, + mngr=configured_test_manager, + ) + + stdout = capsys.readouterr().out + lines = stdout.split("\n") + last_param_set = mngr.param_file_param_sets["simple_cache"][0] + assert lines[7] == f"{last_param_set[1]} {last_param_set[0]}" + + +@pytest.mark.skip +def test_print_params_for_registry( + configured_test_manager, configured_test_manager2, capsys +): + """Running experiment with print-params should just output the parameter set representation from the params_registry if a value was passed.""" + with capsys.disabled(): + out, mngr = run_experiment("basic", ["params1"], mngr=configured_test_manager2) + param_set = mngr.records[-1].params + + # just get the first few letters of the query + hash_query = param_set.hash[:6] + + actual_params = params1.get_params()[0] + string_to_match = param_set_string_hash_representations(actual_params) + + out, mngr = run_experiment( + "simple_cache", + ["simple_cache"], + param_set_names=["thing1"], + print_params=hash_query, + mngr=configured_test_manager, + ) + + stdout = capsys.readouterr().out + lines = stdout.split("\n") + assert lines[7] == f"{param_set.hash} {param_set.name}" + + printed_param_set = json.loads("\n".join(lines[8:])) + del printed_param_set["_DRY_REPS"] + assert printed_param_set == string_to_match