Skip to content

Commit

Permalink
Improve mechanism for global parameters
Browse files Browse the repository at this point in the history
Introduce a decorator similar to luigi.inherits where the parameters
can have different names or default values from the copied ones.
Also, when used on a GlobalParamTask class, these copied parameters are made global,
i.e. they take the source value when their own value is None.

Add missing tests for synthesis_workflow.fit_utils module.

Change-Id: I8e516bc52b3071f31f9bf8fc53c595c4f5307d9b
  • Loading branch information
adrien-berchet committed Oct 19, 2020
1 parent 2b5a492 commit be274b3
Show file tree
Hide file tree
Showing 15 changed files with 429 additions and 92 deletions.
1 change: 0 additions & 1 deletion synthesis_workflow/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import matplotlib
import numpy as np
import pandas as pd
import yaml
from joblib import delayed
from joblib import Parallel
from tqdm import tqdm
Expand Down
9 changes: 7 additions & 2 deletions synthesis_workflow/tasks/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from ..circuit import create_planes
from ..tools import ensure_dir
from .config import CircuitConfig
from .utils import GlobalParamTask
from .config import SynthesisConfig
from .luigi_tools import copy_params
from .luigi_tools import GlobalParamTask
from .luigi_tools import ParamLink


class CreateAtlasLayerAnnotations(GlobalParamTask):
Expand Down Expand Up @@ -123,6 +126,9 @@ def output(self):
return luigi.LocalTarget(self.atlas_planes_path + ".npz")


@copy_params(
mtypes=ParamLink(SynthesisConfig),
)
class SliceCircuit(GlobalParamTask):
"""Create a smaller circuit .mvd3 file for subsampling.
Expand All @@ -134,7 +140,6 @@ class SliceCircuit(GlobalParamTask):
"""

sliced_circuit_path = luigi.Parameter(default="sliced_circuit_somata.mvd3")
mtypes = luigi.ListParameter(default=None)
n_cells = luigi.IntParameter(default=10)
hemisphere = luigi.Parameter(default=None)

Expand Down
2 changes: 1 addition & 1 deletion synthesis_workflow/tasks/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def parse_args(self, argv):

def _setup_logging(log_level, log_file=None, log_file_level=None):
"""Setup logging"""
setup_logging(log_level, log_file_level, log_file_level)
setup_logging(log_level, log_file, log_file_level)


def main(arguments=None):
Expand Down
10 changes: 8 additions & 2 deletions synthesis_workflow/tasks/diametrizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

from ..tools import update_morphs_df
from .config import DiametrizerConfig
from .config import RunnerConfig
from .luigi_tools import copy_params
from .luigi_tools import GlobalParamTask
from .luigi_tools import ParamLink


matplotlib.use("Agg")
Expand Down Expand Up @@ -60,15 +64,17 @@ def _plot_models(models_params, models_data, fig_folder="figures", ext=".png"):
)


class BuildDiameterModels(luigi.Task):
@copy_params(
nb_jobs=ParamLink(RunnerConfig),
)
class BuildDiameterModels(GlobalParamTask):
"""Task to build diameter models from set of cells."""

morphs_df_path = luigi.Parameter(default="morphs_df.csv")
morphology_path = luigi.Parameter(default="morphology_path")
diameter_models_path = luigi.Parameter(default="diameter_models.yaml")
by_mtypes = luigi.BoolParameter()
plot_models = luigi.BoolParameter()
nb_jobs = luigi.IntParameter(default=-1)

def run(self):
"""Run."""
Expand Down
118 changes: 118 additions & 0 deletions synthesis_workflow/tasks/luigi_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""utils functions for luigi parameters."""
import logging
import re
from copy import deepcopy

import luigi

Expand Down Expand Up @@ -39,10 +40,127 @@ def log_parameters(task):
L.debug("Can't print '%s' attribute for unknown reason", name)


class GlobalParamTask(luigi.Task):
"""Base class used to add customisable global parameters"""

def __getattribute__(self, name):
tmp = super().__getattribute__(name)
if tmp is not None:
return tmp
if hasattr(self, "_global_params"):
global_param = self._global_params.get(name)
if global_param is not None:
return getattr(global_param.cls(), global_param.name)
return tmp

def __setattr__(self, name, value):
if value is None and name in self.get_param_names():
L.warning(
"The Parameter '%s' of the task '%s' is set to None, thus the global "
"value will be taken frow now on",
name,
self.__class__.__name__,
)
return super().__setattr__(name, value)


class BaseWrapperTask(GlobalParamTask, luigi.WrapperTask):
"""Base wrapper class with global parameters"""


class ExtParameter(luigi.Parameter):
"""Class to parse file extension parameters"""

def parse(self, x):
pattern = re.compile(r"\.?(.*)")
match = re.match(pattern, x)
return match.group(1)


class ParamLink:
"""Class to store parameter linking informations"""

def __init__(self, cls, name=None, default=None):
self.cls = cls
self.name = name
self.default = default


class copy_params:
"""
Copy a parameter from another Task.
If no value is given to this parameter, the value from the other task is taken.
**Usage**:
.. code-block:: python
class AnotherTask(luigi.Task):
m = luigi.IntParameter(default=1)
@copy_params(m=ParamLink(AnotherTask))
class MyFirstTask(luigi.Task):
def run(self):
print(self.m) # this will be defined and print 1
# ...
@copy_params(another_m=ParamLink(AnotherTask, "m"))
class MySecondTask(luigi.Task):
def run(self):
print(self.another_m) # this will be defined and print 1
# ...
@copy_params(another_m=ParamLink(AnotherTask, "m", 5))
class MyFirstTask(luigi.Task):
def run(self):
print(self.another_m) # this will be defined and print 5
# ...
@copy_params(another_m=ParamLink(AnotherTask, "m"))
class MyFirstTask(GlobalParamTask):
def run(self):
print(self.another_m) # this will be defined and print 1 if self.another_m is None
# ...
"""

def __init__(self, **params_to_copy):
"""Init."""
super().__init__()
if not params_to_copy:
raise TypeError("params_to_copy cannot be empty")

self.params_to_copy = params_to_copy

def __call__(self, task_that_inherits):
"""Call."""
# Get all parameters
for param_name, attr in self.params_to_copy.items():
# Check if the parameter exists in the inheriting task
if not hasattr(task_that_inherits, param_name):
if attr.name is None:
attr.name = param_name
par = getattr(attr.cls, attr.name)

# Copy param
new_param = deepcopy(par)

# Set default value is required
if attr.default is not None:
new_param._default = attr.default
elif (
issubclass(task_that_inherits, GlobalParamTask)
and attr.default is None
):
new_param._default = None

# Add it to the inheriting task with new default values
setattr(task_that_inherits, param_name, new_param)

# Add link to global parameter
if issubclass(task_that_inherits, GlobalParamTask):
task = task_that_inherits
if not hasattr(task, "_global_params"):
task._global_params = {}
task._global_params[param_name] = attr

return task_that_inherits
40 changes: 27 additions & 13 deletions synthesis_workflow/tasks/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from .config import CircuitConfig
from .config import DiametrizerConfig
from .config import PathConfig
from .config import RunnerConfig
from .config import SynthesisConfig
from .luigi_tools import ExtParameter
from .utils import GlobalParamTask
from .luigi_tools import copy_params
from .luigi_tools import GlobalParamTask
from .luigi_tools import ParamLink


morphio.set_maximum_warnings(0)
Expand Down Expand Up @@ -60,6 +62,9 @@ def output(self):
return luigi.LocalTarget(PathConfig().substituted_morphs_df_path)


@copy_params(
tmd_parameters_path=ParamLink(SynthesisConfig),
)
class BuildSynthesisParameters(GlobalParamTask):
"""Build the tmd_parameter.json for synthesis.
Expand All @@ -69,7 +74,6 @@ class BuildSynthesisParameters(GlobalParamTask):
"""

input_tmd_parameters_path = luigi.Parameter(default=None)
tmd_parameters_path = luigi.Parameter(default=None)

def requires(self):
""""""
Expand Down Expand Up @@ -112,15 +116,16 @@ def output(self):
return luigi.LocalTarget(self.tmd_parameters_path)


@copy_params(
morphology_path=ParamLink(PathConfig),
)
class BuildSynthesisDistributions(GlobalParamTask):
"""Build the tmd_distribution.json for synthesis.
Args:
morphology_path (str): column name in morphology dataframe to access morphology paths
"""

morphology_path = luigi.Parameter(default=None)

def requires(self):
""""""
return ApplySubstitutionRules()
Expand Down Expand Up @@ -160,6 +165,9 @@ def requires(self):
return [BuildSynthesisParameters(), BuildSynthesisDistributions()]


@copy_params(
nb_jobs=ParamLink(RunnerConfig),
)
class BuildAxonMorphologies(GlobalParamTask):
"""Run choose-morphologies to synthesize axon morphologies.
Expand All @@ -177,7 +185,6 @@ class BuildAxonMorphologies(GlobalParamTask):
placement_alpha = luigi.FloatParameter(default=1.0)
placement_scales = luigi.ListParameter(default=None)
placement_seed = luigi.IntParameter(default=0)
nb_jobs = luigi.IntParameter(default=None)

def requires(self):
""""""
Expand Down Expand Up @@ -220,6 +227,11 @@ def output(self):
return luigi.LocalTarget(self.axon_morphs_path)


@copy_params(
ext=ParamLink(PathConfig),
morphology_path=ParamLink(PathConfig),
nb_jobs=ParamLink(RunnerConfig),
)
class Synthesize(GlobalParamTask):
"""Run placement-algorithm to synthesize morphologies.
Expand All @@ -234,11 +246,8 @@ class Synthesize(GlobalParamTask):
"""

out_circuit_path = luigi.Parameter(default="sliced_circuit_morphologies.mvd3")
ext = ExtParameter(default=None)
axon_morphs_base_dir = luigi.OptionalParameter(default=None)
apical_points_path = luigi.Parameter(default="apical_points.yaml")
morphology_path = luigi.Parameter(default=None)
nb_jobs = luigi.IntParameter(default=None)
debug_region_grower_scales = luigi.BoolParameter(default=False)

def requires(self):
Expand Down Expand Up @@ -307,13 +316,15 @@ def output(self):
return luigi.LocalTarget(self.out_circuit_path)


@copy_params(
morphology_path=ParamLink(PathConfig),
tmd_parameters_path=ParamLink(SynthesisConfig),
nb_jobs=ParamLink(RunnerConfig),
)
class AddScalingRulesToParameters(GlobalParamTask):
"""Add scaling rules to tmd_parameter.json."""

scaling_rules_path = luigi.Parameter(default="scaling_rules.yaml")
tmd_parameters_path = luigi.Parameter(default=None)
morphology_path = luigi.Parameter(default=None)
nb_jobs = luigi.IntParameter(default=None)

def requires(self):
""""""
Expand Down Expand Up @@ -348,10 +359,13 @@ def output(self):
return luigi.LocalTarget(self.tmd_parameters_path)


@copy_params(
morphology_path=ParamLink(PathConfig),
nb_jobs=ParamLink(RunnerConfig),
)
class RescaleMorphologies(GlobalParamTask):
"""Rescale morphologies for synthesis input."""

morphology_path = luigi.Parameter(default=None)
rescaled_morphology_path = luigi.Parameter(default="rescaled_morphology_path")
rescaled_morphology_base_path = luigi.Parameter(default="rescaled_morphologies")
scaling_rules_path = luigi.Parameter(default="scaling_rules.yaml")
Expand Down
49 changes: 0 additions & 49 deletions synthesis_workflow/tasks/utils.py

This file was deleted.

Loading

0 comments on commit be274b3

Please sign in to comment.