Skip to content

Commit

Permalink
better datasplit generator naming
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Jul 23, 2024
1 parent fb1bd5d commit 926bb2d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, array_config):
`like` method to create a new OnesArray with the same metadata as
another array.
"""
logger.warning("OnesArray is deprecated. Use ConstantArray instead.")
self._source_array = array_config.source_array_config.array_type(
array_config.source_array_config
)
Expand Down Expand Up @@ -406,5 +407,4 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
specified by the region of interest. This method returns a subarray
of the array with all values set to 1.
"""
logger.warning("OnesArray is deprecated. Use ConstantArray instead.")
return np.ones_like(self.source_array.__getitem__(roi), dtype=bool)
20 changes: 13 additions & 7 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def class_name(self):
Notes:
This function is used to get the class name.
"""
if self._class_name is None:
if self.targets is None:
logger.warning("Both targets and class name are None.")
return None
self._class_name = self.targets
return self._class_name

# Goal is to force class_name to be set only once, so we have the same classes for all datasets
Expand Down Expand Up @@ -730,10 +735,14 @@ def __generate_semantic_seg_datasplit(self):
gt_config,
mask_config,
) = self.__generate_semantic_seg_dataset_crop(dataset)
if type(self.class_name) == list:
classes = self.classes_separator_caracter.join(self.class_name)
else:
classes = self.class_name
if dataset.dataset_type == DatasetType.train:
train_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{self.class_name}_{self.output_resolution[0]}nm",
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
raw_config=raw_config,
gt_config=gt_config,
mask_config=mask_config,
Expand All @@ -742,16 +751,13 @@ def __generate_semantic_seg_datasplit(self):
else:
validation_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{self.class_name}_{self.output_resolution[0]}nm",
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
raw_config=raw_config,
gt_config=gt_config,
mask_config=mask_config,
)
)
if type(self.class_name) == list:
classes = self.classes_separator_caracter.join(self.class_name)
else:
classes = self.class_name

return TrainValidateDataSplitConfig(
name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm",
train_configs=train_dataset_configs,
Expand Down Expand Up @@ -815,7 +821,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
organelle_arrays = {}
# classes_datasets, classes = self.check_class_name(gt_dataset)
classes_datasets, classes = format_class_name(
gt_dataset, self.classes_separator_caracter
gt_dataset, self.classes_separator_caracter, self.targets
)
for current_class_dataset, current_class_name in zip(classes_datasets, classes):
if not (gt_path / current_class_dataset).exists():
Expand Down

0 comments on commit 926bb2d

Please sign in to comment.