Skip to content

Commit

Permalink
Add piping to pass arguments to load_data methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzwanenburg committed Apr 11, 2024
1 parent 5c753ec commit 72efd75
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mirp/_data_import/generic_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def update_image_data(self):

def to_object(self, **kwargs) -> GenericImage:

self.load_data()
self.load_data(**kwargs)
self.complete()
self.stack_slices()
self.update_image_data()
Expand Down
12 changes: 7 additions & 5 deletions mirp/_data_import/read_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

def read_image(
image: ImageFile,
to_numpy=False
to_numpy=False,
**kwargs
) -> np.ndarray | GenericImage:
image = image.to_object().promote()
image = image.to_object(**kwargs).promote()

if to_numpy:
image = image.get_voxel_grid()
Expand All @@ -20,22 +21,23 @@ def read_image(

def read_image_and_masks(
image: ImageFile,
to_numpy=False
to_numpy=False,
**kwargs
) -> tuple[np.ndarray | GenericImage, list[np.ndarray] | list[BaseMask]]:
mask_list = []
if image.associated_masks is not None:
mask_list = image.associated_masks

# Read masks from file.
if mask_list is not None:
mask_list = [mask.to_object(image=image) for mask in mask_list]
mask_list = [mask.to_object(image=image, **kwargs) for mask in mask_list]
mask_list = flatten_list(mask_list)

# Remove None entries.
mask_list = [mask for mask in mask_list if mask is not None]

# Read image from file.
image = image.to_object().promote()
image = image.to_object(**kwargs).promote()

if to_numpy:
image = image.get_voxel_grid()
Expand Down
5 changes: 4 additions & 1 deletion mirp/_workflows/standardWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def standard_image_processing(self) -> tuple[GenericImage, list[BaseMask]]:
logging.info(self._message_start())

# Read image and masks.
image, masks = read_image_and_masks(self.image_file, to_numpy=False)
image, masks = read_image_and_masks(
self.image_file,
to_numpy=False
)

if masks is None or len(masks) == 0:
warnings.warn("No segmentation masks were read.")
Expand Down

0 comments on commit 72efd75

Please sign in to comment.