Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Mathias Goncalves <goncalves.mathias@gmail.com>
  • Loading branch information
oesteban and mgxd committed Nov 26, 2020
1 parent c6200b8 commit c3647e7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
18 changes: 9 additions & 9 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
"mode",
"median",
"mean",
"no",
False,
usedefault=True,
desc="strategy to recenter the distribution of the input fieldmap",
)
Expand Down Expand Up @@ -143,14 +143,14 @@ def _run_interface(self, runtime):
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
nb.Nifti1Image(interp_data, fmapnii.affine, hdr).to_filename(out_name)
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(out_name)
self._results["out_field"] = out_name

index = 0
self._results["out_coeff"] = []
for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)):
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
nb.Nifti1Image(
bsl.__class__(
np.array(model.coef_, dtype="float32")[index:index + n].reshape(
bsl.shape
),
Expand All @@ -162,7 +162,7 @@ def _run_interface(self, runtime):

# Write out fitting-error map
self._results["out_error"] = out_name.replace("_field.", "_error.")
nb.Nifti1Image(
fmapnii.__class__(
data * mask - interp_data, fmapnii.affine, fmapnii.header
).to_filename(self._results["out_error"])

Expand All @@ -177,7 +177,7 @@ def _run_interface(self, runtime):
)
interp_data[~mask] = np.array(model.coef_) @ extrapolators # Extrapolation
self._results["out_extrapolated"] = out_name.replace("_field.", "_extra.")
nb.Nifti1Image(interp_data, fmapnii.affine, hdr).to_filename(
fmapnii.__class__(interp_data, fmapnii.affine, hdr).to_filename(
self._results["out_extrapolated"]
)
return runtime
Expand All @@ -190,7 +190,7 @@ class _Coefficients2WarpInputSpec(BaseInterfaceInputSpec):
mandatory=True,
desc="input coefficients, after alignment to the EPI data",
)
ro_time = traits.Float(1.0, usedefault=True, desc="EPI readout time (s).")
ro_time = traits.Float(mandatory=True, desc="EPI readout time (s).")
pe_dir = traits.Enum(
"i",
"i-",
Expand Down Expand Up @@ -260,7 +260,7 @@ def _run_interface(self, runtime):
self._results["out_field"] = fname_presuffix(
self.inputs.in_target, suffix="_field", newpath=runtime.cwd
)
nb.Nifti1Image(data, targetnii.affine, hdr).to_filename(
targetnii.__class__(data, targetnii.affine, hdr).to_filename(
self._results["out_field"]
)

Expand All @@ -280,7 +280,7 @@ def _run_interface(self, runtime):
aff = targetnii.affine.copy()
aff[:3, 3] = 0.0
field = nb.affines.apply_affine(aff, field).reshape(fieldshape)
warpnii = nb.Nifti1Image(
warpnii = targetnii.__class__(
field[:, :, :, np.newaxis, :].astype("float32"), targetnii.affine, None
)
warpnii.header.set_intent("vector", (), "")
Expand Down Expand Up @@ -314,7 +314,7 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
bs_affine, 0.5 * (bs_shape - 1)
)

return nb.Nifti1Image(np.zeros(bs_shape, dtype="float32"), bs_affine)
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def bspline_weights(points, ctrl_nii):
Expand Down
2 changes: 1 addition & 1 deletion sdcflows/interfaces/reportlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _generate_report(self):
abs(np.percentile(fmapdata[maskdata], 0.2)))
if self.inputs.apply_mask:
fmapdata[~maskdata] = 0
fmapnii = nb.Nifti1Image(fmapdata, fmapnii.affine, fmapnii.header)
fmapnii = fmapnii.__class__(fmapdata, fmapnii.affine, fmapnii.header)

fmap_overlay = [{
'overlay': fmapnii,
Expand Down
3 changes: 2 additions & 1 deletion sdcflows/interfaces/tests/test_bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def test_bsplines(tmp_path, testnum):
in_target=str(tmp_path / "target.nii.gz"),
in_coeff=str(tmp_path / "coeffs.nii.gz"),
pe_dir="j-",
ro_time=1.0,
).run()

# Approximate the interpolated target
test2 = BSplineApprox(
in_data=test1.outputs.out_field,
in_mask=str(tmp_path / "target.nii.gz"),
bs_spacing=[(4, 6, 8)],
recenter="no",
recenter=False,
ridge_alpha=1e-4,
).run()

Expand Down

0 comments on commit c3647e7

Please sign in to comment.