diff --git a/README.md b/README.md index d8cf4776ae..8ce2d9d828 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ Supported datasets: - [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#loveda) - [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-potsdam) - [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isprs-vaihingen) +- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#isaid) ## Installation diff --git a/README_zh-CN.md b/README_zh-CN.md index 12b69a3ba8..eceeebf0b3 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -137,6 +137,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [LoveDA](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#loveda) - [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-potsdam) - [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isprs-vaihingen) +- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/zh_cn/dataset_prepare.md#isaid) ## 安装 diff --git a/configs/_base_/datasets/isaid.py b/configs/_base_/datasets/isaid.py new file mode 100644 index 0000000000..8e4c26abb7 --- /dev/null +++ b/configs/_base_/datasets/isaid.py @@ -0,0 +1,62 @@ +# dataset settings +dataset_type = 'iSAIDDataset' +data_root = 'data/iSAID' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +""" +This crop_size setting is followed by the implementation of +`PointFlow: Flowing Semantics Through Points for Aerial Image +Segmentation `_. +""" + +crop_size = (896, 896) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(896, 896), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(896, 896), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/train', + ann_dir='ann_dir/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline)) diff --git a/configs/deeplabv3plus/README.md b/configs/deeplabv3plus/README.md index 91b66dd504..4fb7d13118 100644 --- a/configs/deeplabv3plus/README.md +++ b/configs/deeplabv3plus/README.md @@ -114,8 +114,16 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur | DeepLabV3+ | R-50-D8 | 512x512 | 80000 | 7.36 | 26.91 | 73.97 | 75.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816-5040938d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r50-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) | | DeepLabV3+ | R-101-D8 | 512x512 | 80000 | 10.83 | 18.59 | 73.06 | 74.14 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816.log.json) | +### iSAID + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| DeepLabV3+ | R-18-D8 | 896x896 | 80000 | 6.19 | 24.81 | 61.35 | 62.61 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) | +| DeepLabV3+ | R-50-D8 | 896x896 | 80000 | 21.45 | 8.42 | 67.06 | 68.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) | + Note: - `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series. - `MG-124` stands for multi-grid dilation in the last stage of ResNet. - `FP16` means Mixed Precision (FP16) is adopted in training. +- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf) diff --git a/configs/deeplabv3plus/deeplabv3plus.yml b/configs/deeplabv3plus/deeplabv3plus.yml index b68d7e90a5..a587216a49 100644 --- a/configs/deeplabv3plus/deeplabv3plus.yml +++ b/configs/deeplabv3plus/deeplabv3plus.yml @@ -10,6 +10,7 @@ Collections: - LoveDA - Potsdam - Vaihingen + - iSAID Paper: URL: https://arxiv.org/abs/1802.02611 Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation @@ -803,3 +804,47 @@ Models: mIoU(ms+flip): 74.14 Config: configs/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen/deeplabv3plus_r101-d8_4x4_512x512_80k_vaihingen_20211231_230816-8a095afa.pth +- Name: deeplabv3plus_r18-d8_4x4_896x896_80k_isaid + In Collection: deeplabv3plus + Metadata: + backbone: R-18-D8 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 40.31 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 6.19 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 61.35 + mIoU(ms+flip): 62.61 + Config: configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth +- Name: deeplabv3plus_r50-d8_4x4_896x896_80k_isaid + In Collection: deeplabv3plus + Metadata: + backbone: R-50-D8 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 118.76 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 21.45 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 67.06 + mIoU(ms+flip): 68.02 + Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth diff --git a/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py b/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..892a8a30e9 --- /dev/null +++ b/configs/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid.py @@ -0,0 +1,11 @@ +_base_ = './deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py' +model = dict( + pretrained='open-mmlab://resnet18_v1c', + backbone=dict(depth=18), + decode_head=dict( + c1_in_channels=64, + c1_channels=12, + in_channels=512, + channels=128, + ), + auxiliary_head=dict(in_channels=256, channels=64)) diff --git a/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py b/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..a1a8beb82d --- /dev/null +++ b/configs/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/deeplabv3plus_r50-d8.py', '../_base_/datasets/isaid.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16)) diff --git a/configs/hrnet/README.md b/configs/hrnet/README.md index 885ec19b18..225a06f448 100644 --- a/configs/hrnet/README.md +++ b/configs/hrnet/README.md @@ -107,3 +107,15 @@ High-resolution representations are essential for position-sensitive vision prob | FCN | HRNetV2p-W18-Small | 512x512 | 80000 | 1.58 | 38.11 | 71.81 | 73.1 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909-b23aae02.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_512x512_80k_vaihingen/fcn_hr18s_4x4_512x512_80k_vaihingen_20211231_230909.log.json) | | FCN | HRNetV2p-W18 | 512x512 | 80000 | 2.76 | 19.55 | 72.57 | 74.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216-2ec3ae8a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_512x512_80k_vaihingen/fcn_hr18_4x4_512x512_80k_vaihingen_20211231_231216.log.json) | | FCN | HRNetV2p-W48 | 512x512 | 80000 | 6.20 | 17.25 | 72.50 | 73.52 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244.log.json) | + +### iSAID + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| FCN | HRNetV2p-W18-Small | 896x896 | 80000 | 4.95 | 13.84 | 62.30 | 62.97 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603.log.json) | +| FCN | HRNetV2p-W18 | 896x896 | 80000 | 8.30 | 7.71 | 65.06 | 65.60 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230.log.json) | +| FCN | HRNetV2p-W48 | 896x896 | 80000 | 16.89 | 7.34 | 67.80 | 68.53 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643.log.json) | + +Note: + +- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf) diff --git a/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py b/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..62e6d6bf0e --- /dev/null +++ b/configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/fcn_hr18.py', '../_base_/datasets/isaid.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict(decode_head=dict(num_classes=16)) diff --git a/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py b/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..d6f6c657a5 --- /dev/null +++ b/configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py @@ -0,0 +1,9 @@ +_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py' +model = dict( + pretrained='open-mmlab://msra/hrnetv2_w18_small', + backbone=dict( + extra=dict( + stage1=dict(num_blocks=(2, )), + stage2=dict(num_blocks=(2, 2)), + stage3=dict(num_modules=3, num_blocks=(2, 2, 2)), + stage4=dict(num_modules=2, num_blocks=(2, 2, 2, 2))))) diff --git a/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py b/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..55cf1b55bd --- /dev/null +++ b/configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py @@ -0,0 +1,10 @@ +_base_ = './fcn_hr18_4x4_896x896_80k_isaid.py' +model = dict( + pretrained='open-mmlab://msra/hrnetv2_w48', + backbone=dict( + extra=dict( + stage2=dict(num_channels=(48, 96)), + stage3=dict(num_channels=(48, 96, 192)), + stage4=dict(num_channels=(48, 96, 192, 384)))), + decode_head=dict( + in_channels=[48, 96, 192, 384], channels=sum([48, 96, 192, 384]))) diff --git a/configs/hrnet/hrnet.yml b/configs/hrnet/hrnet.yml index 2854c15220..cd989dfbab 100644 --- a/configs/hrnet/hrnet.yml +++ b/configs/hrnet/hrnet.yml @@ -10,6 +10,7 @@ Collections: - LoveDA - Potsdam - Vaihingen + - iSAID Paper: URL: https://arxiv.org/abs/1908.07919 Title: Deep High-Resolution Representation Learning for Human Pose Estimation @@ -648,3 +649,69 @@ Models: mIoU(ms+flip): 73.52 Config: configs/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_512x512_80k_vaihingen/fcn_hr48_4x4_512x512_80k_vaihingen_20211231_231244-7133cb22.pth +- Name: fcn_hr18s_4x4_896x896_80k_isaid + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W18-Small + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 72.25 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 4.95 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 62.3 + mIoU(ms+flip): 62.97 + Config: configs/hrnet/fcn_hr18s_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18s_4x4_896x896_80k_isaid/fcn_hr18s_4x4_896x896_80k_isaid_20220118_001603-3cc0769b.pth +- Name: fcn_hr18_4x4_896x896_80k_isaid + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W18 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 129.7 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 8.3 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 65.06 + mIoU(ms+flip): 65.6 + Config: configs/hrnet/fcn_hr18_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr18_4x4_896x896_80k_isaid/fcn_hr18_4x4_896x896_80k_isaid_20220110_182230-49bf752e.pth +- Name: fcn_hr48_4x4_896x896_80k_isaid + In Collection: hrnet + Metadata: + backbone: HRNetV2p-W48 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 136.24 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 16.89 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 67.8 + mIoU(ms+flip): 68.53 + Config: configs/hrnet/fcn_hr48_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/hrnet/fcn_hr48_4x4_896x896_80k_isaid/fcn_hr48_4x4_896x896_80k_isaid_20220114_174643-547fc420.pth diff --git a/configs/pspnet/README.md b/configs/pspnet/README.md index ca8bddabb1..6223b5ea7a 100644 --- a/configs/pspnet/README.md +++ b/configs/pspnet/README.md @@ -148,6 +148,14 @@ We support evaluation results on these two datasets using models above trained o | PSPNet | R-50-D8 | 512x512 | 80000 | 6.14 | 30.29 | 72.36 | 73.75 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355-382f8f5b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_512x512_80k_vaihingen/pspnet_r50-d8_4x4_512x512_80k_vaihingen_20211228_160355.log.json) | | PSPNet | R-101-D8 | 512x512 | 80000 | 9.61 | 19.97 | 72.61 | 74.18 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806.log.json) | +### iSAID + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ---------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| PSPNet | R-18-D8 | 896x896 | 80000 | 4.52 | 26.91 | 60.22 | 61.25 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) | +| PSPNet | R-50-D8 | 896x896 | 80000 | 16.58 | 8.88 | 65.36 | 66.48 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629.log.json) | + Note: - `FP16` means Mixed Precision (FP16) is adopted in training. +- `896x896` is the Crop Size of iSAID dataset, which is followed by the implementation of [PointFlow: Flowing Semantics Through Points for Aerial Image Segmentation](https://arxiv.org/pdf/2103.06564.pdf) diff --git a/configs/pspnet/pspnet.yml b/configs/pspnet/pspnet.yml index a78f2c8753..087367bf21 100644 --- a/configs/pspnet/pspnet.yml +++ b/configs/pspnet/pspnet.yml @@ -13,6 +13,7 @@ Collections: - LoveDA - Potsdam - Vaihingen + - iSAID Paper: URL: https://arxiv.org/abs/1612.01105 Title: Pyramid Scene Parsing Network @@ -942,3 +943,47 @@ Models: mIoU(ms+flip): 74.18 Config: configs/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen.py Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_4x4_512x512_80k_vaihingen/pspnet_r101-d8_4x4_512x512_80k_vaihingen_20211231_230806-8eba0a09.pth +- Name: pspnet_r18-d8_4x4_896x896_80k_isaid + In Collection: pspnet + Metadata: + backbone: R-18-D8 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 37.16 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 4.52 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 60.22 + mIoU(ms+flip): 61.25 + Config: configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid/pspnet_r18-d8_4x4_896x896_80k_isaid_20220110_180526-e84c0b6a.pth +- Name: pspnet_r50-d8_4x4_896x896_80k_isaid + In Collection: pspnet + Metadata: + backbone: R-50-D8 + crop size: (896,896) + lr schd: 80000 + inference time (ms/im): + - value: 112.61 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (896,896) + Training Memory (GB): 16.58 + Results: + - Task: Semantic Segmentation + Dataset: iSAID + Metrics: + mIoU: 65.36 + mIoU(ms+flip): 66.48 + Config: configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid/pspnet_r50-d8_4x4_896x896_80k_isaid_20220110_180629-1f21dc32.pth diff --git a/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py b/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..4f6f9ab253 --- /dev/null +++ b/configs/pspnet/pspnet_r18-d8_4x4_896x896_80k_isaid.py @@ -0,0 +1,9 @@ +_base_ = './pspnet_r50-d8_4x4_896x896_80k_isaid.py' +model = dict( + pretrained='open-mmlab://resnet18_v1c', + backbone=dict(depth=18), + decode_head=dict( + in_channels=512, + channels=128, + ), + auxiliary_head=dict(in_channels=256, channels=64)) diff --git a/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py b/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py new file mode 100644 index 0000000000..ef7eb99280 --- /dev/null +++ b/configs/pspnet/pspnet_r50-d8_4x4_896x896_80k_isaid.py @@ -0,0 +1,6 @@ +_base_ = [ + '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/isaid.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(num_classes=16), auxiliary_head=dict(num_classes=16)) diff --git a/docs/en/dataset_prepare.md b/docs/en/dataset_prepare.md index c115e86d56..d98b080d41 100644 --- a/docs/en/dataset_prepare.md +++ b/docs/en/dataset_prepare.md @@ -123,6 +123,21 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── vaihingen +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val +│ ├── iSAID +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes @@ -325,3 +340,38 @@ python tools/convert_datasets/vaihingen.py /path/to/vaihingen ``` In our default setting (`clip_size` =512, `stride_size`=256), it will generate 344 images for training and 398 images for validation. + +### iSAID +The data images could be download from [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) (train/val/test) + +The data annotations could be download from [iSAID](https://captain-whu.github.io/iSAID/dataset.html) (train/val) + +The dataset is a Large-scale Dataset for Instance Segmentation (also have segmantic segmentation) in Aerial Images. + +You may need to follow the following structure for dataset preparation after downloading iSAID dataset. + +``` +│ ├── iSAID +│ │ ├── train +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ │ ├── part2.zip +│ │ │ │ ├── part3.zip +│ │ │ ├── Semantic_masks +│ │ │ │ ├── images.zip +│ │ ├── val +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ ├── Semantic_masks +│ │ │ │ ├── images.zip +│ │ ├── test +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ │ ├── part2.zip +``` + +```shell +python tools/convert_datasets/isaid.py /path/to/iSAID +``` + +In our default setting (`clip_size` =512, `stride_size`=256), it will generate 33978 images for training and 11644 images for validation. diff --git a/docs/zh_cn/dataset_prepare.md b/docs/zh_cn/dataset_prepare.md index 9a8428a64f..5df5881603 100644 --- a/docs/zh_cn/dataset_prepare.md +++ b/docs/zh_cn/dataset_prepare.md @@ -104,6 +104,21 @@ mmsegmentation │ │ ├── ann_dir │ │ │ ├── train │ │ │ ├── val +│ ├── vaihingen +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val +│ ├── iSAID +│ │ ├── img_dir +│ │ │ ├── train +│ │ │ ├── val +│ │ │ ├── test +│ │ ├── ann_dir +│ │ │ ├── train +│ │ │ ├── val ``` ### Cityscapes @@ -265,4 +280,39 @@ python tools/convert_datasets/potsdam.py /path/to/potsdam python tools/convert_datasets/vaihingen.py /path/to/vaihingen ``` -使用我们默认的配置 (`clip_size` =512, `stride_size`=256), 将生成 344 张图片的训练集和 398 张图片的验证集。 +使用我们默认的配置 (`clip_size`=512, `stride_size`=256), 将生成 344 张图片的训练集和 398 张图片的验证集。 + +### iSAID +iSAID 数据集(训练集/验证集/测试集)的图像可以从 [DOTA-v1.0](https://captain-whu.github.io/DOTA/dataset.html) 下载. + +iSAID 数据集(训练集/验证集)的注释可以从 [iSAID](https://captain-whu.github.io/iSAID/dataset.html) 下载. + +该数据集是一个大规模的实例分割(也可以用于语义分割)的遥感数据集. + +下载后,在数据集转换前,您需要将数据集文件夹调整成如下格式. + +``` +│ ├── iSAID +│ │ ├── train +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ │ ├── part2.zip +│ │ │ │ ├── part3.zip +│ │ │ ├── Semantic_masks +│ │ │ │ ├── images.zip +│ │ ├── val +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ ├── Semantic_masks +│ │ │ │ ├── images.zip +│ │ ├── test +│ │ │ ├── images +│ │ │ │ ├── part1.zip +│ │ │ │ ├── part2.zip +``` + +```shell +python tools/convert_datasets/isaid.py /path/to/iSAID +``` + +使用我们默认的配置 (`patch_width`=896, `patch_height`=896, `overlap_area`=384), 将生成 33978 张图片的训练集和 11644 张图片的验证集。 diff --git a/mmseg/core/evaluation/class_names.py b/mmseg/core/evaluation/class_names.py index cc90517b4c..bc591c6e2f 100644 --- a/mmseg/core/evaluation/class_names.py +++ b/mmseg/core/evaluation/class_names.py @@ -111,6 +111,16 @@ def vaihingen_classes(): ] +def isaid_classes(): + """iSAID class names for external use.""" + return [ + 'background', 'ship', 'store_tank', 'baseball_diamond', 'tennis_court', + 'basketball_court', 'Ground_Track_Field', 'Bridge', 'Large_Vehicle', + 'Small_Vehicle', 'Helicopter', 'Swimming_pool', 'Roundabout', + 'Soccer_ball_field', 'plane', 'Harbor' + ] + + def cityscapes_palette(): """Cityscapes palette for external use.""" return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], @@ -236,6 +246,15 @@ def vaihingen_palette(): [255, 255, 0], [255, 0, 0]] +def isaid_palette(): + """iSAID palette for external use.""" + return [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, + 127], [0, 0, 127], + [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 127, 191], + [0, 127, 255], [0, 100, 155]] + + dataset_aliases = { 'cityscapes': ['cityscapes'], 'ade': ['ade', 'ade20k'], @@ -247,7 +266,8 @@ def vaihingen_palette(): 'cocostuff', 'cocostuff10k', 'cocostuff164k', 'coco-stuff', 'coco-stuff10k', 'coco-stuff164k', 'coco_stuff', 'coco_stuff10k', 'coco_stuff164k' - ] + ], + 'isaid': ['isaid', 'iSAID'] } diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index 9f14325fea..5d42a11c26 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -10,6 +10,7 @@ RepeatDataset) from .drive import DRIVEDataset from .hrf import HRFDataset +from .isaid import iSAIDDataset from .isprs import ISPRSDataset from .loveda import LoveDADataset from .night_driving import NightDrivingDataset @@ -25,5 +26,5 @@ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', - 'ISPRSDataset', 'PotsdamDataset' + 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset' ] diff --git a/mmseg/datasets/isaid.py b/mmseg/datasets/isaid.py new file mode 100644 index 0000000000..2b63d9273d --- /dev/null +++ b/mmseg/datasets/isaid.py @@ -0,0 +1,82 @@ +import os.path as osp + +import mmcv +from mmcv.utils import print_log + +from ..utils import get_root_logger +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class iSAIDDataset(CustomDataset): + """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images + In segmentation map annotation for iSAID dataset, which is included + in 16 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor') + + PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], + [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], + [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], + [0, 127, 191], [0, 127, 255], [0, 100, 155]] + + def __init__(self, **kwargs): + super(iSAIDDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + ignore_index=255, + **kwargs) + assert osp.exists(self.img_dir) + + def load_annotations(self, + img_dir, + img_suffix, + ann_dir, + seg_map_suffix=None, + split=None): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + with open(split) as f: + for line in f: + name = line.strip() + img_info = dict(filename=name + img_suffix) + if ann_dir is not None: + ann_name = name + '_instance_color_RGB' + seg_map = ann_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in mmcv.scandir(img_dir, img_suffix, recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_img = img + seg_map = seg_img.replace( + img_suffix, '_instance_color_RGB' + seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos diff --git a/setup.cfg b/setup.cfg index 4839120ab6..23cb09e698 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,4 +16,4 @@ default_section = THIRDPARTY skip = *.po,*.ts,*.ipynb count = quiet-level = 3 -ignore-words-list = formating,sur,hist +ignore-words-list = formating,sur,hist,dota diff --git a/tests/data/pseudo_isaid_dataset/ann_dir/P0000_0_896_1024_1920_instance_color_RGB.png b/tests/data/pseudo_isaid_dataset/ann_dir/P0000_0_896_1024_1920_instance_color_RGB.png new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/pseudo_isaid_dataset/ann_dir/P0000_0_896_1536_2432_instance_color_RGB.png b/tests/data/pseudo_isaid_dataset/ann_dir/P0000_0_896_1536_2432_instance_color_RGB.png new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/pseudo_isaid_dataset/img_dir/P0000_0_896_1024_1920.png b/tests/data/pseudo_isaid_dataset/img_dir/P0000_0_896_1024_1920.png new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/pseudo_isaid_dataset/img_dir/P0000_0_896_1536_2432.png b/tests/data/pseudo_isaid_dataset/img_dir/P0000_0_896_1536_2432.png new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/data/pseudo_isaid_dataset/splits/train.txt b/tests/data/pseudo_isaid_dataset/splits/train.txt new file mode 100644 index 0000000000..c310167fe1 --- /dev/null +++ b/tests/data/pseudo_isaid_dataset/splits/train.txt @@ -0,0 +1 @@ +P0000_0_896_1536_2432 diff --git a/tests/data/pseudo_isaid_dataset/splits/val.txt b/tests/data/pseudo_isaid_dataset/splits/val.txt new file mode 100644 index 0000000000..aeff0ee339 --- /dev/null +++ b/tests/data/pseudo_isaid_dataset/splits/val.txt @@ -0,0 +1 @@ +P0000_0_896_1024_1920 diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 3d4c40a016..6ea6eb9852 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -16,7 +16,7 @@ COCOStuffDataset, ConcatDataset, CustomDataset, ISPRSDataset, LoveDADataset, MultiImageMixDataset, PascalVOCDataset, PotsdamDataset, RepeatDataset, - build_dataset) + build_dataset, iSAIDDataset) def test_classes(): @@ -25,10 +25,11 @@ def test_classes(): 'pascal_voc') assert list( ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k') + assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff') assert list(LoveDADataset.CLASSES) == get_classes('loveda') assert list(PotsdamDataset.CLASSES) == get_classes('potsdam') assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen') - assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff') + assert list(iSAIDDataset.CLASSES) == get_classes('isaid') with pytest.raises(ValueError): get_classes('unsupported') @@ -73,6 +74,7 @@ def test_palette(): assert LoveDADataset.PALETTE == get_palette('loveda') assert PotsdamDataset.PALETTE == get_palette('potsdam') assert COCOStuffDataset.PALETTE == get_palette('cocostuff') + assert iSAIDDataset.PALETTE == get_palette('isaid') with pytest.raises(ValueError): get_palette('unsupported') @@ -730,6 +732,27 @@ def test_vaihingen(): assert len(test_dataset) == 1 +def test_isaid(): + test_dataset = iSAIDDataset( + pipeline=[], + img_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir')) + assert len(test_dataset) == 2 + isaid_info = test_dataset.load_annotations( + img_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'), + img_suffix='.png', + ann_dir=osp.join( + osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'), + seg_map_suffix='.png', + split=osp.join( + osp.dirname(__file__), + '../data/pseudo_isaid_dataset/splits/train.txt')) + assert len(isaid_info) == 1 + + @patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock) @patch('mmseg.datasets.CustomDataset.__getitem__', MagicMock(side_effect=lambda idx: idx)) diff --git a/tools/convert_datasets/isaid.py b/tools/convert_datasets/isaid.py new file mode 100644 index 0000000000..e641ef8283 --- /dev/null +++ b/tools/convert_datasets/isaid.py @@ -0,0 +1,244 @@ +import argparse +import glob +import os +import os.path as osp +import shutil +import tempfile +import zipfile + +import mmcv +import numpy as np +from PIL import Image + +iSAID_palette = \ + { + 0: (0, 0, 0), + 1: (0, 0, 63), + 2: (0, 63, 63), + 3: (0, 63, 0), + 4: (0, 63, 127), + 5: (0, 63, 191), + 6: (0, 63, 255), + 7: (0, 127, 63), + 8: (0, 127, 127), + 9: (0, 0, 127), + 10: (0, 0, 191), + 11: (0, 0, 255), + 12: (0, 191, 127), + 13: (0, 127, 191), + 14: (0, 127, 255), + 15: (0, 100, 155) + } + +iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()} + + +def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette): + """RGB-color encoding to grayscale labels.""" + arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) + + for c, i in palette.items(): + m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) + arr_2d[m] = i + + return arr_2d + + +def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): + img = np.asarray(Image.open(src_path).convert('RGB')) + + img_H, img_W, _ = img.shape + + if img_H < patch_H and img_W > patch_W: + + img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H > patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + elif img_H < patch_H and img_W < patch_W: + + img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0) + + img_H, img_W, _ = img.shape + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + img_patch = img[y_str:y_end, x_str:x_end, :] + img_patch = Image.fromarray(img_patch.astype(np.uint8)) + image = osp.splitext( + src_path.split('/')[-1])[0] + '_' + str(y_str) + '_' + str( + y_end) + '_' + str(x_str) + '_' + str(x_end) + '.png' + # print(image) + save_path_image = osp.join(out_dir, 'img_dir', mode, str(image)) + img_patch.save(save_path_image) + + +def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): + label = mmcv.imread(src_path, channel_order='rgb') + label = iSAID_convert_from_color(label) + img_H, img_W = label.shape + + if img_H < patch_H and img_W > patch_W: + + label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255) + + img_H = patch_H + + elif img_H > patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255) + + img_W = patch_W + + elif img_H < patch_H and img_W < patch_W: + + label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255) + + img_H = patch_H + img_W = patch_W + + for x in range(0, img_W, patch_W - overlap): + for y in range(0, img_H, patch_H - overlap): + x_str = x + x_end = x + patch_W + if x_end > img_W: + diff_x = x_end - img_W + x_str -= diff_x + x_end = img_W + y_str = y + y_end = y + patch_H + if y_end > img_H: + diff_y = y_end - img_H + y_str -= diff_y + y_end = img_H + + lab_patch = label[y_str:y_end, x_str:x_end] + lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P') + + image = osp.splitext(src_path.split('/')[-1])[0].split( + '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str( + x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png' + lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image))) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert iSAID dataset to mmsegmentation format') + parser.add_argument('dataset_path', help='iSAID folder path') + parser.add_argument('--tmp_dir', help='path of the temporary directory') + parser.add_argument('-o', '--out_dir', help='output path') + + parser.add_argument( + '--patch_width', + default=896, + type=int, + help='Width of the cropped image patch') + parser.add_argument( + '--patch_height', + default=896, + type=int, + help='Height of the cropped image patch') + parser.add_argument( + '--overlap_area', default=384, type=int, help='Overlap area') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + dataset_path = args.dataset_path + # image patch width and height + patch_H, patch_W = args.patch_width, args.patch_height + + overlap = args.overlap_area # overlap area + + if args.out_dir is None: + out_dir = osp.join('data', 'iSAID') + else: + out_dir = args.out_dir + + print('Making directories...') + mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) + + mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) + mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) + + assert os.path.exists(os.path.join(dataset_path, 'train')), \ + 'train is not in {}'.format(dataset_path) + assert os.path.exists(os.path.join(dataset_path, 'val')), \ + 'val is not in {}'.format(dataset_path) + assert os.path.exists(os.path.join(dataset_path, 'test')), \ + 'test is not in {}'.format(dataset_path) + + with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: + for dataset_mode in ['train', 'val', 'test']: + + # for dataset_mode in [ 'test']: + print('Extracting {}ing.zip...'.format(dataset_mode)) + img_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) + print('Find the data', img_zipp_list) + for img_zipp in img_zipp_list: + zip_file = zipfile.ZipFile(img_zipp) + zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img')) + src_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png')) + + src_prog_bar = mmcv.ProgressBar(len(src_path_list)) + for i, img_path in enumerate(src_path_list): + if dataset_mode != 'test': + slide_crop_image(img_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + + else: + shutil.move(img_path, + os.path.join(out_dir, 'img_dir', dataset_mode)) + src_prog_bar.update() + + if dataset_mode != 'test': + label_zipp_list = glob.glob( + os.path.join(dataset_path, dataset_mode, 'Semantic_masks', + '*.zip')) + for label_zipp in label_zipp_list: + zip_file = zipfile.ZipFile(label_zipp) + zip_file.extractall( + os.path.join(tmp_dir, dataset_mode, 'lab')) + + lab_path_list = glob.glob( + os.path.join(tmp_dir, dataset_mode, 'lab', 'images', + '*.png')) + lab_prog_bar = mmcv.ProgressBar(len(lab_path_list)) + for i, lab_path in enumerate(lab_path_list): + slide_crop_label(lab_path, out_dir, dataset_mode, patch_H, + patch_W, overlap) + lab_prog_bar.update() + + print('Removing the temporary files...') + + print('Done!') + + +if __name__ == '__main__': + main()