From 64bbe817a2978e8098a76c70242ff6e81e47cb41 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Thu, 3 Nov 2022 12:37:55 +0100 Subject: [PATCH] Some renaming and new tests --- neurots/generate/algorithms/tmdgrower.py | 2 +- neurots/preprocess/__init__.py | 4 +- neurots/preprocess/exceptions.py | 20 ++++++ neurots/preprocess/preprocessors.py | 5 +- ...ancy_checkers.py => relevance_checkers.py} | 6 +- neurots/preprocess/utils.py | 16 ++--- neurots/preprocess/validity_checkers.py | 9 +-- tests/test_preprocess.py | 69 +++++++++++++++++-- 8 files changed, 107 insertions(+), 24 deletions(-) create mode 100644 neurots/preprocess/exceptions.py rename neurots/preprocess/{relevancy_checkers.py => relevance_checkers.py} (89%) diff --git a/neurots/generate/algorithms/tmdgrower.py b/neurots/generate/algorithms/tmdgrower.py index 68f771c6..41c9bf13 100644 --- a/neurots/generate/algorithms/tmdgrower.py +++ b/neurots/generate/algorithms/tmdgrower.py @@ -27,7 +27,7 @@ from neurots.generate.algorithms.common import section_data from neurots.morphmath import sample from neurots.morphmath.utils import norm -from neurots.preprocess.relevancy_checkers import check_min_bar_length +from neurots.preprocess.relevance_checkers import check_min_bar_length L = logging.getLogger(__name__) diff --git a/neurots/preprocess/__init__.py b/neurots/preprocess/__init__.py index 79e67028..8919e6c9 100644 --- a/neurots/preprocess/__init__.py +++ b/neurots/preprocess/__init__.py @@ -15,8 +15,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from neurots.preprocess import relevancy_checkers # noqa +from neurots.preprocess import relevance_checkers # noqa from neurots.preprocess import validity_checkers # noqa from neurots.preprocess.utils import preprocess_inputs # noqa -from neurots.preprocess.utils import register_preprocess # noqa +from neurots.preprocess.utils import register_preprocessor # noqa from neurots.preprocess.utils import register_validator # noqa diff --git a/neurots/preprocess/exceptions.py b/neurots/preprocess/exceptions.py new file mode 100644 index 00000000..d9da5c8b --- /dev/null +++ b/neurots/preprocess/exceptions.py @@ -0,0 +1,20 @@ +"""Exceptions raised during validation process.""" + +# Copyright (C) 2022 Blue Brain Project, EPFL +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +class NeuroTSValidationError(Exception): + """Exception raised when a configuration set is not valid.""" diff --git a/neurots/preprocess/preprocessors.py b/neurots/preprocess/preprocessors.py index c473f6d0..bc4a8cb3 100644 --- a/neurots/preprocess/preprocessors.py +++ b/neurots/preprocess/preprocessors.py @@ -23,6 +23,5 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -# TODO: remove the next line and the 'noqa' when a preprocess is added -# pylint: disable=unused-import -from neurots.preprocess.utils import register_preprocess # noqa +# TODO: uncomment the next line when a preprocess is added +# from neurots.preprocess.utils import register_preprocessor diff --git a/neurots/preprocess/relevancy_checkers.py b/neurots/preprocess/relevance_checkers.py similarity index 89% rename from neurots/preprocess/relevancy_checkers.py rename to neurots/preprocess/relevance_checkers.py index 1be20e2b..519a26b5 100644 --- a/neurots/preprocess/relevancy_checkers.py +++ b/neurots/preprocess/relevance_checkers.py @@ -1,9 +1,11 @@ -"""Functions to check that the given parameters and distributions will given relevant results. +"""Functions to check that the given parameters and distributions will give relevant results. -The functions used as relevancy checkers should have a name like 'check_*' and have the following +The functions used as relevance checkers should have a name like 'check_*' and have the following signature: `check_something(params, distrs, start_point=None, context=None)`. The `start_point` and `context` parameters should always be optional, as they will not be known during the preprocessing step. + +These functions can be called either in validity checkers or inside the grower codes. """ # Copyright (C) 2022 Blue Brain Project, EPFL diff --git a/neurots/preprocess/utils.py b/neurots/preprocess/utils.py index 3ed7cebc..5dc38cb6 100644 --- a/neurots/preprocess/utils.py +++ b/neurots/preprocess/utils.py @@ -22,18 +22,18 @@ from neurots.validator import validate_neuron_distribs from neurots.validator import validate_neuron_params -_PREPROCESS_FUNCTIONS = { - "preprocess": defaultdict(set), - "validator": defaultdict(set), +_REGISTERED_FUNCTIONS = { + "preprocessors": defaultdict(set), + "validators": defaultdict(set), } -def register_preprocess(*growth_methods): +def register_preprocessor(*growth_methods): """Register a preprocess function.""" def inner(func): for i in growth_methods: - _PREPROCESS_FUNCTIONS["preprocess"][i].add(func) + _REGISTERED_FUNCTIONS["preprocessors"][i].add(func) return func return inner @@ -44,7 +44,7 @@ def register_validator(*growth_methods): def inner(func): for i in growth_methods: - _PREPROCESS_FUNCTIONS["validator"][i].add(func) + _REGISTERED_FUNCTIONS["validators"][i].add(func) return func return inner @@ -60,8 +60,8 @@ def preprocess_inputs(params, distrs): for grow_type in params["grow_types"]: growth_method = params[grow_type]["growth_method"] for preprocess_func in chain( - _PREPROCESS_FUNCTIONS["validator"][growth_method], - _PREPROCESS_FUNCTIONS["preprocess"][growth_method], + _REGISTERED_FUNCTIONS["validators"][growth_method], + _REGISTERED_FUNCTIONS["preprocessors"][growth_method], ): preprocess_func(params[grow_type], distrs[grow_type]) diff --git a/neurots/preprocess/validity_checkers.py b/neurots/preprocess/validity_checkers.py index 87bc233a..b2928a88 100644 --- a/neurots/preprocess/validity_checkers.py +++ b/neurots/preprocess/validity_checkers.py @@ -23,7 +23,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from neurots.preprocess.relevancy_checkers import check_min_bar_length +from neurots.preprocess.exceptions import NeuroTSValidationError +from neurots.preprocess.relevance_checkers import check_min_bar_length from neurots.preprocess.utils import register_validator @@ -32,7 +33,7 @@ def check_num_seg(params, distrs): """Check that params contains a 'num_seg' entry.""" # pylint: disable=unused-argument if "num_seg" not in params: - raise KeyError( + raise NeuroTSValidationError( "The parameters must contain a 'num_seg' entry when the " "'growth_method' entry in parameters is 'trunk'." ) @@ -42,12 +43,12 @@ def check_num_seg(params, distrs): def check_bar_length(params, distrs): """Check consistency between parameters and persistence diagram.""" if "min_bar_length" not in distrs: - raise KeyError( + raise NeuroTSValidationError( "The distributions must contain a 'min_bar_length' entry when the " "'growth_method' entry in parameters is in ['tmd', 'tmd_apical', 'tmd_gradient']." ) if "mean" not in params.get("step_size", {}).get("norm", {}): - raise KeyError( + raise NeuroTSValidationError( "The parameters must contain a 'step_size' entry when the " "'growth_method' entry in parameters is in ['tmd', 'tmd_apical', 'tmd_gradient']." ) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 751378f7..ab09a81d 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -1,11 +1,17 @@ """Test the neurots.preprocess functions.""" +# pylint: disable=redefined-outer-name +# pylint: disable=unused-argument import json import logging +from collections import defaultdict from pathlib import Path import pytest from neurots import preprocess +from neurots.preprocess.exceptions import NeuroTSValidationError +from neurots.preprocess.utils import register_preprocessor +from neurots.preprocess.utils import register_validator from neurots.utils import convert_from_legacy_neurite_type DATA_PATH = Path(__file__).parent / "data" @@ -88,7 +94,7 @@ def test_check_num_seg(): params = {} with pytest.raises( - KeyError, + NeuroTSValidationError, match=( "The parameters must contain a 'num_seg' entry when the " "'growth_method' entry in parameters is 'trunk'." @@ -107,7 +113,7 @@ def test_check_min_bar_length(caplog): """ with pytest.raises( - KeyError, + NeuroTSValidationError, match=( r"The distributions must contain a 'min_bar_length' entry when the " r"'growth_method' entry in parameters is in \['tmd', 'tmd_apical', 'tmd_gradient'\]\." @@ -120,7 +126,7 @@ def test_check_min_bar_length(caplog): } with pytest.raises( - KeyError, + NeuroTSValidationError, match=( r"The parameters must contain a 'step_size' entry when the " r"'growth_method' entry in parameters is in \['tmd', 'tmd_apical', 'tmd_gradient'\]\." @@ -140,7 +146,7 @@ def test_check_min_bar_length(caplog): preprocess.validity_checkers.check_bar_length(params, distrs) assert caplog.record_tuples == [ ( - "neurots.preprocess.relevancy_checkers", + "neurots.preprocess.relevance_checkers", 30, "Selected step size 999.000000 is too big for bars of size 1.000000", ) @@ -157,3 +163,58 @@ def test_check_min_bar_length(caplog): with caplog.at_level(logging.DEBUG): preprocess.validity_checkers.check_bar_length(params, distrs) assert caplog.record_tuples == [] + + +@pytest.fixture +def dummy_register(monkeypatch): + """Monkeypatch the registered functions and reset it at the end of the test.""" + monkeypatch.setattr( + preprocess.utils, + "_REGISTERED_FUNCTIONS", + { + "preprocessors": defaultdict(set), + "validators": defaultdict(set), + }, + ) + + +def test_register_validator(dummy_register): + """Test validator registering.""" + with (DATA_PATH / "axon_trunk_parameters.json").open(encoding="utf-8") as f: + params = convert_from_legacy_neurite_type(json.load(f)) + with (DATA_PATH / "axon_trunk_distribution.json").open(encoding="utf-8") as f: + distrs = convert_from_legacy_neurite_type(json.load(f)) + + @register_validator("axon_trunk") + def dummy_validator(params, distrs): + assert params["randomness"] == 0 + assert distrs["num_trees"]["data"]["bins"] == [1] + + assert params["axon"]["randomness"] == 0 + assert distrs["axon"]["num_trees"]["data"]["bins"] == [1] + + preprocessed_params, preprocessed_distrs = preprocess.preprocess_inputs(params, distrs) + + assert preprocessed_params == params + assert preprocessed_distrs == distrs + + +def test_register_preprocessor(dummy_register): + """Test preprocessor registering.""" + with (DATA_PATH / "axon_trunk_parameters.json").open(encoding="utf-8") as f: + params = convert_from_legacy_neurite_type(json.load(f)) + with (DATA_PATH / "axon_trunk_distribution.json").open(encoding="utf-8") as f: + distrs = convert_from_legacy_neurite_type(json.load(f)) + + @register_preprocessor("axon_trunk") + def dummy_preprocessor(params, distrs): + params["randomness"] = 999 + distrs["num_trees"]["data"]["bins"] = [999] + + assert params["axon"]["randomness"] == 0 + assert distrs["axon"]["num_trees"]["data"]["bins"] == [1] + + preprocessed_params, preprocessed_distrs = preprocess.preprocess_inputs(params, distrs) + + assert preprocessed_params["axon"]["randomness"] == 999 + assert preprocessed_distrs["axon"]["num_trees"]["data"]["bins"] == [999]