diff --git a/dicom_numpy/combine_slices.py b/dicom_numpy/combine_slices.py index c484b59..5270ade 100644 --- a/dicom_numpy/combine_slices.py +++ b/dicom_numpy/combine_slices.py @@ -75,27 +75,27 @@ def combine_slices(slice_datasets, rescale=None): def _merge_slice_pixel_arrays(slice_datasets, rescale=None): - first_dataset = slice_datasets[0] - num_rows = first_dataset.Rows - num_columns = first_dataset.Columns - num_slices = len(slice_datasets) - sorted_slice_datasets = _sort_by_slice_position(slice_datasets) if rescale is None: rescale = any(_requires_rescaling(d) for d in sorted_slice_datasets) - if rescale: - voxels = np.empty((num_columns, num_rows, num_slices), dtype=np.float32, order='F') - for k, dataset in enumerate(sorted_slice_datasets): + first_dataset = sorted_slice_datasets[0] + slice_dtype = first_dataset.pixel_array.dtype + slice_shape = first_dataset.pixel_array.T.shape + num_slices = len(sorted_slice_datasets) + + voxels_shape = slice_shape + (num_slices,) + voxels_dtype = np.float32 if rescale else slice_dtype + voxels = np.empty(voxels_shape, dtype=voxels_dtype, order='F') + + for k, dataset in enumerate(sorted_slice_datasets): + pixel_array = dataset.pixel_array.T + if rescale: slope = float(getattr(dataset, 'RescaleSlope', 1)) intercept = float(getattr(dataset, 'RescaleIntercept', 0)) - voxels[:, :, k] = dataset.pixel_array.T.astype(np.float32) * slope + intercept - else: - dtype = first_dataset.pixel_array.dtype - voxels = np.empty((num_columns, num_rows, num_slices), dtype=dtype, order='F') - for k, dataset in enumerate(sorted_slice_datasets): - voxels[:, :, k] = dataset.pixel_array.T + pixel_array = pixel_array.astype(np.float32) * slope + intercept + voxels[..., k] = pixel_array return voxels @@ -136,6 +136,7 @@ def _validate_slices_form_uniform_grid(slice_datasets): 'SeriesInstanceUID', 'Rows', 'Columns', + 'SamplesPerPixel', 'PixelSpacing', 'PixelRepresentation', 'BitsAllocated', diff --git a/tests/conftest.py b/tests/conftest.py index 35d2a70..f96c5a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ negative_z_cos = (0, 0, -1) arbitrary_shape = (10, 11) +arbitrary_rgb_shape = (10, 11, 3) class MockSlice: @@ -28,15 +29,21 @@ def __init__(self, pixel_array, slice_position, row_cosine=None, column_cosine=N if column_cosine is None: column_cosine = y_cos - na, nb = pixel_array.shape + shape = pixel_array.shape + if len(shape) == 2: + na, nb = shape + SamplesPerPixel = 1 + else: + na, nb, SamplesPerPixel = shape self.pixel_array = pixel_array self.SeriesInstanceUID = 'arbitrary uid' self.SOPClassUID = 'arbitrary sopclass uid' self.PixelSpacing = [1.0, 1.0] - self.Rows = na - self.Columns = nb + self.Columns = na + self.Rows = nb + self.SamplesPerPixel = SamplesPerPixel self.Modality = 'MR' # assume that the images are centered on the remaining unused axis @@ -63,5 +70,16 @@ def axial_slices(): ] +@pytest.fixture +def axial_rgb_slices(): + # SamplesPerPixel = 3 + return [ + MockSlice(randi(*arbitrary_rgb_shape), 0), + MockSlice(randi(*arbitrary_rgb_shape), 1), + MockSlice(randi(*arbitrary_rgb_shape), 2), + MockSlice(randi(*arbitrary_rgb_shape), 3), + ] + + def randi(*shape): return np.random.randint(1000, size=shape, dtype='uint16') diff --git a/tests/test_combine_slices.py b/tests/test_combine_slices.py index a82e90f..e572b14 100644 --- a/tests/test_combine_slices.py +++ b/tests/test_combine_slices.py @@ -30,6 +30,12 @@ def test_single_slice_spacing(self, axial_slices): assert np.array_equal(array, dataset.pixel_array.T[:, :, None]) assert np.isclose(np.linalg.norm(affine[:, 2]), np.abs(slice_spacing)) + def test_rgb_axial_set(self, axial_rgb_slices): + combined, _ = combine_slices(axial_rgb_slices) + + manually_combined = np.stack([ds.pixel_array for ds in axial_rgb_slices], axis=0).T + assert np.array_equal(combined, manually_combined) + class TestMergeSlicePixelArrays: def test_casting_if_only_rescale_slope(self):