Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CellPose dataset #272

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions scripts/datasets/light_microscopy/check_cellpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.light_microscopy import get_cellpose_loader


ROOT = "/media/anwai/ANWAI/data/cellpose/"


def check_cellpose():
loader = get_cellpose_loader(
path=ROOT,
split="train",
patch_shape=(512, 512),
batch_size=1,
choice="cyto",
)
check_loader(loader, 8, instance_labels=True)


if __name__ == "__main__":
check_cellpose()
1 change: 1 addition & 0 deletions torch_em/data/datasets/light_microscopy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cellpose import get_cellpose_loader, get_cellpose_dataset
from .cellseg_3d import get_cellseg_3d_loader, get_cellseg_3d_dataset
from .covid_if import get_covid_if_loader, get_covid_if_dataset
from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset
Expand Down
118 changes: 118 additions & 0 deletions torch_em/data/datasets/light_microscopy/cellpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""This dataset contains annotation for cell segmentation in fluorescene microscently-labeled microscopy images.

This dataset is from the publication https://doi.org/10.1038/s41592-020-01018-x.
Please cite it if you use this dataset in your research.
"""


import os
from glob import glob
from natsort import natsorted
from typing import Union, Tuple

import torch_em

from .. import util
from .neurips_cell_seg import to_rgb
from ... import ImageCollectionDataset


URL = "https://www.cellpose.org/dataset"


def _get_cellpose_paths(path, split, choice):
if choice == "cyto":
assert split in ["train", "test"], f"'{split}' is not a valid split in '{choice}'."
elif choice == "cyto2":
assert split == "train", f"'{split}' is not a valid split in '{choice}'."
else:
raise ValueError(f"'{choice}' is not a valid dataset choice.")

image_paths = natsorted(glob(os.path.join(path, choice, split, "*_img.png")))
gt_paths = natsorted(glob(os.path.join(path, choice, split, "*_masks.png")))

return image_paths, gt_paths


def get_cellpose_dataset(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
choice: str = "cyto",
download: bool = False,
**kwargs
):
"""Get the CellPose dataset for cell segmentation.

Args:
TODO

Returns:
The segmentation dataset.
"""
assert choice in ["cyto", "cyto2"]
assert split in ["train", "test"]

if download:
assert NotImplementedError(
"The dataset cannot be automatically downloaded. ",
"Please see 'get_cellpose_dataset' in 'torch_em/data/datasets/cellpose.py' for details."
)

image_paths, gt_paths = _get_cellpose_paths(path=path, split=split, choice=choice)

if "raw_transform" not in kwargs:
raw_transform = torch_em.transform.get_raw_transform(augmentation2=to_rgb)

if "transform" not in kwargs:
transform = torch_em.transform.get_augmentations(ndim=2)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key=None,
label_paths=gt_paths,
label_key=None,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
**kwargs
)
dataset = ImageCollectionDataset(
raw_image_paths=image_paths,
label_image_paths=gt_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
)

return dataset


def get_cellpose_loader(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
batch_size: int,
choice: str = "cyto",
download: bool = False,
**kwargs
):
"""Get the CellPose dataloader for cell segmentation.

Args:
TODO

Returns:
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cellpose_dataset(
path=path,
split=split,
patch_shape=patch_shape,
choice=choice,
download=download,
**ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
Loading