Skip to content

Commit

Permalink
Fix: Make bifurcation angles globally invariant (#95)
Browse files Browse the repository at this point in the history
* Fix: Ensure gorwth in a section is not biased by parent direction

* Fix: Make bifurcation angles globally invariant

* remove history

* fix test

* add test

* lint

* cov

* better impl

* minor

* cov

* revert small thing

* test

* None default

* better impl

* cleanup

* more

* better

* fix astro

* cov

* cov 98

* fix example licence

* fix IPC
  • Loading branch information
arnaudon authored Nov 25, 2024
1 parent 2ef760f commit f34dc8d
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 53 deletions.
90 changes: 90 additions & 0 deletions examples/synthesize_single_neuron_y_direction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# noqa

# Copyright (C) 2021-2024 Blue Brain Project, EPFL
#
# SPDX-License-Identifier: Apache-2.0

"""
Synthesize a single neuron with global direction
================================================
This example shows how to synthesize a single cell with different y directions
"""

import json
from pathlib import Path

import numpy as np
from morph_tool.transform import rotate

import neurots


def run(output_dir, data_dir):
"""Run the example for generating a single cell."""
np.random.seed(42)
with open(data_dir / "bio_params.json", encoding="utf-8") as p_file:
params = json.load(p_file)
# use trunk angle with y_direction awareness
params["basal_dendrite"]["orientation"] = {
"mode": "pia_constraint",
"values": {"form": "step", "params": [1.5, 0.25]},
}
params["apical_dendrite"]["orientation"] = {
"mode": "normal_pia_constraint",
"values": {"direction": {"mean": [0.0], "std": [0.0]}},
}

# Initialize a neuron
N = neurots.NeuronGrower(
input_distributions=data_dir / "bio_distr.json",
input_parameters=params,
)

# Grow the neuron
neuron = N.grow()

# Export the synthesized cell
neuron.write(output_dir / "generated_cell_orig.asc")

np.random.seed(42)

# Initialize a neuron
N = neurots.NeuronGrower(
input_distributions=data_dir / "bio_distr.json",
input_parameters=params,
# context={"y_direction": [0.0, 1.0, 0.0]},
)

# Grow the neuron
neuron = N.grow()

# Export the synthesized cell
neuron.write(output_dir / "generated_cell_y.asc")

np.random.seed(42)

# Initialize a neuron
N = neurots.NeuronGrower(
input_distributions=data_dir / "bio_distr.json",
input_parameters=params,
context={"y_direction": [1.0, 0.0, 0.0]},
)

# Grow the neuron
neuron = N.grow()

# Export the synthesized cell
neuron.write(output_dir / "generated_cell_x.asc")

# the rotated neuron should be the same as original one

rotate(neuron, [[0, -1, 0], [1, 0, 0], [0, 0, 1]])
neuron.write(output_dir / "generated_cell_x_rot.asc")


if __name__ == "__main__":
result_dir = Path("results_single_neuron")
result_dir.mkdir(parents=True, exist_ok=True)

run(result_dir, Path("data"))
4 changes: 2 additions & 2 deletions neurots/extract_input/from_neurom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from neurom import stats
from neurom.features.morphology import trunk_vectors

from neurots.utils import PIA_DIRECTION
from neurots.utils import Y_DIRECTION


def transform_distr(opt_distr):
Expand Down Expand Up @@ -130,7 +130,7 @@ def trunk_neurite_3d_angles(pop, neurite_type, bins):
apical_3d_angles = []
for morph in pop.morphologies:
vecs = trunk_vectors(morph, neurite_type=neurite_type)
pia_3d_angles += [nm.morphmath.angle_between_vectors(PIA_DIRECTION, vec) for vec in vecs]
pia_3d_angles += [nm.morphmath.angle_between_vectors(Y_DIRECTION, vec) for vec in vecs]
if neurite_type.name != "apical_dendrite":
apical_ref_vec = trunk_vectors(morph, neurite_type=nm.APICAL_DENDRITE)
if len(apical_ref_vec) > 0:
Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/algorithms/abstractgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AbstractAlgo:

def __init__(self, input_data, params, start_point, context):
"""The TreeGrower Algorithm initialization."""
self.context = context
self.context = context if context is not None else {}
self.input_data = copy.deepcopy(input_data)
self.params = copy.deepcopy(params)
self.start_point = start_point
Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/algorithms/basicgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def bifurcate(self, current_section):
Returns:
tuple[dict, dict]: Two dictionaries containing the two children sections data.
"""
dir1, dir2 = self.bif_method()
dir1, dir2 = self.bif_method(y_rotation=self.context.get("y_rotation"))
first_point = np.array(current_section.last_point)
stop = current_section.stop_criteria

Expand Down
12 changes: 9 additions & 3 deletions neurots/generate/algorithms/tmdgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def bifurcate(self, current_section):
self.barcode.remove_bif(current_section.stop_criteria["TMD"].bif_id)
ang = self.barcode.angles[current_section.stop_criteria["TMD"].bif_id]

dir1, dir2 = self.bif_method(current_section.history(), angles=ang)
dir1, dir2 = self.bif_method(
current_section.history(), angles=ang, y_rotation=self.context.get("y_rotation")
)
first_point = np.array(current_section.last_point)

stop1, stop2 = self.get_stop_criteria(current_section)
Expand Down Expand Up @@ -249,7 +251,9 @@ def bifurcate(self, current_section):
first_point = np.array(current_section.last_point)

if current_section.process == "major":
dir1, dir2 = bif_methods["directional"](current_section.direction, angles=ang)
dir1, dir2 = bif_methods["directional"](
current_section.direction, angles=ang, y_rotation=self.context.get("y_rotation")
)

if not self._found_last_bif:
self.apical_section = current_section.id
Expand All @@ -264,7 +268,9 @@ def bifurcate(self, current_section):
if not self._found_last_bif:
self._found_last_bif = True
else:
dir1, dir2 = self.bif_method(current_section.history(), angles=ang)
dir1, dir2 = self.bif_method(
current_section.history(), angles=ang, y_rotation=self.context.get("y_rotation")
)
process1 = "secondary"
process2 = "secondary"

Expand Down
24 changes: 18 additions & 6 deletions neurots/generate/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
from neurots.generate.soma import Soma
from neurots.generate.soma import SomaGrower
from neurots.generate.tree import TreeGrower
from neurots.morphmath import rotation
from neurots.morphmath import sample
from neurots.morphmath.utils import normalize_vectors
from neurots.preprocess import preprocess_inputs
from neurots.utils import Y_DIRECTION
from neurots.utils import NeuroTSError
from neurots.utils import convert_from_legacy_neurite_type
from neurots.utils import point_to_section_segment
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(
):
"""Constructor of the NeuronGrower class."""
self.neuron = Morphology()
self.context = context
self.context = self._process_context(context)
if rng_or_seed is None or isinstance(
rng_or_seed, (int, np.integer, SeedSequence, BitGenerator)
):
Expand Down Expand Up @@ -135,6 +137,20 @@ def __init__(

self._trunk_orientations_class = trunk_orientations_class

def _process_context(self, context):
"""Apply some required processing to the context dictionary."""
if context is None:
return {}
if not isinstance(context, dict):
return context

# we ofen need to use the y_direction as a rotation, so we save to it once here
if "y_direction" in context:
context["y_rotation"] = rotation.rotation_matrix_from_vectors(
Y_DIRECTION, context["y_direction"]
)
return context

def next(self):
"""Call the "next" method of each neurite grower."""
for grower in list(self.active_neurites):
Expand Down Expand Up @@ -363,11 +379,7 @@ def _simple_grow_trunks(self):
)

def _3d_angles_grow_trunks(self):
"""Grow trunk with 3d_angles method via :func:`.orientation.OrientationManager` class.
Args:
input_parameters_with_3d_anglles (dict): input_parameters with fits for 3d angles
"""
"""Grow trunk with 3d_angles method via :func:`.orientation.OrientationManager` class."""
trunk_orientations_manager = self._trunk_orientations_class(
soma=self.soma_grower.soma,
parameters=self.input_parameters,
Expand Down
24 changes: 11 additions & 13 deletions neurots/generate/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from neurots.morphmath import rotation
from neurots.morphmath import sample
from neurots.morphmath.utils import normalize_vectors
from neurots.utils import PIA_DIRECTION
from neurots.utils import Y_DIRECTION
from neurots.utils import NeuroTSError
from neurots.utils import accept_reject

Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, soma, parameters, distributions, context, rng):
self._soma = soma
self._parameters = parameters
self._distributions = distributions
self._context = context
self._context = context if context is not None else {}
self._rng = rng

self._orientations = {}
Expand Down Expand Up @@ -222,7 +222,7 @@ def _mode_normal_pia_constraint(self, values_dict, tree_type):
the second angle to obtain a 3d direction. For multiple apical trees, `mean` and `std`
should be two lists with lengths equal to number of trees, otherwise it can be a float.
Pia direction can be overwritten by the parameter 'pia_direction' value.
Pia direction can be overwritten by the parameter 'y_direction' value from context.
"""
n_orientations = sample.n_neurites(self._distributions[tree_type]["num_trees"], self._rng)
if (
Expand Down Expand Up @@ -250,7 +250,7 @@ def _mode_normal_pia_constraint(self, values_dict, tree_type):

phis = self._rng.uniform(0, 2 * np.pi, len(means))
angles += spherical_angles_to_pia_orientations(
phis, thetas, self._parameters.get("pia_direction", None)
phis, thetas, self._context.get("y_rotation", None)
).tolist()
return np.array(angles)

Expand All @@ -260,10 +260,10 @@ def _mode_pia_constraint(self, _, tree_type):
See :func:`self._sample_trunk_from_3d_angle` for more details on the algorithm.
"""
n_orientations = sample.n_neurites(self._distributions[tree_type]["num_trees"], self._rng)
pia_direction = self._parameters.get("pia_direction", PIA_DIRECTION)
y_direction = self._context.get("y_direction", Y_DIRECTION)
return np.asarray(
[
self._sample_trunk_from_3d_angle(tree_type, pia_direction)
self._sample_trunk_from_3d_angle(tree_type, y_direction)
for _ in range(n_orientations)
]
)
Expand Down Expand Up @@ -302,9 +302,7 @@ def prob(proposal):
params = self._parameters[tree_type]["orientation"]["values"]["params"]
p = _prob(val, *params)

if self._context is not None and self._context.get(
"constraints", []
): # pragma: no cover
if self._context.get("constraints", []): # pragma: no cover
for constraint in self._context["constraints"]:
if "trunk_prob" in constraint:
p *= constraint["trunk_prob"](proposal, self._soma.center)
Expand Down Expand Up @@ -477,13 +475,13 @@ def compute_interval_n_tree(soma, n_trees, rng=np.random):
return phi_intervals, interval_n_trees


def spherical_angles_to_pia_orientations(phis, thetas, pia_direction=None):
def spherical_angles_to_pia_orientations(phis, thetas, y_rotation=None):
"""Compute orientation from spherical angles where thetas are wrt to pia (default=`[0, 1, 0]`).
Args:
phis (numpy.ndarray): Polar angles.
thetas (numpy.ndarray): Azimuthal angles.
pia_direction (numpy.ndarray): Direction of pia if different from `[0, 1, 0]`.
y_rotation (numpy.ndarray): Rotation of y direction if different from `[0, 1, 0]`.
Returns:
numpy.ndarray: The orientation vectors where each row corresponds to a phi-theta pair.
Expand All @@ -492,8 +490,8 @@ def spherical_angles_to_pia_orientations(phis, thetas, pia_direction=None):
vector = np.column_stack(
(np.cos(phis) * np.sin(thetas), np.cos(thetas), np.sin(phis) * np.sin(thetas))
)
if pia_direction is not None:
vector = vector.dot(rotation.rotation_matrix_from_vectors(PIA_DIRECTION, pia_direction).T)
if y_rotation is not None:
vector = vector.dot(y_rotation.T)
return vector


Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def add_section(
"""
SGrower = section_growers[self.params["metric"]]
context = copy.deepcopy(self.context)
if self.context is not None and "constraints" in self.context: # pragma: no cover
if "constraints" in self.context: # pragma: no cover
context["constraints"] = [
constraint
for constraint in self.context["constraints"]
Expand Down
Loading

0 comments on commit f34dc8d

Please sign in to comment.