Skip to content

Commit

Permalink
Merge 1cb4748 into 103e9cf
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Sep 10, 2024
2 parents 103e9cf + 1cb4748 commit 38a1b7e
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import zarr
from zarr.n5 import N5FSStore
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays import (
ZarrArrayConfig,
ZarrArray,
Expand All @@ -15,11 +16,13 @@
ConcatArrayConfig,
LogicalOrArrayConfig,
ConstantArrayConfig,
CropArrayConfig,
)
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
import logging


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -92,6 +95,37 @@ def resize_if_needed(
else:
return array_config

def limit_validation_crop_size(gt_config, mask_config, max_size):
gt_array = gt_config.array_type(gt_config)
voxel_shape = gt_array.roi.shape / gt_array.voxel_size
crop = False
while np.prod(voxel_shape) > max_size:
crop = True
max_idx = np.argmax(voxel_shape)
voxel_shape = Coordinate(
s if i != max_idx else s // 2 for i, s in enumerate(voxel_shape)
)
if crop:
crop_roi_shape = voxel_shape * gt_array.voxel_size
context = (gt_array.roi.shape - crop_roi_shape) / 2
crop_roi = gt_array.roi.grow(-context, -context)
crop_roi = crop_roi.snap_to_grid(gt_array.voxel_size, mode="shrink")

logger.debug(
f"Cropped {gt_config.name}: original roi: {gt_array.roi}, new_roi: {crop_roi}"
)

gt_config = CropArrayConfig(
name=gt_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
mask_config = CropArrayConfig(
name=mask_config.name + "_cropped",
source_array_config=gt_config,
roi=crop_roi,
)
return gt_config, mask_config

def get_right_resolution_array_config(
container: Path, dataset, target_resolution, extra_str=""
Expand Down Expand Up @@ -441,6 +475,10 @@ class DataSplitGenerator:
The maximum raw value.
classes_separator_caracter : str
The classes separator character.
max_validation_volume_size : int
The maximum validation volume size. Default is None. If None, the validation volume size is not limited.
else, the validation volume size is limited to the specified value.
e.g. 600**3 for 600^3 voxels = 216_000_000 voxels.
Methods:
__init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter)
Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character.
Expand Down Expand Up @@ -484,6 +522,7 @@ def __init__(
raw_max=255,
classes_separator_caracter="&",
use_negative_class=False,
max_validation_volume_size=None,
):
"""
Initializes the DataSplitGenerator class with the specified:
Expand Down Expand Up @@ -573,6 +612,7 @@ def __init__(
self.raw_max = raw_max
self.classes_separator_caracter = classes_separator_caracter
self.use_negative_class = use_negative_class
self.max_validation_volume_size = max_validation_volume_size
if use_negative_class:
if targets is None:
raise ValueError(
Expand Down Expand Up @@ -749,6 +789,10 @@ def __generate_semantic_seg_datasplit(self):
)
)
else:
if self.max_validation_volume_size is not None:
gt_config, mask_config = limit_validation_crop_size(
gt_config, mask_config, self.max_validation_volume_size
)
validation_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
Expand Down

0 comments on commit 38a1b7e

Please sign in to comment.