Skip to content

Commit

Permalink
Specify resolution with sequences.
Browse files Browse the repository at this point in the history
  • Loading branch information
vaxenburg committed May 6, 2024
1 parent d88faf8 commit aaf46b6
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from dacapo.experiments.tasks import TaskConfig
from pathlib import Path
from typing import List
from typing import List, Union, Optional, Sequence
from enum import Enum, EnumMeta
from funlib.geometry import Coordinate
from typing import Union, Optional

import zarr
from dacapo.experiments.datasplits.datasets.arrays import (
Expand Down Expand Up @@ -159,7 +158,10 @@ def generate_dataspec_from_csv(csv_path: Path):

class DataSplitGenerator:
"""Generates DataSplitConfig for a given task config and datasets.
class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class
Class names in gt_dataset should be within [] e.g. [mito&peroxisome&er] for
multiple classes or [mito] for one class.
Currently only supports:
- semantic segmentation.
Supports:
Expand All @@ -172,8 +174,8 @@ def __init__(
self,
name: str,
datasets: List[DatasetSpec],
input_resolution: Coordinate,
output_resolution: Coordinate,
input_resolution: Union[Sequence[int], Coordinate],
output_resolution: Union[Sequence[int], Coordinate],
targets: Optional[List[str]] = None,
segmentation_type: Union[str, SegmentationType] = "semantic",
max_gt_downsample=32,
Expand All @@ -187,16 +189,19 @@ def __init__(
raw_max=255,
classes_separator_caracter="&",
):
if not isinstance(input_resolution, Coordinate):
input_resolution = Coordinate(input_resolution)
if not isinstance(output_resolution, Coordinate):
output_resolution = Coordinate(output_resolution)
if isinstance(segmentation_type, str):
segmentation_type = SegmentationType[segmentation_type.lower()]

self.name = name
self.datasets = datasets
self.input_resolution = input_resolution
self.output_resolution = output_resolution
self.targets = targets
self._class_name = None

if isinstance(segmentation_type, str):
segmentation_type = SegmentationType[segmentation_type.lower()]

self.segmentation_type = segmentation_type
self.max_gt_downsample = max_gt_downsample
self.max_gt_upsample = max_gt_upsample
Expand Down Expand Up @@ -369,8 +374,8 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
@staticmethod
def generate_from_csv(
csv_path: Path,
input_resolution: Coordinate,
output_resolution: Coordinate,
input_resolution: Union[Sequence[int], Coordinate],
output_resolution: Union[Sequence[int], Coordinate],
name: Optional[str] = None,
**kwargs,
):
Expand Down

0 comments on commit aaf46b6

Please sign in to comment.