Skip to content

Commit

Permalink
add doctests for example 2/n segmentation (#5083)
Browse files Browse the repository at this point in the history
* draft

* fix

* drop folder

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
Borda and tchaton committed Jan 5, 2021
1 parent 12d6437 commit 2438d74
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
36 changes: 36 additions & 0 deletions pl_examples/domain_templates/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)


def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
for p_dir in (path_dir_images, path_dir_masks):
os.makedirs(p_dir, exist_ok=True)
for i in range(3):
path_img = os.path.join(path_dir_images, f'dummy_kitti_{i}.png')
Image.new('RGB', image_dims).save(path_img)
path_mask = os.path.join(path_dir_masks, f'dummy_kitti_{i}.png')
Image.new('L', image_dims).save(path_mask)


class KITTI(Dataset):
"""
Class for KITTI Semantic Segmentation Benchmark dataset
Expand All @@ -53,6 +66,12 @@ class KITTI(Dataset):
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> KITTI(dataset_path, 'train') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<...semantic_segmentation.KITTI object at ...>
"""
IMAGE_PATH = os.path.join('training', 'image_2')
MASK_PATH = os.path.join('training', 'semantic')
Expand Down Expand Up @@ -141,6 +160,23 @@ class SegModel(pl.LightningModule):
It uses the FCN ResNet50 model as an example.
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
>>> from pl_examples import DATASETS_PATH
>>> dataset_path = os.path.join(DATASETS_PATH, "Kitti")
>>> _create_synth_kitti_dataset(dataset_path, image_dims=(1024, 512))
>>> SegModel(dataset_path) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
SegModel(
(net): UNet(
(layers): ModuleList(
(0): DoubleConv(...)
(1): Down(...)
(2): Down(...)
(3): Up(...)
(4): Up(...)
(5): Conv2d(64, 19, kernel_size=(1, 1), stride=(1, 1))
)
)
)
"""
def __init__(
self,
Expand Down
13 changes: 0 additions & 13 deletions pl_examples/pytorch_ecosystem/__init__.py

This file was deleted.

0 comments on commit 2438d74

Please sign in to comment.