Skip to content

Commit

Permalink
Rework input validation process
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Oct 4, 2022
1 parent a17e5a0 commit 7dd592e
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 35 deletions.
19 changes: 19 additions & 0 deletions neurots/generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,22 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from neurots.generate.algorithms import basicgrower
from neurots.generate.algorithms import tmdgrower
from neurots.generate.section import SectionGrower
from neurots.generate.section import SectionGrowerPath
from neurots.generate.section import SectionGrowerTMD

growth_algorithms = {
"tmd": tmdgrower.TMDAlgo,
"tmd_apical": tmdgrower.TMDApicalAlgo,
"tmd_gradient": tmdgrower.TMDGradientAlgo,
"axon_trunk": basicgrower.AxonAlgo,
"trunk": basicgrower.TrunkAlgo,
}

section_growers = {
"radial_distances": SectionGrowerTMD,
"path_distances": SectionGrowerPath,
"trunk_length": SectionGrower,
}
20 changes: 19 additions & 1 deletion neurots/generate/algorithms/abstractgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

import abc
import copy
from enum import Enum


class SkipValidationLevel(Enum):
SKIP_NONE = 0
SKIP_OPTIONAL_ONLY = 10
SKIP_ALL = 100


class AbstractAlgo:
Expand All @@ -32,13 +39,24 @@ class AbstractAlgo:
# meta class is used to define other classes
__metaclass__ = abc.ABCMeta

def __init__(self, input_data, params, start_point, context):
def __init__(self, input_data, params, start_point, context=None, skip_validation=False):
"""The TreeGrower Algorithm initialization."""
self.context = context
self.input_data = copy.deepcopy(input_data)
self.params = copy.deepcopy(params)
self.start_point = start_point

if not skip_validation:
if skip_validation is True:
skip_validation = SkipValidationLevel.SKIP_ALL
elif skip_validation is False:
skip_validation = SkipValidationLevel.SKIP_NONE
self.preprocess_inputs(params, input_data, skip_validation)

@abc.abstractclassmethod
def preprocess_inputs(cls, params, distrs, skip_validation_level=False):
"""Preprocess all inputs for the given class."""

@abc.abstractmethod
def initialize(self):
"""Abstract TreeGrower Algorithm initialization.
Expand Down
24 changes: 22 additions & 2 deletions neurots/generate/algorithms/basicgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from neurots.generate.algorithms.abstractgrower import AbstractAlgo
from neurots.generate.algorithms.common import bif_methods
from neurots.generate.algorithms.common import growth_algorithms
from neurots.generate.algorithms.common import section_data

logger = logging.getLogger(__name__)
Expand All @@ -36,11 +37,22 @@ class TrunkAlgo(AbstractAlgo):
context (Any): An object containing contextual information.
"""

def __init__(self, input_data, params, start_point, context=None, **_):
def __init__(self, input_data, params, start_point, context=None, skip_validation=False, **_):
"""Constructor of the TrunkAlgo class."""
super().__init__(input_data, params, start_point, context)
super().__init__(
input_data, params, start_point, context=context, skip_validation=skip_validation
)
self.bif_method = bif_methods[params["branching_method"]]

@classmethod
def preprocess_inputs(cls, params, distrs, skip_validation_level=False):
"""Preprocess all inputs for the given class."""
if "num_seg" not in params:
raise KeyError(
"The parameters must contain a 'num_seg' entry when the "
"'growth_method' entry in parameters is 'trunk'."
)

def initialize(self):
"""Generates the data to be used for the initialization of the first section to be grown.
Expand Down Expand Up @@ -100,3 +112,11 @@ def __init__(self, *args, **kwargs):
params["num_seg"] = 1

super().__init__(*args, **kwargs)


growth_algorithms.update(
{
"axon_trunk": AxonAlgo,
"trunk": TrunkAlgo,
}
)
9 changes: 9 additions & 0 deletions neurots/generate/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from neurots.morphmath import bifurcation as _bif

growth_algorithms = {}

bif_methods = {
"bio_oriented": _bif.bio_oriented,
"symmetric": _bif.symmetric,
Expand All @@ -27,6 +29,13 @@
}


def get_grower_name(grower):
grower_name = [k for k, v in growth_algorithms.items() if v == grower]
if not grower_name:
raise ValueError(f"No name can be found for the grower {grower}.")
return grower_name[0]


def checks_bif_term(ref, bif, term, target_length):
"""Check bif/term.
Expand Down
54 changes: 42 additions & 12 deletions neurots/generate/algorithms/tmdgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import numpy as np

from neurots.generate.algorithms.abstractgrower import AbstractAlgo
from neurots.generate.algorithms.abstractgrower import SkipValidationLevel
from neurots.generate.algorithms.barcode import Barcode
from neurots.generate.algorithms.common import TMDStop
from neurots.generate.algorithms.common import bif_methods
from neurots.generate.algorithms.common import get_grower_name
from neurots.generate.algorithms.common import growth_algorithms
from neurots.generate.algorithms.common import section_data
from neurots.morphmath import sample
from neurots.morphmath.utils import norm
Expand Down Expand Up @@ -56,7 +59,9 @@ def __init__(
**_,
):
"""TMD basic grower."""
super().__init__(input_data, params, start_point, context)
super().__init__(
input_data, params, start_point, context=context, skip_validation=skip_validation
)
self.bif_method = bif_methods[params["branching_method"]]
self.params = copy.deepcopy(params)
self.ph_angles = self.select_persistence(input_data, random_generator)
Expand All @@ -65,17 +70,33 @@ def __init__(
self.apical_section = None
self.apical_point_distance_from_soma = 0.0
self.persistence_length = self.barcode.get_persistence_length()
# Validate parameters and distributions
if not skip_validation:
# Consistency check between parameters - persistence diagram
barSZ = input_data["min_bar_length"]
stepSZ = params["step_size"]["norm"]["mean"]
if stepSZ >= barSZ:
L.warning(
"Selected step size %f is too big for bars of size %f",
stepSZ,
barSZ,
)

@classmethod
def preprocess_inputs(cls, params, distrs, skip_validation_level=False):
"""Preprocess all inputs for the given class."""
# Check consistency between parameters and persistence diagram.
try:
barSZ = distrs["min_bar_length"]
except KeyError as exc:
if skip_validation_level == SkipValidationLevel.SKIP_OPTIONAL_ONLY:
# Here we just raise a warning
barSZ = -1
elif skip_validation_level == SkipValidationLevel.SKIP_ALL:
# Here we do nothing
barSZ = float("inf")
else:
grower_name = get_grower_name(cls)
raise KeyError(
"The distributions must contain a 'min_bar_length' entry when the "
f"'growth_method' entry in parameters is '{grower_name}'."
) from exc
stepSZ = params["step_size"]["norm"]["mean"]
if stepSZ >= barSZ:
L.warning(
"Selected step size %f is too big for bars of size %f",
stepSZ,
barSZ,
)

def select_persistence(self, input_data, random_generator=np.random):
"""Select the persistence.
Expand Down Expand Up @@ -342,3 +363,12 @@ def bifurcate(self, current_section):
current_section, s2["stop"]["TMD"], s2["process"], s2["direction"]
)
return s1, s2


growth_algorithms.update(
{
"tmd": TMDAlgo,
"tmd_apical": TMDApicalAlgo,
"tmd_gradient": TMDGradientAlgo,
}
)
7 changes: 7 additions & 0 deletions neurots/generate/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,10 @@ class SectionGrowerPath(SectionGrowerExponentialProba):
def get_val(self):
"""Returns path distance."""
return self.pathlength


section_growers = {
"radial_distances": SectionGrowerTMD,
"path_distances": SectionGrowerPath,
"trunk_length": SectionGrower,
}
31 changes: 12 additions & 19 deletions neurots/generate/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
from morphio import SectionType

from neurots.generate.algorithms import basicgrower
from neurots.generate.algorithms import tmdgrower
from neurots.generate.section import SectionGrower
from neurots.generate.section import SectionGrowerPath
from neurots.generate.section import SectionGrowerTMD
from neurots.generate.algorithms.common import growth_algorithms
from neurots.generate.section import section_growers

# from neurots.generate.algorithms import tmdgrower
# from neurots.generate.section import SectionGrower
# from neurots.generate.section import SectionGrowerPath
# from neurots.generate.section import SectionGrowerTMD
from neurots.morphmath import sample
from neurots.utils import NeuroTSError

Expand All @@ -37,20 +40,6 @@
# LAMBDA: parameter that defines the slope of exponential probability
LAMBDA = 1.0

growth_algorithms = {
"tmd": tmdgrower.TMDAlgo,
"tmd_apical": tmdgrower.TMDApicalAlgo,
"tmd_gradient": tmdgrower.TMDGradientAlgo,
"axon_trunk": basicgrower.AxonAlgo,
"trunk": basicgrower.TrunkAlgo,
}

section_growers = {
"radial_distances": SectionGrowerTMD,
"path_distances": SectionGrowerPath,
"trunk_length": SectionGrower,
}


# Section grower parameters
SectionParameters = namedtuple(
Expand Down Expand Up @@ -129,9 +118,13 @@ def __init__(
self._section_parameters = _create_section_parameters(parameters)
self.growth_algo = self._initialize_algorithm()

@staticmethod
def select_grow_method(params):
return growth_algorithms[params["growth_method"]]

def _initialize_algorithm(self):
"""Initialization steps for TreeGrower."""
grow_meth = growth_algorithms[self.params["growth_method"]]
grow_meth = self.select_grow_method(self.params)

growth_algo = grow_meth(
input_data=self.distr,
Expand Down
11 changes: 11 additions & 0 deletions neurots/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import jsonschema
import pkg_resources

from neurots.generate.tree import TreeGrower

SCHEMA_PATH = pkg_resources.resource_filename("neurots", "schemas")

with Path(SCHEMA_PATH, "parameters.json").open(encoding="utf-8") as f:
Expand Down Expand Up @@ -65,3 +67,12 @@ def validate_neuron_params(data):
def validate_neuron_distribs(data):
"""Validate distribution dictionary."""
validate(data, DISTRIBS_SCHEMA)


def preprocess_inputs(params, distrs, skip_validation_level=False):
"""Validate and preprocess all inputs."""
validate_neuron_params(params)
validate_neuron_distribs(distrs)

grow_method = TreeGrower.select_grow_method(params)
grow_method.preprocess_inputs(params, distrs, skip_validation_level=skip_validation_level)
24 changes: 24 additions & 0 deletions tests/test_neuron_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ def test_skip_rng_and_validation(self):
rng_or_seed=build_random_generator(0),
)

def test_min_bar_length_missing(self, tmpdir):
"""Check that the proper exception is raised when the min_bar_length entry is missing."""
with open(join(_path, "bio_distribution.json"), encoding="utf-8") as f:
distributions = json.load(f)
del distributions["apical_dendrite"]["min_bar_length"]
broken_distribution = tmpdir / "bio_distribution.json"
with open(broken_distribution, mode="w", encoding="utf-8") as f:
json.dump(distributions, f)

with pytest.raises(
KeyError,
match=(
"The distributions must contain a 'min_bar_length' entry when the "
"'growth_method' entry in parameters is 'tmd_apical'."
),
):
_test_full(
"path_distances",
broken_distribution,
"bio_path_params.json",
"path_grower.h5",
"bio_path_persistence_diagram.json",
)


class TestGradientPathGrower:
"""test tmd_path"""
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ setenv =
passenv =
PIP_EXTRA_INDEX_URL
commands = pytest \
-n {env:PYTEST_NPROCS:'auto'} \
-n {env:PYTEST_NPROCS:4} \
--basetemp={envtmpdir} \
--cov={[base]name} \
--cov-branch \
Expand Down

0 comments on commit 7dd592e

Please sign in to comment.