Skip to content

Commit

Permalink
Use joblib everywhere instead of multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Sep 17, 2020
1 parent ed8c2b3 commit cc6c792
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 81 deletions.
36 changes: 21 additions & 15 deletions synthesis_workflow/diametrizer_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Luigi tasks to diametrize cells."""
import multiprocessing
import os
import sys
import traceback
Expand All @@ -13,6 +12,8 @@
from diameter_synthesis.build_diameters import build as build_diameters
from diameter_synthesis.build_models import build as build_diameter_model
from diameter_synthesis.plotting import plot_distribution_fit
from joblib import delayed
from joblib import Parallel
from morphio.mut import Morphology
from tqdm import tqdm

Expand Down Expand Up @@ -66,6 +67,7 @@ class BuildDiameterModels(luigi.Task):
diameter_models_path = luigi.Parameter(default="diameter_models.yaml")
by_mtypes = luigi.BoolParameter()
plot_models = luigi.BoolParameter()
nb_jobs = luigi.IntParameter(default=-1)

def run(self):
"""Run."""
Expand All @@ -88,12 +90,14 @@ def run(self):
config_model=config_model,
morphology_path=self.morphology_path,
)
with multiprocessing.Pool(maxtasksperchild=1) as pool:
for mtype, (params, data) in tqdm(
pool.imap_unordered(build_model, mtypes), total=len(mtypes)
):
models_params[mtype] = params
models_data[mtype] = data
for mtype, (params, data) in Parallel(self.nb_jobs)(
delayed(build_model)(
mtype
)
for mtype in tqdm(mtypes)
):
models_params[mtype] = params
models_data[mtype] = data
else:
morphologies = load_neurons(morphs_df[self.morphology_path].to_list())
models_params["all"], models_data["all"] = build_model(
Expand Down Expand Up @@ -181,14 +185,16 @@ def run(self):
)

exception_count = 0
with multiprocessing.Pool(maxtasksperchild=1) as pool:
for gid, new_path, exception in tqdm(
pool.imap_unordered(diametrizer, morphs_df.index), total=len(morphs_df)
):
morphs_df.loc[gid, self.new_morphology_path] = new_path
morphs_df.loc[gid, "exception"] = exception
if exception is not None:
exception_count += 1
for gid, new_path, exception in Parallel(self.nb_jobs)(
delayed(diametrizer)(
gid
)
for gid in tqdm(morphs_df.index)
):
morphs_df.loc[gid, self.new_morphology_path] = new_path
morphs_df.loc[gid, "exception"] = exception
if exception is not None:
exception_count += 1
L.info("Diametrization terminated, with %s exceptions.", exception_count)

update_morphs_df(self.morphs_df_path, morphs_df).to_csv(
Expand Down
24 changes: 12 additions & 12 deletions synthesis_workflow/synthesis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions for synthesis to be used by luigi tasks."""
import json
import logging
import multiprocessing
import os
import re
import sys
Expand Down Expand Up @@ -108,6 +107,7 @@ def build_distributions(
diameter_model_function,
morphology_path,
cortical_thickness,
nb_jobs=-1
):
"""Build tmd_distribution dictionary for synthesis.
Expand All @@ -129,17 +129,17 @@ def build_distributions(
morphology_path=morphology_path,
)

with multiprocessing.Pool(maxtasksperchild=1) as pool:
tmd_distributions = {
"mtypes": {
mtype: distribution
for mtype, distribution in tqdm( # pylint: disable=unnecessary-comprehension
pool.imap_unordered(build_distributions_single_mtype, mtypes),
total=len(mtypes),
)
},
"metadata": {"cortical_thickness": json.loads(cortical_thickness)},
}
tmd_distributions = {
"mtypes": {},
"metadata": {"cortical_thickness": json.loads(cortical_thickness)},
}
for mtype, distribution in Parallel(nb_jobs)(
delayed(build_distributions_single_mtype)(
mtype
)
for mtype in tqdm(mtypes)
):
tmd_distributions["mtypes"][mtype] = distribution
return tmd_distributions


Expand Down
114 changes: 60 additions & 54 deletions synthesis_workflow/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Functions for validation of synthesis to be used by luigi tasks."""
import multiprocessing
import os
from collections import defaultdict
from functools import partial
Expand All @@ -10,6 +9,8 @@
import numpy as np
import pandas as pd
import seaborn as sns
from joblib import delayed
from joblib import Parallel
from matplotlib.backends.backend_pdf import PdfPages
from scipy.linalg import expm
from scipy.optimize import fmin
Expand All @@ -20,8 +21,8 @@
from atlas_analysis.planes.maths import Plane
from morph_validator.feature_configs import get_feature_configs
from morph_validator.plotting import get_features_df, plot_violin_features
from morph_validator.spatial import (relative_depth_volume,
sample_morph_voxel_values)
from morph_validator.spatial import relative_depth_volume
from morph_validator.spatial import sample_morph_voxel_values
from neurom import viewer

from .circuit_slicing import get_cells_between_planes
Expand Down Expand Up @@ -166,30 +167,32 @@ def _plot_density_profile(
return fig


def plot_density_profiles(circuit, sample, region, sample_distance, output_path):
def plot_density_profiles(circuit, sample, region, sample_distance, output_path, nb_jobs=-1):
"""Plot density profiles for all mtypes.
WIP function, waiting on complete atlas to update.
"""
voxeldata = relative_depth_volume(circuit.atlas, in_region=region, relative=False)
x_pos = 0
with multiprocessing.Pool() as pool:
figures = pool.imap(
partial(
_plot_density_profile,
circuit=circuit,
x_pos=x_pos,
sample=sample,
voxeldata=voxeldata,
sample_distance=sample_distance,
),
sorted(circuit.cells.mtypes),

ensure_dir(output_path)
with PdfPages(output_path) as pdf:
f = partial(
_plot_density_profile,
circuit=circuit,
x_pos=x_pos,
sample=sample,
voxeldata=voxeldata,
sample_distance=sample_distance,
)
ensure_dir(output_path)
with PdfPages(output_path) as pdf:
for fig in list(figures):
pdf.savefig(fig, bbox_inches="tight")
plt.close(fig)
for fig in Parallel(nb_jobs)(
delayed(f)(
mtype
)
for mtype in sorted(circuit.cells.mtypes)
):
pdf.savefig(fig, bbox_inches="tight")
plt.close(fig)


def _plot_cells(circuit, mtype, sample, ax):
Expand Down Expand Up @@ -219,7 +222,7 @@ def _plot_collage_O1(
return fig


def plot_collage_O1(circuit, sample, output_path):
def plot_collage_O1(circuit, sample, output_path, mtypes=None, nb_jobs=-1):
"""Plot collage for all mtypes.
Args:
Expand Down Expand Up @@ -248,23 +251,24 @@ def plot_collage_O1(circuit, sample, output_path):
if mtypes is None:
mtypes = sorted(list(circuit.cells.mtypes))

with multiprocessing.Pool() as pool:
figures = pool.imap(
partial(
_plot_collage_O1,
circuit=circuit,
figsize=figsize,
x_pos=x_pos,
ax_limit=ax_limit,
sample=sample,
),
mtypes,
ensure_dir(output_path)
with PdfPages(output_path) as pdf:
f = partial(
_plot_collage_O1,
circuit=circuit,
figsize=figsize,
x_pos=x_pos,
ax_limit=ax_limit,
sample=sample,
)
ensure_dir(output_path)
with PdfPages(output_path) as pdf:
for fig in tqdm(figures, total=len(mtypes)):
pdf.savefig(fig, bbox_inches="tight", dpi=100)
plt.close(fig)
for fig in Parallel(nb_jobs)(
delayed(f)(
mtype
)
for mtype in sorted(circuit.cells.mtypes)
):
pdf.savefig(fig, bbox_inches="tight", dpi=100)
plt.close(fig)


def get_aligned_basis(plane, target=[0, 0, 1]):
Expand Down Expand Up @@ -384,24 +388,26 @@ def _plot_collage(


def plot_collage(
circuit, planes, layer_annotation, mtype, output_path="collage.pdf", sample=10
circuit, planes, layer_annotation, mtype, output_path="collage.pdf", sample=10, nb_jobs=-1
):
"""Plot collage of an mtyp and a list of planes."""
plane_ids = np.arange(int(len(planes) / 2) - 1)
with multiprocessing.Pool() as pool:
figures = pool.imap(
partial(
_plot_collage,
planes=planes,
layer_annotation=layer_annotation,
circuit=circuit,
mtype=mtype,
sample=sample,
),
plane_ids,

ensure_dir(output_path)
with PdfPages(output_path) as pdf:
f = partial(
_plot_collage,
planes=planes,
layer_annotation=layer_annotation,
circuit=circuit,
mtype=mtype,
sample=sample,
)
ensure_dir(output_path)
with PdfPages(output_path) as pdf:
for fig in tqdm(figures, total=len(plane_ids)):
pdf.savefig(fig, bbox_inches="tight", dpi=100)
plt.close(fig)
for fig in Parallel(nb_jobs)(
delayed(f)(
plane_id
)
for plane_id in plane_ids
):
pdf.savefig(fig, bbox_inches="tight", dpi=100)
plt.close(fig)

0 comments on commit cc6c792

Please sign in to comment.