Skip to content

Commit

Permalink
Fix: Axon fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Apr 19, 2023
1 parent 4520023 commit fd5b184
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 134 deletions.
73 changes: 0 additions & 73 deletions src/synthesis_workflow/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import numpy as np
import pandas as pd
import yaml
from atlas_analysis.planes.planes import _smoothing
from atlas_analysis.planes.planes import create_centerline
from atlas_analysis.planes.planes import create_planes as _create_planes
from brainbuilder.app.cells import _place as place
from neurocollage.planes import get_cells_between_planes
from neurocollage.planes import slice_n_cells
Expand Down Expand Up @@ -184,76 +181,6 @@ def get_local_bbox(annotation):
)


def create_planes(
layer_annotation,
plane_type="aligned",
plane_count=10,
slice_thickness=100,
centerline_first_bound=None,
centerline_last_bound=None,
centerline_axis=0,
):
"""Create planes in an atlas.
We create 3 * plane_count such each triplet of planes define the left, center
and right plane of each slice.
Args:
layer_annotation (VoxelData): annotations with layers
plane_type (str): type of planes creation algorithm, two choices:
* centerline: centerline is computed between _first_bound and _last_bound with
internal algorithm (from atlas-analysis package)
* aligned: centerline is a straight line, along the centerline_axis
plane_count (int): number of planes to create slices of atlas,
slice_thickness (float): thickness of slices (in micrometer)
centerline_first_bound (list): (for plane_type == centerline) location of first bound
for centerline (in voxcell index)
centerline_last_bound (list): (for plane_type == centerline) location of last bound
for centerline (in voxcell index)
centerline_axis (str): (for plane_type = aligned) axis along which to create planes
"""
if plane_type == "centerline_straight":
if centerline_first_bound is None and centerline_last_bound is None:
centerline_first_bound, centerline_last_bound = get_centerline_bounds(layer_annotation)
centerline = np.array(
[
layer_annotation.indices_to_positions(centerline_first_bound),
layer_annotation.indices_to_positions(centerline_last_bound),
]
)
elif plane_type == "centerline_curved":
if centerline_first_bound is None and centerline_last_bound is None:
centerline_first_bound, centerline_last_bound = get_centerline_bounds(layer_annotation)
centerline = create_centerline(
layer_annotation, [centerline_first_bound, centerline_last_bound]
)
centerline = _smoothing(centerline)

elif plane_type == "aligned":
centerline = np.zeros([2, 3])
bbox = get_local_bbox(layer_annotation)
centerline[:, centerline_axis] = np.linspace(
bbox[0, centerline_axis], bbox[1, centerline_axis], 2
)
else:
raise ValueError(f"Please set plane_type to 'aligned' or 'centerline', not {plane_type}.")

# create all planes to match slice_thickness between every two planes
centerline_len = np.linalg.norm(np.diff(centerline, axis=0), axis=1).sum()
total_plane_count = int(centerline_len / slice_thickness) * 2 + 1
planes = _create_planes(centerline, plane_count=total_plane_count)

# select plane_count planes + direct left/right neighbors
planes_all_ids = np.arange(total_plane_count)
id_shift = int(total_plane_count / plane_count)
planes_select_ids = list(planes_all_ids[int(id_shift / 2) :: id_shift])
planes_select_ids += list(planes_all_ids[int(id_shift / 2) - 1 :: id_shift])
planes_select_ids += list(planes_all_ids[int(id_shift / 2) + 1 :: id_shift])
return [planes[i] for i in sorted(planes_select_ids)], centerline


def get_layer_tags(atlas_dir, region_structure_path, region=None):
"""Create a VoxelData with layer tags."""
atlas_helper = AtlasHelper(LocalAtlas(atlas_dir), region_structure_path=region_structure_path)
Expand Down
8 changes: 5 additions & 3 deletions src/synthesis_workflow/tasks/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,10 @@ class SynthesisConfig(luigi.Config):
),
schema={"type": "array", "items": {"type": "string"}},
)
with_axons = luigi.BoolParameter(
default=False, description=":bool: Set to True to synthesize local axons"
axon_method = luigi.ChoiceParameter(
default="no_axon",
description=":str: The method used to handle axons.",
choices=["no_axon", "reconstructed", "synthesis"],
)


Expand Down Expand Up @@ -272,7 +274,7 @@ class PathConfig(luigi.Config):
# Default internal values
ext = ExtParameter(default="asc", description=":str: Default extension used.")
morphology_path = luigi.Parameter(
default="morphology_path",
default="path",
description="Column name in the morphology dataframe to access morphology paths",
)
morphs_df_path = luigi.Parameter(
Expand Down
38 changes: 14 additions & 24 deletions src/synthesis_workflow/tasks/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from luigi_tools.task import ParamRef
from luigi_tools.task import WorkflowTask
from luigi_tools.task import copy_params
from neurom import load_morphology
from neurom.check.morphology_checks import has_apical_dendrite
from neurots import extract_input
from neurots.generate.orientations import fit_3d_angles
from neurots.validator import validate_neuron_distribs
Expand Down Expand Up @@ -154,23 +152,15 @@ def run(self):
morphs_df = pd.read_csv(self.input()["morphologies"].path)
mtypes = sorted(morphs_df.mtype.unique())
neurite_types = get_neurite_types(morphs_df)
if SynthesisConfig().with_axons:
for neurite_type in neurite_types.items():
if SynthesisConfig().axon_method != "no_axon":
for neurite_type in neurite_types.values():
neurite_type.append("axon")

tmd_parameters = {}
for mtype in tqdm(mtypes):
neurite_types = ["basal_dendrite"]
if has_apical_dendrite(
load_morphology(morphs_df.loc[morphs_df.mtype == mtype, "path"].to_list()[0])
):
neurite_types.append("apical_dendrite")
if SynthesisConfig().with_axons:
neurite_types.append("axon")

config = DiametrizerConfig().config_diametrizer
config["neurite_types"] = neurite_types
kwargs = {"neurite_types": neurite_types, "diameter_parameters": config}
config["neurite_types"] = neurite_types[mtype]
kwargs = {"neurite_types": neurite_types[mtype], "diameter_parameters": config}
tmd_parameters[mtype] = extract_input.parameters(**kwargs)

with self.output().open("w") as f:
Expand Down Expand Up @@ -228,8 +218,8 @@ def run(self):
L.debug("mtypes found: %s", mtypes)

neurite_types = get_neurite_types(morphs_df)
if SynthesisConfig().with_axons:
for neurite_type in neurite_types.items():
if SynthesisConfig().axon_method != "no_axon":
for neurite_type in neurite_types.values():
neurite_type.append("axon")
L.debug("neurite_types found: %s", neurite_types)

Expand Down Expand Up @@ -498,11 +488,6 @@ class Synthesize(WorkflowTask):
description=":float: The std value of the scaling jitter to apply (in degrees).",
)
seed = luigi.IntParameter(default=0, description=":int: Pseudo-random generator seed.")
axon_method = luigi.ChoiceParameter(
default="reconstructed",
description=":str: The method used to handle axons.",
choices=["no_axon", "reconstructed", "synthesis"],
)

def requires(self):
"""Required input tasks."""
Expand All @@ -514,7 +499,7 @@ def requires(self):
"circuit": SliceCircuit(),
"composition": GetCellComposition(),
}
if self.axon_method == "reconstructed":
if SynthesisConfig().axon_method == "reconstructed":
tasks["axons"] = BuildAxonMorphologies()
tasks["axon_cells"] = BuildAxonMorphsDF(
neurondb_path=BuildAxonMorphologies().get_neuron_db_path("xml"),
Expand All @@ -537,7 +522,7 @@ def run(self):

axon_morphs_path = None
axon_morphs_base_dir = None
if self.axon_method == "reconstructed":
if SynthesisConfig().axon_method == "reconstructed":
axon_morphs_path = self.input()["axons"]["morphs"].path
if self.axon_morphs_base_dir is None:
axon_morphs_base_dir = get_axon_base_dir(
Expand Down Expand Up @@ -566,7 +551,7 @@ def run(self):
"region_structure": self.input()["synthesis_input"].pathlib_path
/ CircuitConfig().region_structure_path,
}
if self.axon_method == "reconstructed" and self.apply_jitter:
if SynthesisConfig().axon_method == "reconstructed" and self.apply_jitter:
kwargs["scaling_jitter_std"] = self.scaling_jitter_std
kwargs["rotational_jitter_std"] = self.rotational_jitter_std

Expand Down Expand Up @@ -732,6 +717,11 @@ def run(self):
custom_parameters = pd.read_csv(custom_path)
apply_parameter_diff(tmd_parameters, custom_parameters)

# if we are no_axon, ensure tmd_parameters has no axon data, or json schema may crash
if SynthesisConfig().axon_method == "no_axon":
for mtype in tmd_parameters:
tmd_parameters[mtype]["axon"] = {}

with self.output().open("w") as f:
json.dump(tmd_parameters, f, cls=NumpyEncoder, indent=4, sort_keys=True)

Expand Down
5 changes: 0 additions & 5 deletions src/synthesis_workflow/tasks/vacuum_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ class VacuumSynthesize(WorkflowTask):
description=":str: Diametrizer model to use.",
)
n_cells = luigi.IntParameter(default=10, description=":int: Number of cells to synthesize.")
axon_method = luigi.ChoiceParameter(
default="no_axon",
description=":str: The method used to handle axons.",
choices=["no_axon", "synthesis"],
)

def requires(self):
"""Required input tasks."""
Expand Down
4 changes: 2 additions & 2 deletions src/synthesis_workflow/tasks/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class PlotMorphometrics(WorkflowTask):
description=":str: Path to output directory (relative from ``PathConfig.result_path``).",
)
base_key = luigi.Parameter(
default="morphology_path",
default="path",
description=":str: Base key to use in the morphology DataFrame.",
)
comp_key = luigi.Parameter(
Expand Down Expand Up @@ -669,7 +669,7 @@ class TrunkValidation(WorkflowTask):
description=":str: Path to output directory (relative from ``PathConfig.result_path``).",
)
base_key = luigi.Parameter(
default="morphology_path",
default="path",
description=":str: Base key to use in the morphology DataFrame.",
)
comp_key = luigi.Parameter(
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Package version."""
VERSION = "0.1.2"
VERSION = "0.1.3.dev0"
21 changes: 0 additions & 21 deletions tests/data/in_small_O1/out/morphs_df/axon_morphs.tsv

This file was deleted.

5 changes: 0 additions & 5 deletions tests/data/in_small_O1/out/morphs_df/axon_morphs_df.csv

This file was deleted.

0 comments on commit fd5b184

Please sign in to comment.