This repository provides the official PyTorch implementation of the following paper:
WildNet: Learning Domain Generalized Semantic Segmentation from the Wild
Suhyeon Lee, Hongje Seong, Seongwon Lee, Euntai Kim
Yonsei University
Abstract: We present a new domain generalized semantic segmentation network named WildNet, which learns domain-generalized features by leveraging a variety of contents and styles from the wild. In domain generalization, the low generalization ability for unseen target domains is clearly due to overfitting to the source domain. To address this problem, previous works have focused on generalizing the domain by removing or diversifying the styles of the source domain. These alleviated overfitting to the source-style but overlooked overfitting to the source-content. In this paper, we propose to diversify both the content and style of the source domain with the help of the wild. Our main idea is for networks to naturally learn domain-generalized semantic information from the wild. To this end, we diversify styles by augmenting source features to resemble wild styles and enable networks to adapt to a variety of styles. Furthermore, we encourage networks to learn class-discriminant features by providing semantic variations borrowed from the wild to source contents in the feature space. Finally, we regularize networks to capture consistent semantic information even when both the content and style of the source domain are extended to the wild. Extensive experiments on five different datasets validate the effectiveness of our WildNet, and we significantly outperform state-of-the-art methods.
Our pytorch implementation is heavily derived from RobustNet (CVPR 2021). If you use this code in your research, please also cite their work. [link to license]
Clone this repository.
git clone https://github.com/suhyeonlee/WildNet.git
cd WildNet
Install following packages.
conda create --name wildnet python=3.7
conda activate wildnet
conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.1 -c pytorch
conda install scipy==1.1.0
conda install tqdm==4.46.0
conda install scikit-image==0.16.2
pip install tensorboardX
pip install thop
pip install kmeans1d
imageio_download_bin freeimage
We trained our model with the source domain (GTAV or Cityscapes) and the wild domain (ImageNet). Then we evaluated the model on Cityscapes, BDD-100K, Synthia (SYNTHIA-RAND-CITYSCAPES), GTAV and Mapillary Vistas.
We adopt Class uniform sampling proposed in this paper to handle class imbalance problems.
- For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
Unzip the files and make the directory structures as follows.
cityscapes
└ leftImg8bit_trainvaltest
└ leftImg8bit
└ train
└ val
└ test
└ gtFine_trainvaltest
└ gtFine
└ train
└ val
└ test
bdd-100k
└ images
└ train
└ val
└ test
└ labels
└ train
└ val
mapillary
└ training
└ images
└ labels
└ validation
└ images
└ labels
└ test
└ images
└ labels
imagenet
└ data
└ train
└ val
- We used GTAV_Split to split GTAV dataset into training/validation/test set. Please refer the txt files in split_data.
GTAV
└ images
└ train
└ folder
└ valid
└ folder
└ test
└ folder
└ labels
└ train
└ folder
└ valid
└ folder
└ test
└ folder
- We split Synthia dataset into train/val set following the RobustNet. Please refer the txt files in split_data.
synthia
└ RGB
└ train
└ val
└ GT
└ COLOR
└ train
└ val
└ LABELS
└ train
└ val
- You should modify the path in "<path_to_wildnet>/config.py" according to your dataset path.
#Cityscapes Dir Location
__C.DATASET.CITYSCAPES_DIR = <YOUR_CITYSCAPES_PATH>
#Mapillary Dataset Dir Location
__C.DATASET.MAPILLARY_DIR = <YOUR_MAPILLARY_PATH>
#GTAV Dataset Dir Location
__C.DATASET.GTAV_DIR = <YOUR_GTAV_PATH>
#BDD-100K Dataset Dir Location
__C.DATASET.BDD_DIR = <YOUR_BDD_PATH>
#Synthia Dataset Dir Location
__C.DATASET.SYNTHIA_DIR = <YOUR_SYNTHIA_PATH>
#ImageNet Dataset Dir Location
__C.DATASET.ImageNet_DIR = <YOUR_ImageNet_PATH>
- You can train WildNet with the following command.
<path_to_wildnet>$ CUDA_VISIBLE_DEVICES=0,1 ./scripts/train_wildnet_r50os16_gtav.sh
- You can download our ResNet-50 model at Google Drive and validate pretrained model with the following command.
<path_to_wildnet>$ CUDA_VISIBLE_DEVICES=0,1 ./scripts/valid_wildnet_r50os16_gtav.sh <weight_file_location>
If you find this work useful in your research, please cite our paper:
@inproceedings{lee2022wildnet,
title={WildNet: Learning Domain Generalized Semantic Segmentation from the Wild},
author={Lee, Suhyeon and Seong, Hongje and Lee, Seongwon and Kim, Euntai},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
This software is for non-commercial use only. The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence (see this for details)