Skip to content

Commit

Permalink
fix: refactor the brain extraction workflow [wip] [skip ci]
Browse files Browse the repository at this point in the history
This commit:

  - [x] Updates the nodes with pure python interfaces based on nibabel,
    minimizing the need for the new ``copy_header`` of ANTs' nipype
    interfaces.
  - [x] Reorganizes the workflow so that the Atropos refinement is
    completely self contained.

These are the first two steps to address nipreps/smriprep#125.
  • Loading branch information
oesteban committed Sep 16, 2020
1 parent 9c3eb81 commit 07792d8
Showing 1 changed file with 60 additions and 70 deletions.
130 changes: 60 additions & 70 deletions niworkflows/anat/ants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
# nipype
from nipype.pipeline import engine as pe
from nipype.interfaces import utility as niu
from nipype.interfaces.fsl.maths import ApplyMask
from nipype.interfaces.ants import (
AI,
Atropos,
ImageMath,
MultiplyImages,
N4BiasFieldCorrection,
ResampleImageBySpacing,
ThresholdImage,
)

Expand All @@ -31,8 +29,9 @@
FixHeaderRegistration as Registration,
FixHeaderApplyTransforms as ApplyTransforms,
)
from ..interfaces.images import RegridToZooms
from ..interfaces.nibabel import ApplyMask, Binarize
from ..interfaces.utils import CopyXForm
from ..interfaces.nibabel import Binarize


ATROPOS_MODELS = {
Expand Down Expand Up @@ -203,14 +202,8 @@ def init_brain_extraction_wf(
name="outputnode",
)

copy_xform = pe.Node(
CopyXForm(fields=["out_file", "out_mask", "bias_corrected", "bias_image"]),
name="copy_xform",
run_without_submitting=True,
)

trunc = pe.MapNode(
ImageMath(operation="TruncateImageIntensity", op2="0.01 0.999 256"),
ImageMath(operation="TruncateImageIntensity", op2="0.01 0.999 256", copy_header=True),
name="truncate_images",
iterfield=["op1"],
)
Expand All @@ -229,20 +222,15 @@ def init_brain_extraction_wf(
iterfield=["input_image"],
)

res_tmpl = pe.Node(
ResampleImageBySpacing(out_spacing=(4, 4, 4), apply_smoothing=True),
name="res_tmpl",
)
res_tmpl.inputs.input_image = tpl_target_path
res_target = pe.Node(
ResampleImageBySpacing(out_spacing=(4, 4, 4), apply_smoothing=True),
name="res_target",
)
res_tmpl = pe.Node(RegridToZooms(in_file=tpl_target_path, zooms=(4, 4, 4), smooth=True),
name="res_tmpl")
res_target = pe.Node(RegridToZooms(zooms=(4, 4, 4), smooth=True), name="res_target")

lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_tmpl")
lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True),
name="lap_tmpl")
lap_tmpl.inputs.op1 = tpl_target_path
lap_target = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_target"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_target"
)
mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")
mrg_tmpl.inputs.in1 = tpl_target_path
Expand Down Expand Up @@ -287,26 +275,20 @@ def init_brain_extraction_wf(
fixed_mask_trait += "s"

map_brainmask = pe.Node(
ApplyTransforms(interpolation="Gaussian", float=True),
ApplyTransforms(interpolation="Gaussian"),
name="map_brainmask",
mem_gb=1,
)
map_brainmask.inputs.input_image = str(tpl_mask_path)

thr_brainmask = pe.Node(
ThresholdImage(
dimension=3, th_low=0.5, th_high=1.0, inside_value=1, outside_value=0
dimension=3, th_low=0.5, th_high=1.0, inside_value=1, outside_value=0,
copy_header=True,
),
name="thr_brainmask",
)

# Morphological dilation, radius=2
dil_brainmask = pe.Node(ImageMath(operation="MD", op2="2"), name="dil_brainmask")
# Get largest connected component
get_brainmask = pe.Node(
ImageMath(operation="GetLargestComponent"), name="get_brainmask"
)

# Refine INU correction
inu_n4_final = pe.MapNode(
N4BiasFieldCorrection(
Expand Down Expand Up @@ -340,47 +322,37 @@ def init_brain_extraction_wf(
# fmt: off
wf.connect([
(inputnode, trunc, [("in_files", "op1")]),
(inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
(inputnode, inu_n4_final, [("in_files", "input_image")]),
(inputnode, init_aff, [("in_mask", "fixed_image_mask")]),
(inputnode, norm, [("in_mask", fixed_mask_trait)]),
(inputnode, map_brainmask, [(("in_files", _pop), "reference_image")]),
(trunc, inu_n4, [("output_image", "input_image")]),
(inu_n4, res_target, [(("output_image", _pop), "input_image")]),
(res_tmpl, init_aff, [("output_image", "fixed_image")]),
(res_target, init_aff, [("output_image", "moving_image")]),
(inu_n4, res_target, [(("output_image", _pop), "in_file")]),
(res_tmpl, init_aff, [("out_file", "fixed_image")]),
(res_target, init_aff, [("out_file", "moving_image")]),
(init_aff, norm, [("output_transform", "initial_moving_transform")]),
(norm, map_brainmask, [
("reverse_transforms", "transforms"),
("reverse_invert_flags", "invert_transform_flags"),
]),
(map_brainmask, thr_brainmask, [("output_image", "input_image")]),
(thr_brainmask, dil_brainmask, [("output_image", "op1")]),
(dil_brainmask, get_brainmask, [("output_image", "op1")]),
(map_brainmask, inu_n4_final, [("output_image", "weight_image")]),
(inu_n4_final, apply_mask, [("output_image", "in_file")]),
(get_brainmask, apply_mask, [("output_image", "mask_file")]),
(get_brainmask, copy_xform, [("output_image", "out_mask")]),
(apply_mask, copy_xform, [("out_file", "out_file")]),
(inu_n4_final, copy_xform, [
("output_image", "bias_corrected"),
("bias_image", "bias_image"),
]),
(copy_xform, outputnode, [
("out_file", "out_file"),
("out_mask", "out_mask"),
("bias_corrected", "bias_corrected"),
("bias_image", "bias_image"),
]),
(thr_brainmask, apply_mask, [("output_image", "in_mask")]),
(thr_brainmask, outputnode, [("output_image", "out_mask")]),
(inu_n4_final, outputnode, [("output_image", "bias_corrected"),
("bias_image", "bias_image")]),
(apply_mask, outputnode, [("out_file", "out_file")]),
])
# fmt: on

if use_laplacian:
lap_tmpl = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_tmpl"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_tmpl"
)
lap_tmpl.inputs.op1 = tpl_target_path
lap_target = pe.Node(
ImageMath(operation="Laplacian", op2="1.5 1"), name="lap_target"
ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_target"
)
mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")
mrg_tmpl.inputs.in1 = tpl_target_path
Expand Down Expand Up @@ -412,27 +384,21 @@ def init_brain_extraction_wf(
mem_gb=mem_gb,
in_segmentation_model=atropos_model,
)
sel_wm = pe.Node(
niu.Select(index=atropos_model[-1] - 1),
name="sel_wm",
run_without_submitting=True,
)

# fmt: off
wf.disconnect([
(get_brainmask, apply_mask, [("output_image", "mask_file")]),
(copy_xform, outputnode, [("out_mask", "out_mask")]),
(thr_brainmask, outputnode, [("output_image", "out_mask")]),
(inu_n4_final, outputnode, [("output_image", "bias_corrected"),
("bias_image", "bias_image")]),
(apply_mask, outputnode, [("out_file", "out_file")]),
])
wf.connect([
(inu_n4, atropos_wf, [("output_image", "inputnode.in_files")]),
(inu_n4_final, atropos_wf, [("output_image", "inputnode.in_files")]),
(thr_brainmask, atropos_wf, [("output_image", "inputnode.in_mask")]),
(get_brainmask, atropos_wf, [
("output_image", "inputnode.in_mask_dilated"),
]),
(atropos_wf, sel_wm, [("outputnode.out_tpms", "inlist")]),
(sel_wm, inu_n4_final, [("out", "weight_image")]),
(atropos_wf, apply_mask, [("outputnode.out_mask", "mask_file")]),
(atropos_wf, outputnode, [
("outputnode.out_file", "out_file"),
("outputnode.bias_corrected", "bias_corrected"),
("outputnode.out_mask", "bias_image"),
("outputnode.out_mask", "out_mask"),
("outputnode.out_segm", "out_segm"),
("outputnode.out_tpms", "out_tpms"),
Expand Down Expand Up @@ -515,12 +481,13 @@ def init_atropos_wf(
wf = pe.Workflow(name)

inputnode = pe.Node(
niu.IdentityInterface(fields=["in_files", "in_mask", "in_mask_dilated"]),
niu.IdentityInterface(fields=["in_files", "in_mask"]),
name="inputnode",
)
outputnode = pe.Node(
niu.IdentityInterface(fields=["out_mask", "out_segm", "out_tpms"]),
name="outputnode",
niu.IdentityInterface(fields=[
"out_file", "bias_corrected", "bias_image", "out_mask", "out_segm", "out_tpms"
]), name="outputnode",
)

copy_xform = pe.Node(
Expand All @@ -529,6 +496,22 @@ def init_atropos_wf(
run_without_submitting=True,
)

# Morphological dilation, radius=2
dil_brainmask = pe.Node(ImageMath(operation="MD", op2="2", copy_header=True),
name="dil_brainmask")
# Get largest connected component
get_brainmask = pe.Node(
ImageMath(operation="GetLargestComponent", copy_header=True), name="get_brainmask"
)

# (thr_brainmask, dil_brainmask, [("output_image", "op1")]),
# (dil_brainmask, get_brainmask, [("output_image", "op1")]),
# (get_brainmask, apply_mask, [("output_image", "in_mask")]),
# (get_brainmask, outputnode, [("output_image", "out_mask")]),
# (get_brainmask, atropos_wf, [
# ("output_image", "inputnode.in_mask_dilated"),
# ]),

# Run atropos (core node)
atropos = pe.Node(
Atropos(
Expand All @@ -549,10 +532,10 @@ def init_atropos_wf(

# massage outputs
pad_segm = pe.Node(
ImageMath(operation="PadImage", op2="%d" % padding), name="02_pad_segm"
ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="02_pad_segm"
)
pad_mask = pe.Node(
ImageMath(operation="PadImage", op2="%d" % padding), name="03_pad_mask"
ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="03_pad_mask"
)

# Split segmentation in binary masks
Expand Down Expand Up @@ -649,6 +632,13 @@ def init_atropos_wf(

msk_conform = pe.Node(niu.Function(function=_conform_mask), name="msk_conform")
merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]), name="merge_tpms")

sel_wm = pe.Node(
niu.Select(index=atropos_model[-1] - 1),
name="sel_wm",
run_without_submitting=True,
)

# fmt: off
wf.connect([
(inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
Expand Down

0 comments on commit 07792d8

Please sign in to comment.