Skip to content

Commit

Permalink
Some renaming and new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Nov 3, 2022
1 parent dccd9e8 commit 64bbe81
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 24 deletions.
2 changes: 1 addition & 1 deletion neurots/generate/algorithms/tmdgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from neurots.generate.algorithms.common import section_data
from neurots.morphmath import sample
from neurots.morphmath.utils import norm
from neurots.preprocess.relevancy_checkers import check_min_bar_length
from neurots.preprocess.relevance_checkers import check_min_bar_length

L = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions neurots/preprocess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from neurots.preprocess import relevancy_checkers # noqa
from neurots.preprocess import relevance_checkers # noqa
from neurots.preprocess import validity_checkers # noqa
from neurots.preprocess.utils import preprocess_inputs # noqa
from neurots.preprocess.utils import register_preprocess # noqa
from neurots.preprocess.utils import register_preprocessor # noqa
from neurots.preprocess.utils import register_validator # noqa
20 changes: 20 additions & 0 deletions neurots/preprocess/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Exceptions raised during validation process."""

# Copyright (C) 2022 Blue Brain Project, EPFL
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.


class NeuroTSValidationError(Exception):
"""Exception raised when a configuration set is not valid."""
5 changes: 2 additions & 3 deletions neurots/preprocess/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# TODO: remove the next line and the 'noqa' when a preprocess is added
# pylint: disable=unused-import
from neurots.preprocess.utils import register_preprocess # noqa
# TODO: uncomment the next line when a preprocess is added
# from neurots.preprocess.utils import register_preprocessor
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Functions to check that the given parameters and distributions will given relevant results.
"""Functions to check that the given parameters and distributions will give relevant results.
The functions used as relevancy checkers should have a name like 'check_*' and have the following
The functions used as relevance checkers should have a name like 'check_*' and have the following
signature: `check_something(params, distrs, start_point=None, context=None)`. The `start_point` and
`context` parameters should always be optional, as they will not be known during the preprocessing
step.
These functions can be called either in validity checkers or inside the grower codes.
"""

# Copyright (C) 2022 Blue Brain Project, EPFL
Expand Down
16 changes: 8 additions & 8 deletions neurots/preprocess/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@
from neurots.validator import validate_neuron_distribs
from neurots.validator import validate_neuron_params

_PREPROCESS_FUNCTIONS = {
"preprocess": defaultdict(set),
"validator": defaultdict(set),
_REGISTERED_FUNCTIONS = {
"preprocessors": defaultdict(set),
"validators": defaultdict(set),
}


def register_preprocess(*growth_methods):
def register_preprocessor(*growth_methods):
"""Register a preprocess function."""

def inner(func):
for i in growth_methods:
_PREPROCESS_FUNCTIONS["preprocess"][i].add(func)
_REGISTERED_FUNCTIONS["preprocessors"][i].add(func)
return func

return inner
Expand All @@ -44,7 +44,7 @@ def register_validator(*growth_methods):

def inner(func):
for i in growth_methods:
_PREPROCESS_FUNCTIONS["validator"][i].add(func)
_REGISTERED_FUNCTIONS["validators"][i].add(func)
return func

return inner
Expand All @@ -60,8 +60,8 @@ def preprocess_inputs(params, distrs):
for grow_type in params["grow_types"]:
growth_method = params[grow_type]["growth_method"]
for preprocess_func in chain(
_PREPROCESS_FUNCTIONS["validator"][growth_method],
_PREPROCESS_FUNCTIONS["preprocess"][growth_method],
_REGISTERED_FUNCTIONS["validators"][growth_method],
_REGISTERED_FUNCTIONS["preprocessors"][growth_method],
):
preprocess_func(params[grow_type], distrs[grow_type])

Expand Down
9 changes: 5 additions & 4 deletions neurots/preprocess/validity_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from neurots.preprocess.relevancy_checkers import check_min_bar_length
from neurots.preprocess.exceptions import NeuroTSValidationError
from neurots.preprocess.relevance_checkers import check_min_bar_length
from neurots.preprocess.utils import register_validator


Expand All @@ -32,7 +33,7 @@ def check_num_seg(params, distrs):
"""Check that params contains a 'num_seg' entry."""
# pylint: disable=unused-argument
if "num_seg" not in params:
raise KeyError(
raise NeuroTSValidationError(
"The parameters must contain a 'num_seg' entry when the "
"'growth_method' entry in parameters is 'trunk'."
)
Expand All @@ -42,12 +43,12 @@ def check_num_seg(params, distrs):
def check_bar_length(params, distrs):
"""Check consistency between parameters and persistence diagram."""
if "min_bar_length" not in distrs:
raise KeyError(
raise NeuroTSValidationError(
"The distributions must contain a 'min_bar_length' entry when the "
"'growth_method' entry in parameters is in ['tmd', 'tmd_apical', 'tmd_gradient']."
)
if "mean" not in params.get("step_size", {}).get("norm", {}):
raise KeyError(
raise NeuroTSValidationError(
"The parameters must contain a 'step_size' entry when the "
"'growth_method' entry in parameters is in ['tmd', 'tmd_apical', 'tmd_gradient']."
)
Expand Down
69 changes: 65 additions & 4 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Test the neurots.preprocess functions."""
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
import json
import logging
from collections import defaultdict
from pathlib import Path

import pytest

from neurots import preprocess
from neurots.preprocess.exceptions import NeuroTSValidationError
from neurots.preprocess.utils import register_preprocessor
from neurots.preprocess.utils import register_validator
from neurots.utils import convert_from_legacy_neurite_type

DATA_PATH = Path(__file__).parent / "data"
Expand Down Expand Up @@ -88,7 +94,7 @@ def test_check_num_seg():
params = {}

with pytest.raises(
KeyError,
NeuroTSValidationError,
match=(
"The parameters must contain a 'num_seg' entry when the "
"'growth_method' entry in parameters is 'trunk'."
Expand All @@ -107,7 +113,7 @@ def test_check_min_bar_length(caplog):
"""

with pytest.raises(
KeyError,
NeuroTSValidationError,
match=(
r"The distributions must contain a 'min_bar_length' entry when the "
r"'growth_method' entry in parameters is in \['tmd', 'tmd_apical', 'tmd_gradient'\]\."
Expand All @@ -120,7 +126,7 @@ def test_check_min_bar_length(caplog):
}

with pytest.raises(
KeyError,
NeuroTSValidationError,
match=(
r"The parameters must contain a 'step_size' entry when the "
r"'growth_method' entry in parameters is in \['tmd', 'tmd_apical', 'tmd_gradient'\]\."
Expand All @@ -140,7 +146,7 @@ def test_check_min_bar_length(caplog):
preprocess.validity_checkers.check_bar_length(params, distrs)
assert caplog.record_tuples == [
(
"neurots.preprocess.relevancy_checkers",
"neurots.preprocess.relevance_checkers",
30,
"Selected step size 999.000000 is too big for bars of size 1.000000",
)
Expand All @@ -157,3 +163,58 @@ def test_check_min_bar_length(caplog):
with caplog.at_level(logging.DEBUG):
preprocess.validity_checkers.check_bar_length(params, distrs)
assert caplog.record_tuples == []


@pytest.fixture
def dummy_register(monkeypatch):
"""Monkeypatch the registered functions and reset it at the end of the test."""
monkeypatch.setattr(
preprocess.utils,
"_REGISTERED_FUNCTIONS",
{
"preprocessors": defaultdict(set),
"validators": defaultdict(set),
},
)


def test_register_validator(dummy_register):
"""Test validator registering."""
with (DATA_PATH / "axon_trunk_parameters.json").open(encoding="utf-8") as f:
params = convert_from_legacy_neurite_type(json.load(f))
with (DATA_PATH / "axon_trunk_distribution.json").open(encoding="utf-8") as f:
distrs = convert_from_legacy_neurite_type(json.load(f))

@register_validator("axon_trunk")
def dummy_validator(params, distrs):
assert params["randomness"] == 0
assert distrs["num_trees"]["data"]["bins"] == [1]

assert params["axon"]["randomness"] == 0
assert distrs["axon"]["num_trees"]["data"]["bins"] == [1]

preprocessed_params, preprocessed_distrs = preprocess.preprocess_inputs(params, distrs)

assert preprocessed_params == params
assert preprocessed_distrs == distrs


def test_register_preprocessor(dummy_register):
"""Test preprocessor registering."""
with (DATA_PATH / "axon_trunk_parameters.json").open(encoding="utf-8") as f:
params = convert_from_legacy_neurite_type(json.load(f))
with (DATA_PATH / "axon_trunk_distribution.json").open(encoding="utf-8") as f:
distrs = convert_from_legacy_neurite_type(json.load(f))

@register_preprocessor("axon_trunk")
def dummy_preprocessor(params, distrs):
params["randomness"] = 999
distrs["num_trees"]["data"]["bins"] = [999]

assert params["axon"]["randomness"] == 0
assert distrs["axon"]["num_trees"]["data"]["bins"] == [1]

preprocessed_params, preprocessed_distrs = preprocess.preprocess_inputs(params, distrs)

assert preprocessed_params["axon"]["randomness"] == 999
assert preprocessed_distrs["axon"]["num_trees"]["data"]["bins"] == [999]

0 comments on commit 64bbe81

Please sign in to comment.