diff --git a/mitosis/__init__.py b/mitosis/__init__.py index a8bf8bd..810ff47 100644 --- a/mitosis/__init__.py +++ b/mitosis/__init__.py @@ -157,7 +157,7 @@ def __init__( self.log_table = Table(table_name, md, *cols) url = "sqlite:///" + str(self.db) self.eng = create_engine(url) - with self.eng.connect() as conn: + with self.eng.begin() as conn: if not inspection.inspect(conn).has_table(table_name): md.create_all(conn) @@ -182,7 +182,7 @@ def emit(self, record: logging.LogRecord): stmt = stmt.values({col: vals[i + 1]}) else: raise ValueError("Cannot parse db message") - with self.eng.connect() as conn: + with self.eng.begin() as conn: conn.execute(stmt) def parse_record(self, msg: str) -> List[str]: @@ -220,21 +220,21 @@ def _verify_variant_name(trial_db: Path, step: str, param: Parameter) -> None: eng = create_engine("sqlite:///" + str(trial_db)) md = MetaData() tb = Table(f"{step}_variant_{param.arg_name}", md, *variant_types()) - vals: Collection[Any] + vals: Collection if isinstance(param.vals, Mapping): vals = StrictlyReproduceableDict({k: v for k, v in sorted(param.vals.items())}) elif isinstance(param.vals, Collection) and not isinstance(param.vals, str): try: vals = StrictlyReproduceableList(sorted(param.vals)) except (ValueError, TypeError): - vals = param.vals + vals = str(param.vals) else: - vals = param.vals + vals = str(param.vals) df = pd.read_sql(select(tb), eng) ind_equal = df.loc[:, "name"] == param.var_name if ind_equal.sum() == 0: stmt = tb.insert().values({"name": param.var_name, "params": str(vals)}) - with eng.connect() as conn: + with eng.begin() as conn: conn.execute(stmt) elif df.loc[ind_equal, "params"].iloc[0] != str(vals): raise RuntimeError( diff --git a/mitosis/_typing.py b/mitosis/_typing.py index 96c3609..4e0e224 100644 --- a/mitosis/_typing.py +++ b/mitosis/_typing.py @@ -1,23 +1,18 @@ -from abc import ABCMeta +from collections.abc import Mapping from dataclasses import dataclass from dataclasses import field -from types import ModuleType -from typing import Any from typing import Callable from typing import NamedTuple from typing import ParamSpec +from typing import TypedDict -P = ParamSpec("P") -ExpRun = Callable[P, dict] +class ExpResults(TypedDict): + main: object -class Experiment(ModuleType, metaclass=ABCMeta): - __name__: str - __file__: str - name: str - lookup_dict: dict[str, dict[str, Any]] - run: ExpRun +P = ParamSpec("P") +ExpRun = Callable[P, ExpResults] @dataclass @@ -33,7 +28,7 @@ class Parameter: var_name: str arg_name: str - vals: Any + vals: object # > 3.10 only: https://stackoverflow.com/a/49911616/534674 evaluate: bool = field(default=False, kw_only=True) @@ -42,7 +37,7 @@ class ExpStep(NamedTuple): name: str action: ExpRun action_ref: str - lookup: dict[str, Any] + lookup: Mapping[str, Mapping[str, object]] lookup_ref: str group: str | None args: list[Parameter] diff --git a/mitosis/tests/mock_paper.py b/mitosis/tests/mock_paper.py index 0fe0493..8d9f975 100644 --- a/mitosis/tests/mock_paper.py +++ b/mitosis/tests/mock_paper.py @@ -1,3 +1,10 @@ +from collections import defaultdict + data_config = {"length": {"test": 5}} meth_config = {"metric": {"test": "len"}} + +# lookup any parameter, any variant: always none +lookup_default: dict[str, dict[str, None]] = defaultdict( + lambda: defaultdict(lambda: None) +) diff --git a/mitosis/tests/mock_part1.py b/mitosis/tests/mock_part1.py index 916cb17..a8a5328 100644 --- a/mitosis/tests/mock_part1.py +++ b/mitosis/tests/mock_part1.py @@ -1,16 +1,23 @@ from logging import getLogger import numpy as np -from numpy.typing import NBitBase -from numpy.typing import NDArray + +from mitosis._typing import ExpResults class Klass: @staticmethod - def gen_data( - length: int, extra: bool = False - ) -> dict[str, NDArray[np.floating[NBitBase]] | bool]: + def gen_data(length: int, extra: bool = False) -> ExpResults: getLogger(__name__).info("This is run every time") getLogger(__name__).debug("This is run in debug mode only") - return {"data": np.ones(length, dtype=np.float_), "extra": extra} + return { + "data": np.ones(length, dtype=np.float_), + "extra": extra, + "main": None, + } # type: ignore + + +def do_nothing(*args, **kwargs) -> ExpResults: + """An experiment step that accepts anything and produces nothing""" + return {"main": None} diff --git a/mitosis/tests/mock_part2.py b/mitosis/tests/mock_part2.py index 8cd3718..ff0d39e 100644 --- a/mitosis/tests/mock_part2.py +++ b/mitosis/tests/mock_part2.py @@ -5,15 +5,17 @@ from numpy.typing import NBitBase from numpy.typing import NDArray +from mitosis._typing import ExpResults + def fit_and_score( data: NDArray[np.floating[NBitBase]], metric: Literal["len"] | Literal["zero"] -) -> dict[str, float]: +) -> ExpResults: if metric == "len": return {"main": len(data)} elif metric == "zero": return {"main": 0} -def bad_runnable(*args: Any, **kwargs: Any): +def bad_runnable(*args: Any, **kwargs: Any) -> int: return 1 # not a dict with key "main" diff --git a/mitosis/tests/test_all.py b/mitosis/tests/test_all.py index aa5bb2a..a000393 100644 --- a/mitosis/tests/test_all.py +++ b/mitosis/tests/test_all.py @@ -148,6 +148,33 @@ def test_mock_experiment(mock_steps, tmp_path): assert (metadata / "experiment").resolve().exists() +@pytest.fixture +def nothing_step(): + # fmt: off + return ExpStep( + "nothing", + mock_part1.do_nothing, "mitosis.tests.mock_part1:do_nothing", + mock_paper.lookup_default, "mitosis.tests.mock_paper:lookup_default", + None, + [], + [] + ) + # fmt: on + + +@pytest.mark.clean +def test_variant_redefinition_disallowed(nothing_step, tmp_path): + # GH 56 + chg_param1 = Parameter("foo_a", "foo", "a", evaluate=False) + chg_param2 = Parameter("foo_a", "foo", "b", evaluate=False) + nothing_step.args.append(chg_param1) + mitosis.run([nothing_step], trials_folder=tmp_path) + nothing_step.args.pop(0) + nothing_step.args.append(chg_param2) + with pytest.raises(RuntimeError, match="stored with different values"): + mitosis.run([nothing_step], trials_folder=tmp_path) + + def test_load_results_order(tmp_path): exp_key = "test_results" (tmp_path / exp_key).mkdir() @@ -224,17 +251,11 @@ def test_malfored_return_experiment(mock_steps, tmp_path): def test_load_toml(): parent = Path(__file__).resolve().parent tomlfile = parent / "test_pyproject.toml" - result = _disk.load_mitosis_steps(tomlfile) - expected = { - "data": ( - "mitosis.tests.mock_part1:Klass.gen_data", - "mitosis.tests.mock_paper:data_config", - ), - "fit_eval": ( - "mitosis.tests.mock_part2:fit_and_score", - "mitosis.tests.mock_paper:meth_config", - ), - } + result = _disk.load_mitosis_steps(tomlfile)["nothing"] + expected = ( + "mitosis.tests.mock_part1:do_nothing", + "mitosis.tests.mock_paper:lookup_default", + ) assert result == expected diff --git a/mitosis/tests/test_pyproject.toml b/mitosis/tests/test_pyproject.toml index 3cbb908..3544bde 100644 --- a/mitosis/tests/test_pyproject.toml +++ b/mitosis/tests/test_pyproject.toml @@ -1,4 +1,7 @@ [tool.mitosis.steps] +nothing = [ + "mitosis.tests.mock_part1:do_nothing", + "mitosis.tests.mock_paper:lookup_default"] data = [ "mitosis.tests.mock_part1:Klass.gen_data", "mitosis.tests.mock_paper:data_config" diff --git a/pyproject.toml b/pyproject.toml index 4f43008..517167f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "nbclient", "nbformat", "pandas<2.2", - "sqlalchemy", + "sqlalchemy>=1.4", "toml", "types-toml", ]