Skip to content

Commit

Permalink
collage update
Browse files Browse the repository at this point in the history
Change-Id: I453ef2c4cb94d74c780fe956bc4ec610c2584911
  • Loading branch information
arnaudon committed Sep 29, 2020
1 parent 80b55a8 commit 609fc39
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import pandas as pd
from tqdm import tqdm
from voxcell import CellCollection
import numpy as np

from atlas_analysis.planes.maths import Plane
from atlas_analysis.planes.planes import _create_planes, _create_centerline, _smoothing

LEFT = 0
RIGHT = 1
Expand Down Expand Up @@ -63,38 +64,27 @@ def slice_n_cells(cells, n_cells, random_state=0):
return sampled_cells


def slice_x_slice(cells, x_slice):
"""Select cells in x_slice."""
return cells[cells.x.between(x_slice[0], x_slice[1])]


def slice_y_slice(cells, y_slice):
"""Select cells in y_slice."""
return cells[cells.y.between(y_slice[0], y_slice[1])]


def slice_z_slice(cells, z_slice):
"""Select cells in z_slice."""
return cells[cells.z.between(z_slice[0], z_slice[1])]


def slice_atlas_bbox(cells, bbox):
"""Slice cells given a bbox on the atlas."""
cells = slice_x_slice(cells, bbox[0])
cells = slice_y_slice(cells, bbox[1])
return slice_z_slice(cells, bbox[2])
def is_between_planes(point, plane_left, plane_right):
"""Check if a point is between two planes in equation representation."""
eq_left = plane_left.get_equation()
eq_right = plane_right.get_equation()
return (eq_left[:3].dot(point) > eq_left[3]) & (
eq_right[:3].dot(point) < eq_right[3]
)


def generic_slicer_old(cells, n_cells, mtypes=None, bbox=None):
"""Select n_cells mtype in mtypes and within bbox."""
if mtypes is not None:
cells = slice_per_mtype(cells, mtypes)
if bbox is not None:
cells = slice_atlas_bbox(cells, bbox)
return slice_n_cells(cells, n_cells)
def get_cells_between_planes(cells, plane_left, plane_right):
"""Get cells gids between two planes in equation representation."""
cells["selected"] = cells[["x", "y", "z"]].apply(
lambda soma_position: is_between_planes(
soma_position.to_numpy(), plane_left, plane_right
),
axis=1,
)
return cells[cells.selected].drop("selected", axis=1)


def generic_slicer(cells, n_cells, mtypes=None, planes=None, hemisphere=None):
def circuit_slicer(cells, n_cells, mtypes=None, planes=None, hemisphere=None):
"""Select n_cells mtype in mtypes."""
if mtypes is not None:
cells = slice_per_mtype(cells, mtypes)
Expand All @@ -110,39 +100,14 @@ def generic_slicer(cells, n_cells, mtypes=None, planes=None, hemisphere=None):
get_cells_between_planes(cells, plane_left, plane_right), n_cells
)
for plane_left, plane_right in tqdm(
zip(planes[:-1], planes[1:]), total=len(planes) - 1
zip(planes[:-1:3], planes[2::3]), total=int(len(planes) / 3)
)
]
)
return slice_n_cells(cells, n_cells)


def is_between_planes(point, plane_left, plane_right):
"""Check if a point is between two planes in equation representation."""
eq_left = get_plane_equation(plane_left)
eq_right = get_plane_equation(plane_right)
return (eq_left[:3].dot(point) > eq_left[3]) & (
eq_right[:3].dot(point) < eq_right[3]
)


def get_plane_equation(quaternion):
"""Get the plane equation from a quaternion representation"""
return Plane.from_quaternion(quaternion[:3], quaternion[3:]).get_equation()


def get_cells_between_planes(cells, plane_left, plane_right):
"""Get cells gids between two planes in equation representation."""
cells["selected"] = cells[["x", "y", "z"]].apply(
lambda soma_position: is_between_planes(
soma_position.to_numpy(), plane_left, plane_right
),
axis=1,
)
return cells[cells.selected].drop("selected", axis=1)


def slice_circuit(input_mvd3, output_mvd3, slicing_function):
def slice_circuit(input_mvd3, output_mvd3, slicer):
"""Slice an mvd3 file using a slicing function.
Args:
Expand All @@ -151,8 +116,71 @@ def slice_circuit(input_mvd3, output_mvd3, slicing_function):
slicing_function (function): function to slice the cells dataframe
"""
cells = CellCollection.load_mvd3(input_mvd3)
sliced_cells = slicing_function(cells.as_dataframe())
sliced_cells = slicer(cells.as_dataframe())
sliced_cells.reset_index(inplace=True)
sliced_cells.index += 1
CellCollection.from_dataframe(sliced_cells).save_mvd3(output_mvd3)
return sliced_cells


def create_planes(
layer_annotation,
plane_type="aligned",
plane_count=10,
slice_thickness=100,
centerline_first_bound=None,
centerline_last_bound=None,
centerline_axis=0,
):
"""Create planes in an atlas.
We create 3 * plane_count such each triplet of planes define the left, center
and right plane of each slice.
Args:
layer_annotation (VoxelData): annotations with layers
plane_type (str): type of planes creation algorithm, two choices:
- centerline: centerline is computed between _first_bound and _last_bound with
internal algorithm (from atlas-analysis package)
- aligned: centerline is a straight line, along the centerline_axis
plane_count (int): number of planes to create slices of atlas,
slice_thickness (float): thickness of slices (in micrometer)
centerline_first_bound (list): (for plane_type == centerline) location of first bound
for centerline (in voxcell index)
centerline_last_bound (list): (for plane_type == centerline) location of last bound
for centerline (in voxcell index)
centerline_axis (str): (for plane_type = aligned) axis along which to create planes
"""
if plane_type == "centerline":
centerline = _create_centerline(
layer_annotation, [centerline_first_bound, centerline_last_bound]
)
centerline = _smoothing(centerline)

elif plane_type == "aligned":
_n_points = 10
bbox = layer_annotation.bbox
centerline = np.zeros([_n_points, 3])
centerline[:, centerline_axis] = np.linspace(
bbox[0, centerline_axis],
bbox[1, centerline_axis],
_n_points,
)
else:
raise Exception(
f"Please set plane_type to 'aligned' or 'centerline', not {plane_type}."
)

# create all planes to match slice_thickness between every two planes
centerline_len = np.linalg.norm(np.diff(centerline, axis=0), axis=1).sum()
total_plane_count = int(centerline_len / slice_thickness) * 2 + 1
planes = _create_planes(centerline, plane_count=total_plane_count)

# select plane_count planes + direct left/right neighbors
planes_all_ids = np.arange(total_plane_count)
id_shift = int(total_plane_count / plane_count)
planes_select_ids = list(planes_all_ids[int(id_shift / 2) :: id_shift])
planes_select_ids += list(planes_all_ids[int(id_shift / 2) - 1 :: id_shift])
planes_select_ids += list(planes_all_ids[int(id_shift / 2) + 1 :: id_shift])
return [planes[i] for i in sorted(planes_select_ids)], centerline
5 changes: 3 additions & 2 deletions synthesis_workflow/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def build_distributions(
morphology_path,
cortical_thickness,
nb_jobs=-1,
joblib_verbose=10,
):
"""Build tmd_distribution dictionary for synthesis.
Expand All @@ -128,8 +129,8 @@ def build_distributions(
"mtypes": {},
"metadata": {"cortical_thickness": json.loads(cortical_thickness)},
}
for mtype, distribution in Parallel(nb_jobs)(
delayed(build_distributions_single_mtype)(mtype) for mtype in tqdm(mtypes)
for mtype, distribution in Parallel(nb_jobs, verbose=joblib_verbose)(
delayed(build_distributions_single_mtype)(mtype) for mtype in mtypes
):
tmd_distributions["mtypes"][mtype] = distribution
return tmd_distributions
Expand Down
Loading

0 comments on commit 609fc39

Please sign in to comment.