diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a16ce8b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +* linguist-vendored +*.py linguist-vendored=false \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eefd481 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +.vscode/* +checkpoints/* +logs/* +**/__pycache__/* +**/*.tfrecord +**/*.zip +data/*/**/*.png +data/*/**/*.jpg +data/widerface/**/*.txt +widerface_evaluate/* +results/* +normal +normal.pdf +reduction +reduction.pdf diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9f0028d --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Kuan-Yu Huang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..ab0a43c --- /dev/null +++ b/README.md @@ -0,0 +1,218 @@ +# [pcdarts-tf2](https://github.com/peteryuX/pcdarts-tf2) + +![Star](https://img.shields.io/github/stars/peteryuX/pcdarts-tf2) +![Fork](https://img.shields.io/github/forks/peteryuX/pcdarts-tf2) +![License](https://img.shields.io/github/license/peteryuX/pcdarts-tf2) + +:fire: PC-DARTS (PC-DARTS: Partial Channel Connections for Memory-Efficient Differentiable Architecture Search, published in ICLR 2020) implemented in Tensorflow 2.0+. This is an unofficial implementation. :fire: + +> PC-DARTS is a memory efficient differentiable architecture search method, which can be trained with a larger batch size and, consequently, enjoys both faster speed and higher training stability. Experimental results achieve an error rate of **2.57%** on CIFAR10 with merely **0.1 GPU-days** for architecture search. + +Original Paper:   [Arxiv](https://arxiv.org/abs/1907.05737)   [OpenReview](https://openreview.net/forum?id=BJlS634tPr) + +Offical Implementation:   [PyTorch](https://github.com/yuhuixu1993/PC-DARTS) + +

+ +

+ +**** + +## Contents +:bookmark_tabs: + +* [Installation](#Installation) +* [Usage](#Training-and-Testing) +* [Benchmark](#Benchmark) +* [Models](#Models) +* [References](#References) + +*** + +## Installation +:pizza: + +Create a new python virtual environment by [Anaconda](https://www.anaconda.com/) or just use pip in your python environment and then clone this repository as following. + +### Clone this repo +```bash +git clone https://github.com/peteryuX/pcdarts-tf2.git +cd pcdarts-tf2 +``` + +### Conda +```bash +conda env create -f environment.yml +conda activate pcdarts-tf2 +``` + +### Pip + +```bash +pip install -r requirements.txt +``` + +**** + +## Usage +:lollipop: + +### Config File +You can modify your own dataset path or other settings of model in [./configs/*.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs) for training and testing, which would like below. + +```python +# general setting +batch_size: 128 +input_size: 32 +init_channels: 36 +layers: 20 +num_classes: 10 +auxiliary_weight: 0.4 +drop_path_prob: 0.3 +arch: PCDARTS +sub_name: 'pcdarts_cifar10' +using_normalize: True + +# training dataset +dataset_len: 50000 # number of training samples +using_crop: True +using_flip: True +using_cutout: True +cutout_length: 16 + +# training setting +epoch: 600 +init_lr: 0.025 +lr_min: 0.0 +momentum: 0.9 +weights_decay: !!float 3e-4 +grad_clip: 5.0 + +val_steps: 1000 +save_steps: 1000 +``` + +Note: +- The `sub_name` is the name of outputs directory used in checkpoints and logs folder. (make sure of setting it unique to other models) +- The `save_steps` is the number interval steps of saving checkpoint file. +- The [./configs/pcdarts_cifar10_search.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs/pcdarts_cifar10_search.yaml) and [./configs/pcdarts_cifar10.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs/pcdarts_cifar10.yaml) are used by [train_search.py](https://github.com/peteryuX/pcdarts-tf2/tree/master/train_search.py) and [train.py](https://github.com/peteryuX/pcdarts-tf2/tree/master/train.py) respectively, which have different settings for small proxy model training(architecture searching) and full-size model training. Please make sure you use the correct config file in related script. (The example yaml script above is [./configs/pcdarts_cifar10.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs/pcdarts_cifar10.yaml).) + +### Architecture Searching on CIFAR-10 (using small proxy model) + +**Step1**: Search cell architecture on CIFAR-10 using small proxy model. + +```bash +python train_search.py --cfg_path="./configs/pcdarts_cifar10_search.yaml" --gpu=0 +``` + +Note: +- The `--gpu` is used to choose the id of your avaliable GPU devices with `CUDA_VISIBLE_DEVICES` system varaible. +- You can visualize the training status on tensorboard by running "`tensorboard --logdir=./logs/`" +- You can visualize the learning rate scheduling by running "`python ./modules/lr_scheduler.py`". +- You can visualize the dataset augmantation by running "`python ./dataset_checker.py`". + +**Step2**: After the searching completed, you can find the result genotypes in `./logs/{sub_name}/search_arch_genotype.py`. Open it and copy the latest genotype into the [./modules/genotypes.py](https://github.com/peteryuX/pcdarts-tf2/tree/master/modules/genotypes.py), which will be used for further training later. The genotype like bellow: + +```python +TheNameYouWantToCall = Genotype( + normal=[ + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('sep_conv_3x3', 0), + ('dil_conv_3x3', 1), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 1), + ('avg_pool_3x3', 0), + ('dil_conv_3x3', 1)], + normal_concat=range(2, 6), + reduce=[ + ('sep_conv_5x5', 1), + ('max_pool_3x3', 0), + ('sep_conv_5x5', 1), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 3), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 2)], + reduce_concat=range(2, 6)) +``` + +Note: +- You can visualize the genotype by running "`python ./visualize_genotype.py TheNameYouWantToCall`". +

+ + +

+ +### Training on CIFAR-10 (using full-sized model) + +**Step1**: Make sure that you already modifed the flag `arch` in [./configs/pcdarts_cifar10.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs/pcdarts_cifar10.yaml) to match the genotype you want to use in [./modules/genotypes.py](https://github.com/peteryuX/pcdarts-tf2/tree/master/modules/genotypes.py). + +Note: +- The default flag `arch` (`PCDARTS`) is the genotype proposed by official paper. You can train this model by yourself, or use dowload it from [BenchmarkModels](#Models). + +**Step2**: Train the full-sized model on CIFAR-10 with specific genotype. + +```bash +python train.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0 +``` + +### Testing on CIFAR-10 (using full-sized model) + +To evaluate the full-sized model with the corresponding cfg file on the testing dataset. You can also download my trained model for testing from [Models](#Models) without training it yourself, which default `arch` (`PCDARTS`) is the best cell proposed in paper. + +```bash +python test.py --cfg_path="./configs/pcdarts_cifar10.yaml" --gpu=0 +``` + +**** + +## Benchmark +:coffee: + +### Results on CIFAR-10 +| Method | Search Method | Params(M) | Test Error(%)| Search-Cost(GPU-days) | +| ------ | ------------- | --------- | ------------ | --------------------- | +| [NASNet-A](https://arxiv.org/abs/1611.01578) | RL | 3.3 | 2.65 | 1800 | +| [AmoebaNet-B](https://arxiv.org/abs/1802.01548) | Evolution | 2.8 | 2.55 | 3150 | +| [ENAS](https://arxiv.org/abs/1802.03268) | RL | 4.6 | 2.89 | 0.5 | +| [DARTSV1](https://arxiv.org/abs/1806.09055) | gradient-based | 3.3 | 3.00 | 0.4 | +| [DARTSV2](https://arxiv.org/abs/1806.09055) | gradient-based | 3.3 | 2.76 | 1.0 | +| [SNAS](https://arxiv.org/abs/1812.09926) | gradient-based | 2.8 | 2.85 | 1.5 | +| [PC-DARTS](https://github.com/yuhuixu1993/PC-DARTS) (official PyTorch version) | gradient-based | 3.63 | **2.57** | **0.1** | +| PC-DARTS TF2 (paper architecture) | gradient-based | 3.63 | 2.73 | - | +| PC-DARTS TF2 (searched by myself) | gradient-based | 3.56 | 2.88 | 0.12 | + +Note: +- Above results are referenced from [official repository](https://github.com/yuhuixu1993/PC-DARTS) and [orignal paper](https://arxiv.org/abs/1907.05737). +- There still have a slight performance gap between my PC-DARTS TF2 and official version. In both cases, we used Nvidia 1080ti (11G memory). My PC-DARTS TF2 pre-trained model can be found in [Models](#Models). +- If you get unsatisfactory results with the archecture searched by yourself, you might try to search it more than one time. (see the discussions [here](https://github.com/yuhuixu1993/PC-DARTS/issues/7)) + +**** + +## Models +:doughnut: + +Dowload these models bellow, then extract them into `./checkpoints/` for restoring. + +| Model Name | Config File | `arch` | Download Link | +|---------------------|-------------|--------|---------------| +| PC-DARTS (CIFAR-10, paper architecture) | [pcdarts_cifar10.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master//configs/pcdarts_cifar10.yaml) | `PCDARTS` | [GoogleDrive](https://drive.google.com/file/d/1BhLlktX78z90yOaORXvch_GAnIWWkYrX/view?usp=sharing) | +| PC-DARTS (CIFAR-10, searched by myself) | [pcdarts_cifar10_TF2.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master//configs/pcdarts_cifar10_TF2.yaml) | `PCDARTS_TF2_SEARCH` | [GoogleDrive](https://drive.google.com/file/d/1UgeZzEnQZ6oeMKpr01rEDflVi9X1dzeq/view?usp=sharing) | + +Note: +- You can find the training settings of the models in the corresponding [./configs/*.yaml](https://github.com/peteryuX/pcdarts-tf2/tree/master/configs) files, and make sure that the `arch` flag in it is matched with the genotypes name in [./modules/genotypes.py](https://github.com/peteryuX/pcdarts-tf2/tree/master/modules/genotypes.py). + +**** + +## References +:hamburger: + +Thanks for these source codes porviding me with knowledges to complete this repository. + +- https://github.com/yuhuixu1993/PC-DARTS (Official) + - PC-DARTS:Partial Channel Connections for Memory-Efficient Differentiable Architecture Search +- https://github.com/quark0/darts + - Differentiable architecture search for convolutional and recurrent networks https://arxiv.org/abs/1806.09055 +- https://github.com/zzh8829/yolov3-tf2 + - YoloV3 Implemented in TensorFlow 2.0 diff --git a/configs/pcdarts_cifar10.yaml b/configs/pcdarts_cifar10.yaml new file mode 100644 index 0000000..47a969f --- /dev/null +++ b/configs/pcdarts_cifar10.yaml @@ -0,0 +1,30 @@ +# general setting +batch_size: 96 +val_batch_size: 512 +input_size: 32 +init_channels: 36 +layers: 20 +num_classes: 10 +auxiliary_weight: 0.4 +drop_path_prob: 0.3 +arch: PCDARTS +sub_name: 'pcdarts_cifar10' +using_normalize: True + +# training dataset +dataset_len: 50000 # number of training samples +using_crop: True +using_flip: True +using_cutout: True +cutout_length: 16 + +# training setting +epoch: 600 +init_lr: 0.025 +lr_min: 0.0 +momentum: 0.9 +weights_decay: !!float 3e-4 +grad_clip: 5.0 + +val_steps: 520 +save_steps: 520 diff --git a/configs/pcdarts_cifar10_TF2.yaml b/configs/pcdarts_cifar10_TF2.yaml new file mode 100644 index 0000000..bf2f161 --- /dev/null +++ b/configs/pcdarts_cifar10_TF2.yaml @@ -0,0 +1,30 @@ +# general setting +batch_size: 96 +val_batch_size: 512 +input_size: 32 +init_channels: 36 +layers: 20 +num_classes: 10 +auxiliary_weight: 0.4 +drop_path_prob: 0.3 +arch: PCDARTS_TF2_SEARCH +sub_name: 'pcdarts_cifar10_tf2_arch' +using_normalize: True + +# training dataset +dataset_len: 50000 # number of training samples +using_crop: True +using_flip: True +using_cutout: True +cutout_length: 16 + +# training setting +epoch: 600 +init_lr: 0.025 +lr_min: 0.0 +momentum: 0.9 +weights_decay: !!float 3e-4 +grad_clip: 5.0 + +val_steps: 520 +save_steps: 520 diff --git a/configs/pcdarts_cifar10_search.yaml b/configs/pcdarts_cifar10_search.yaml new file mode 100644 index 0000000..411024d --- /dev/null +++ b/configs/pcdarts_cifar10_search.yaml @@ -0,0 +1,30 @@ +# general setting +batch_size: 256 +input_size: 32 +init_channels: 16 +layers: 8 +num_classes: 10 +sub_name: 'pcdarts_cifar10_search' +using_normalize: True + +# training dataset +dataset_len: 50000 # number of training samples +train_portion: 0.5 +using_crop: True +using_flip: True +using_cutout: False +cutout_length: 16 + +# training setting +epoch: 50 +start_search_epoch: 15 +init_lr: 0.1 +lr_min: 0.0 +momentum: 0.9 +weights_decay: !!float 3e-4 +grad_clip: 5.0 + +arch_learning_rate: !!float 6e-4 +arch_weight_decay: !!float 1e-3 + +save_steps: 97 diff --git a/dataset_checker.py b/dataset_checker.py new file mode 100644 index 0000000..39c94e9 --- /dev/null +++ b/dataset_checker.py @@ -0,0 +1,20 @@ +import cv2 +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +from modules.utils import set_memory_growth +from modules.dataset import load_cifar10_dataset + + +set_memory_growth() + +dataset = load_cifar10_dataset(batch_size=1, split='train', shuffle=False, + using_normalize=False) + +for (img, labels)in dataset: + img = img.numpy()[0] + print(img.shape, labels.shape, labels.numpy()) + + cv2.imshow('img', cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) + if cv2.waitKey(0) == ord('q'): + exit() diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..4c5c229 --- /dev/null +++ b/environment.yml @@ -0,0 +1,15 @@ +name: pcdarts-tf2 +channels: + - conda-forge + +dependencies: + - python==3.7 + - pip + - pip: + - tensorflow-gpu==2.1.0 + - tensorflow_datasets + - numpy + - opencv-python + - PyYAML + - ipython + - python-graphviz diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/dataset.py b/modules/dataset.py new file mode 100644 index 0000000..0c3320f --- /dev/null +++ b/modules/dataset.py @@ -0,0 +1,96 @@ +import tensorflow as tf +import tensorflow_datasets as tfds + + +_CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] +_CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] + + +def _meshgrid_tf(x, y): + """ workaround solution of the tf.meshgrid() issue: + https://github.com/tensorflow/tensorflow/issues/34470""" + grid_shape = [tf.shape(y)[0], tf.shape(x)[0]] + grid_x = tf.broadcast_to(tf.reshape(x, [1, -1]), grid_shape) + grid_y = tf.broadcast_to(tf.reshape(y, [-1, 1]), grid_shape) + return grid_x, grid_y + + +def _cutout(img, length, pad_values): + """coutout""" + h, w = tf.shape(img)[0], tf.shape(img)[1] + y = tf.random.uniform([], 0, h, dtype=tf.int32) + x = tf.random.uniform([], 0, w, dtype=tf.int32) + + y1 = tf.clip_by_value(y - length // 2, 0, h) + y2 = tf.clip_by_value(y + length // 2, 0, h) + x1 = tf.clip_by_value(x - length // 2, 0, w) + x2 = tf.clip_by_value(x + length // 2, 0, w) + + grid_x, grid_y = _meshgrid_tf(tf.range(h), tf.range(w)) + cond = tf.stack([grid_x > y1, grid_x < y2, grid_y > x1, grid_y < x2], -1) + mask = 1 - tf.cast(tf.math.reduce_all(cond, axis=-1, keepdims=True), + tf.float32) + img = mask * img + (1 - mask) * pad_values + + return img + + +def _transform_data_cifar10(using_normalize, using_crop, using_flip, + using_cutout, cutout_length): + def transform_data(features): + img = features['image'] + labels = features['label'] + img = tf.cast(img, tf.float32) + + pad_values = tf.reduce_mean(img, axis=[0, 1]) + + # randomly crop + if using_crop: + img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], constant_values=-1) + img = tf.where(img == -1, pad_values, img) + img = tf.image.random_crop(img, [32, 32, 3]) + + # randomly left-right flip + if using_flip: + img = tf.image.random_flip_left_right(img) + + # cutout + if using_cutout: + img = _cutout(img, cutout_length, pad_values) + + # rescale 0. ~ 1. + img = img / 255. + + # normalize + if using_normalize: + mean = tf.constant(_CIFAR_MEAN)[tf.newaxis, tf.newaxis] + std = tf.constant(_CIFAR_STD)[tf.newaxis, tf.newaxis] + img = (img - mean) / std + + return img, labels + return transform_data + + +def load_cifar10_dataset(batch_size, split='train', using_normalize=True, + using_crop=True, using_flip=True, using_cutout=True, + cutout_length=16, shuffle=True, buffer_size=10240, + drop_remainder=True): + """load dataset from tfrecord""" + dataset = tfds.load(name="cifar10", split=split) + + if 'train' in split: + dataset = dataset.repeat() + + if shuffle: + dataset = dataset.shuffle(buffer_size=buffer_size) + + dataset = dataset.map( + _transform_data_cifar10(using_normalize, using_crop, using_flip, + using_cutout, cutout_length), + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + + return dataset diff --git a/modules/genotypes.py b/modules/genotypes.py new file mode 100644 index 0000000..9565e8f --- /dev/null +++ b/modules/genotypes.py @@ -0,0 +1,180 @@ +from collections import namedtuple + + +Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') + + +PRIMITIVES = [ + 'none', + 'max_pool_3x3', + 'avg_pool_3x3', + 'skip_connect', + 'sep_conv_3x3', + 'sep_conv_5x5', + 'dil_conv_3x3', + 'dil_conv_5x5'] + + +NASNet = Genotype( + normal=[ + ('sep_conv_5x5', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 0), + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 1)], + normal_concat=[2, 3, 4, 5, 6], + reduce=[ + ('sep_conv_5x5', 1), + ('sep_conv_7x7', 0), + ('max_pool_3x3', 1), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('sep_conv_5x5', 0), + ('skip_connect', 3), + ('avg_pool_3x3', 2), + ('sep_conv_3x3', 2), + ('max_pool_3x3', 1)], + reduce_concat=[4, 5, 6]) + +AmoebaNet = Genotype( + normal=[ + ('avg_pool_3x3', 0), + ('max_pool_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 3), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 1)], + normal_concat=[4, 5, 6], + reduce=[ + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('max_pool_3x3', 0), + ('sep_conv_7x7', 2), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('conv_7x1_1x7', 0), + ('sep_conv_3x3', 5)], + reduce_concat=[3, 4, 6]) + +DARTS_V1 = Genotype( + normal=[ + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('skip_connect', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('skip_connect', 2)], + normal_concat=[2, 3, 4, 5], + reduce=[ + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('skip_connect', 2), + ('max_pool_3x3', 0), + ('max_pool_3x3', 0), + ('skip_connect', 2), + ('skip_connect', 2), + ('avg_pool_3x3', 0)], + reduce_concat=[2, 3, 4, 5]) + +DARTS_V2 = Genotype( + normal=[ + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('skip_connect', 0), + ('dil_conv_3x3', 2)], + normal_concat=[2, 3, 4, 5], + reduce=[ + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('skip_connect', 2), + ('max_pool_3x3', 1), + ('max_pool_3x3', 0), + ('skip_connect', 2), + ('skip_connect', 2), + ('max_pool_3x3', 1)], + reduce_concat=[2, 3, 4, 5]) + +PC_DARTS_cifar = Genotype( + normal=[ + ('sep_conv_3x3', 1), + ('skip_connect', 0), + ('sep_conv_3x3', 0), + ('dil_conv_3x3', 1), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 1), + ('avg_pool_3x3', 0), + ('dil_conv_3x3', 1)], + normal_concat=range(2, 6), + reduce=[ + ('sep_conv_5x5', 1), + ('max_pool_3x3', 0), + ('sep_conv_5x5', 1), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 3), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 2)], + reduce_concat=range(2, 6)) + +PC_DARTS_image = Genotype( + normal=[ + ('skip_connect', 1), + ('sep_conv_3x3', 0), + ('sep_conv_3x3', 0), + ('skip_connect', 1), + ('sep_conv_3x3', 1), + ('sep_conv_3x3', 3), + ('sep_conv_3x3', 1), + ('dil_conv_5x5', 4)], + normal_concat=range(2, 6), + reduce=[ + ('sep_conv_3x3', 0), + ('skip_connect', 1), + ('dil_conv_5x5', 2), + ('max_pool_3x3', 1), + ('sep_conv_3x3', 2), + ('sep_conv_3x3', 1), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 3)], + reduce_concat=range(2, 6)) + +PCDARTS_TF2_SEARCH = Genotype( + normal=[ + ('skip_connect', 0), + ('dil_conv_3x3', 1), + ('sep_conv_3x3', 2), + ('sep_conv_5x5', 0), + ('dil_conv_5x5', 2), + ('sep_conv_3x3', 1), + ('dil_conv_3x3', 2), + ('dil_conv_5x5', 1)], + normal_concat=range(2, 6), + reduce=[ + ('avg_pool_3x3', 1), + ('sep_conv_5x5', 0), + ('dil_conv_3x3', 2), + ('sep_conv_3x3', 1), + ('dil_conv_3x3', 3), + ('dil_conv_3x3', 2), + ('dil_conv_3x3', 1), + ('sep_conv_5x5', 0)], + reduce_concat=range(2, 6)) + +PCDARTS = PC_DARTS_cifar diff --git a/modules/losses.py b/modules/losses.py new file mode 100644 index 0000000..7b27a62 --- /dev/null +++ b/modules/losses.py @@ -0,0 +1,12 @@ +import tensorflow as tf + + +def CrossEntropyLoss(): + """"cross entropy loss""" + def cross_entropy_loss(y_true, y_pred): + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, + logits=y_pred) + + return tf.reduce_mean(ce) + return cross_entropy_loss diff --git a/modules/lr_scheduler.py b/modules/lr_scheduler.py new file mode 100644 index 0000000..37e5d5c --- /dev/null +++ b/modules/lr_scheduler.py @@ -0,0 +1,68 @@ +import tensorflow as tf +import math + + +def MultiStepLR(initial_learning_rate, lr_steps, lr_rate, name='MultiStepLR'): + """Multi-steps learning rate scheduler.""" + lr_steps_value = [initial_learning_rate] + for _ in range(len(lr_steps)): + lr_steps_value.append(lr_steps_value[-1] * lr_rate) + return tf.keras.optimizers.schedules.PiecewiseConstantDecay( + boundaries=lr_steps, values=lr_steps_value) + + +def CosineAnnealingLR_Restart(initial_learning_rate, t_period, lr_min=0.): + """Cosine annealing learning rate scheduler with restart.""" + return tf.keras.experimental.CosineDecayRestarts( + initial_learning_rate=initial_learning_rate, + first_decay_steps=t_period, t_mul=1.0, m_mul=1.0, + alpha=lr_min / initial_learning_rate) + + +def CosineAnnealingLR(initial_learning_rate, t_period, lr_min=0.): + """Cosine annealing learning rate scheduler with restart.""" + return tf.keras.experimental.CosineDecay( + initial_learning_rate=initial_learning_rate, + decay_steps=t_period, alpha=lr_min / initial_learning_rate) + + +if __name__ == "__main__": + # lr_scheduler = MultiStepLR(1e-4, [500, 1000, 2000, 3000], 0.5) + # lr_scheduler = CosineAnnealingLR_Restart(2e-4, 2500, 1e-7) + lr_scheduler = CosineAnnealingLR(0.025, 10000, 0) + + ############################## + # Draw figure + ############################## + N_iter = 10000 + step_list = list(range(0, N_iter, 10)) + lr_list = [] + for i in step_list: + current_lr = lr_scheduler(i).numpy() + lr_list.append(current_lr) + + import matplotlib as mpl + from matplotlib import pyplot as plt + import matplotlib.ticker as mtick + mpl.style.use('default') + import seaborn + seaborn.set(style='whitegrid') + seaborn.set_context('paper') + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) + plt.title('Title', fontsize=16, color='k') + plt.plot(step_list, lr_list, linewidth=1.5, label='learning rate scheme') + legend = plt.legend(loc='upper right', shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + 'K' + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) + + ax.set_ylabel('Learning rate') + ax.set_xlabel('Iteration') + fig = plt.gcf() + plt.show() diff --git a/modules/models.py b/modules/models.py new file mode 100644 index 0000000..8e60693 --- /dev/null +++ b/modules/models.py @@ -0,0 +1,150 @@ +import functools +import tensorflow as tf +from absl import logging +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import (Input, Dense, Flatten, Dropout, Conv2D, + AveragePooling2D, GlobalAveragePooling2D, + ReLU) +from modules.operations import (OPS, FactorizedReduce, ReLUConvBN, + BatchNormalization, Identity, drop_path, + kernel_init, regularizer) +import modules.genotypes as genotypes + + +class Cell(tf.keras.layers.Layer): + """Cell Layer""" + def __init__(self, genotype, ch, reduction, reduction_prev, wd, + name='Cell', **kwargs): + super(Cell, self).__init__(name=name, **kwargs) + + self.wd = wd + + if reduction_prev: + self.preprocess0 = FactorizedReduce(ch, wd=wd) + else: + self.preprocess0 = ReLUConvBN(ch, k=1, s=1, wd=wd) + self.preprocess1 = ReLUConvBN(ch, k=1, s=1, wd=wd) + + if reduction: + op_names, indices = zip(*genotype.reduce) + concat = genotype.reduce_concat + else: + op_names, indices = zip(*genotype.normal) + concat = genotype.normal_concat + + self._compile(ch, op_names, indices, concat, reduction) + + def _compile(self, ch, op_names, indices, concat, reduction): + assert len(op_names) == len(indices) + self._steps = len(op_names) // 2 + self._concat = concat + + self._ops = [] + for name, index in zip(op_names, indices): + strides = 2 if reduction and index < 2 else 1 + op = OPS[name](ch, strides, self.wd, True) + self._ops.append(op) + self._indices = indices + + def call(self, s0, s1, drop_path_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for step_index in range(self._steps): + op1 = self._ops[2 * step_index] + op2 = self._ops[2 * step_index + 1] + h1 = op1(states[self._indices[2 * step_index]]) + h2 = op2(states[self._indices[2 * step_index + 1]]) + + if drop_path_prob is not None: + if not isinstance(op1, Identity): + h1 = drop_path(h1, drop_path_prob, name='drop_path_h1') + if not isinstance(op2, Identity): + h2 = drop_path(h2, drop_path_prob, name='drop_path_h2') + + s = h1 + h2 + states += [s] + + return tf.concat([states[i] for i in self._concat], axis=-1) + + +class AuxiliaryHeadCIFAR(tf.keras.layers.Layer): + """Auxiliary Head Cifar""" + def __init__(self, num_classes, wd, name='AuxiliaryHeadCIFAR', **kwargs): + super(AuxiliaryHeadCIFAR, self).__init__(name=name, **kwargs) + self.features = Sequential([ + ReLU(), + AveragePooling2D(5, strides=3, padding='valid'), + Conv2D(filters=128, kernel_size=1, strides=1, padding='valid', + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), use_bias=False), + BatchNormalization(affine=True), + ReLU(), + Conv2D(filters=768, kernel_size=2, strides=1, padding='valid', + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), use_bias=False), + BatchNormalization(affine=True), + ReLU()]) + self.classifier = Dense(num_classes, kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd)) + + def call(self, x): + x = self.features(x) + x = self.classifier(Flatten()(x)) + return x + + +def CifarModel(cfg, training=True, stem_multiplier=3, name='CifarModel'): + """Cifar Model""" + logging.info(f"buliding {name}...") + + input_size = cfg['input_size'] + ch_init = cfg['init_channels'] + layers = cfg['layers'] + num_cls = cfg['num_classes'] + wd = cfg['weights_decay'] + genotype = eval("genotypes.%s" % cfg['arch']) + + # define model + inputs = Input([input_size, input_size, 3], name='input_image') + if training: + drop_path_prob = Input([], name='drop_prob') + else: + drop_path_prob = None + + ch_curr = stem_multiplier * ch_init + s0 = s1 = Sequential([ + Conv2D(filters=ch_curr, kernel_size=3, strides=1, padding='same', + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), use_bias=False), + BatchNormalization(affine=True)], name='stem')(inputs) + + ch_curr = ch_init + reduction_prev = False + logits_aux = None + for layer_index in range(layers): + if layer_index in [layers // 3, 2 * layers // 3]: + ch_curr *= 2 + reduction = True + else: + reduction = False + + cell = Cell(genotype, ch_curr, reduction, reduction_prev, wd, + name=f'Cell_{layer_index}') + s0, s1 = s1, cell(s0, s1, drop_path_prob) + + reduction_prev = reduction + + if layer_index == 2 * layers // 3 and training: + logits_aux = AuxiliaryHeadCIFAR(num_cls, wd=wd)(s1) + + fea = GlobalAveragePooling2D()(s1) + + logits = Dense(num_cls, kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd))(Flatten()(fea)) + + if training: + return Model((inputs, drop_path_prob), (logits, logits_aux), name=name) + else: + return Model(inputs, logits, name=name) diff --git a/modules/models_search.py b/modules/models_search.py new file mode 100644 index 0000000..27d5d8b --- /dev/null +++ b/modules/models_search.py @@ -0,0 +1,254 @@ +import functools +import tensorflow as tf +from absl import logging +from tensorflow.keras import Model, Sequential +from tensorflow.keras.layers import (Input, Dense, Flatten, Dropout, Conv2D, + MaxPool2D, GlobalAveragePooling2D, + ReLU, Softmax) +from modules.operations import (OPS, FactorizedReduce, ReLUConvBN, + BatchNormalization, Identity, drop_path, + kernel_init, regularizer) +from modules.genotypes import PRIMITIVES, Genotype + + +def channel_shuffle(x, groups): + _, h, w, num_channels = x.shape + + assert num_channels % groups == 0 + channels_per_group = num_channels // groups + + x = tf.reshape(x, [-1, h, w, groups, channels_per_group]) + x = tf.transpose(x, [0, 1, 2, 4, 3]) + x = tf.reshape(x, [-1, h, w, num_channels]) + + return x + + +class MixedOP(tf.keras.layers.Layer): + """Mixed OP""" + def __init__(self, ch, strides, wd, name='MixedOP', **kwargs): + super(MixedOP, self).__init__(name=name, **kwargs) + + self._ops = [] + self.mp = MaxPool2D(2, strides=2, padding='valid') + + for primitive in PRIMITIVES: + op = OPS[primitive](ch // 4, strides, wd, False) + + if 'pool' in primitive: + op = Sequential([op, BatchNormalization(affine=False)]) + + self._ops.append(op) + + def call(self, x, weights): + # channel proportion k = 4 + x_1 = x[:, :, :, :x.shape[3] // 4] + x_2 = x[:, :, :, x.shape[3] // 4:] + + x_1 = tf.add_n([w * op(x_1) for w, op in + zip(tf.split(weights, len(PRIMITIVES)), self._ops)]) + + # reduction cell needs pooling before concat + if x_1.shape[2] == x.shape[2]: + ans = tf.concat([x_1, x_2], axis=3) + else: + ans = tf.concat([x_1, self.mp(x_2)], axis=3) + + return channel_shuffle(ans, 4) + + +class Cell(tf.keras.layers.Layer): + """Cell Layer""" + def __init__(self, steps, multiplier, ch, reduction, reduction_prev, wd, + name='Cell', **kwargs): + super(Cell, self).__init__(name=name, **kwargs) + + self.wd = wd + self.steps = steps + self.multiplier = multiplier + + if reduction_prev: + self.preprocess0 = FactorizedReduce(ch, wd=wd, affine=False) + else: + self.preprocess0 = ReLUConvBN(ch, k=1, s=1, wd=wd, affine=False) + self.preprocess1 = ReLUConvBN(ch, k=1, s=1, wd=wd, affine=False) + + self._ops = [] + for i in range(self.steps): + for j in range(2 + i): + strides = 2 if reduction and j < 2 else 1 + op = MixedOP(ch, strides=strides, wd=wd) + self._ops.append(op) + + def call(self, s0, s1, weights, edge_weights): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + offset = 0 + for _ in range(self.steps): + s = 0 + for j, h in enumerate(states): + branch = self._ops[offset + j](h, weights[offset + j]) + s += edge_weights[offset + j] * branch + offset += len(states) + states.append(s) + + return tf.concat(states[-self.multiplier:], axis=-1) + + +class SplitSoftmax(tf.keras.layers.Layer): + """Split Softmax Layer""" + def __init__(self, size_splits, name='SplitSoftmax', **kwargs): + super(SplitSoftmax, self).__init__(name=name, **kwargs) + self.size_splits = size_splits + self.soft_max_func = Softmax(axis=-1) + + def call(self, value): + return tf.concat( + [self.soft_max_func(t) for t in tf.split(value, self.size_splits)], + axis=0) + + +class SearchNetArch(object): + """Search Network Architecture""" + def __init__(self, cfg, steps=4, multiplier=4, stem_multiplier=3, + name='SearchModel'): + self.cfg = cfg + self.steps = steps + self.multiplier = multiplier + self.stem_multiplier = stem_multiplier + self.name = name + + self.arch_parameters = self._initialize_alphas() + self.model = self._build_model() + + def _initialize_alphas(self): + k = sum(range(2, 2 + self.steps)) + num_ops = len(PRIMITIVES) + w_init = tf.random_normal_initializer() + self.alphas_normal = tf.Variable( + initial_value=1e-3 * w_init(shape=[k, num_ops], dtype='float32'), + trainable=True, name='alphas_normal') + self.alphas_reduce = tf.Variable( + initial_value=1e-3 * w_init(shape=[k, num_ops], dtype='float32'), + trainable=True, name='alphas_reduce') + self.betas_normal = tf.Variable( + initial_value=1e-3 * w_init(shape=[k], dtype='float32'), + trainable=True, name='betas_normal') + self.betas_reduce = tf.Variable( + initial_value=1e-3 * w_init(shape=[k], dtype='float32'), + trainable=True, name='betas_reduce') + + return [self.alphas_normal, self.alphas_reduce, self.betas_normal, + self.betas_reduce] + + def _build_model(self): + """Model""" + logging.info(f"buliding {self.name}...") + + input_size = self.cfg['input_size'] + ch_init = self.cfg['init_channels'] + layers = self.cfg['layers'] + num_cls = self.cfg['num_classes'] + wd = self.cfg['weights_decay'] + + # define model + inputs = Input([input_size, input_size, 3], name='input_image') + alphas_normal = Input([None], name='alphas_normal') + alphas_reduce = Input([None], name='alphas_reduce') + betas_normal = Input([], name='betas_normal') + betas_reduce = Input([], name='betas_reduce') + + alphas_reduce_weights = Softmax( + name='AlphasReduceSoftmax')(alphas_reduce) + alphas_normal_weights = Softmax( + name='AlphasNormalSoftmax')(alphas_normal) + betas_reduce_weights = SplitSoftmax( + range(2, 2 + self.steps), name='BetasReduceSoftmax')(betas_reduce) + betas_normal_weights = SplitSoftmax( + range(2, 2 + self.steps), name='BetasNormalSoftmax')(betas_normal) + + ch_curr = self.stem_multiplier * ch_init + s0 = s1 = Sequential([ + Conv2D(filters=ch_curr, kernel_size=3, strides=1, padding='same', + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), use_bias=False), + BatchNormalization(affine=True)], name='stem')(inputs) + + ch_curr = ch_init + reduction_prev = False + for layer_index in range(layers): + if layer_index in [layers // 3, 2 * layers // 3]: + ch_curr *= 2 + reduction = True + weights = alphas_reduce_weights + edge_weights = betas_reduce_weights + else: + reduction = False + weights = alphas_normal_weights + edge_weights = betas_normal_weights + + cell = Cell(self.steps, self.multiplier, ch_curr, reduction, + reduction_prev, wd, name=f'Cell_{layer_index}') + s0, s1 = s1, cell(s0, s1, weights, edge_weights) + + reduction_prev = reduction + + fea = GlobalAveragePooling2D()(s1) + + logits = Dense(num_cls, kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd))(Flatten()(fea)) + + return Model( + (inputs, alphas_normal, alphas_reduce, betas_normal, betas_reduce), + logits, name=self.name) + + def get_genotype(self): + """get genotype""" + def _parse(weights, edge_weights): + n = 2 + start = 0 + gene = [] + for i in range(self.steps): + end = start + n + w = weights[start:end].copy() + ew = edge_weights[start:end].copy() + + # fused weights + for j in range(n): + w[j, :] = w[j, :] * ew[j] + + # pick the top 2 edges (k = 2). + edges = sorted( + range(i + 2), + key=lambda x: -max(w[x][k] for k in range(len(w[x])) + if k != PRIMITIVES.index('none')) + )[:2] + + # pick the top best op, and append into genotype. + for j in edges: + k_best = None + for k in range(len(w[j])): + if k != PRIMITIVES.index('none'): + if k_best is None or w[j][k] > w[j][k_best]: + k_best = k + gene.append((PRIMITIVES[k_best], j)) + + start = end + n += 1 + + return gene + + gene_reduce = _parse( + Softmax()(self.alphas_reduce).numpy(), + SplitSoftmax(range(2, 2 + self.steps))(self.betas_reduce).numpy()) + gene_normal = _parse( + Softmax()(self.alphas_normal).numpy(), + SplitSoftmax(range(2, 2 + self.steps))(self.betas_normal).numpy()) + + concat = range(2 + self.steps - self.multiplier, self.steps + 2) + genotype = Genotype(normal=gene_normal, normal_concat=concat, + reduce=gene_reduce, reduce_concat=concat) + + return genotype diff --git a/modules/operations.py b/modules/operations.py new file mode 100644 index 0000000..5202046 --- /dev/null +++ b/modules/operations.py @@ -0,0 +1,180 @@ +import tensorflow as tf +from tensorflow.keras import Sequential +from tensorflow.keras.layers import (Conv2D, SeparableConv2D, MaxPool2D, + AveragePooling2D, ReLU, Dropout) + + +OPS = {'none': lambda f, s, wd, affine: Zero(s), + 'avg_pool_3x3': lambda f, s, wd, affine: + AveragePooling2D(3, strides=s, padding='same'), + 'max_pool_3x3': lambda f, s, wd, affine: + MaxPool2D(3, strides=s, padding='same'), + 'skip_connect': lambda f, s, wd, affine: + Identity() if s == 1 else FactorizedReduce(f, wd, affine=affine), + 'sep_conv_3x3': lambda f, s, wd, affine: + SepConv(f, 3, s, wd, affine=affine), + 'sep_conv_5x5': lambda f, s, wd, affine: + SepConv(f, 5, s, wd, affine=affine), + 'sep_conv_7x7': lambda f, s, wd, affine: + SepConv(f, 7, s, wd, affine=affine), + 'dil_conv_3x3': lambda f, s, wd, affine: + DilConv(f, 3, s, 2, wd, affine=affine), + 'dil_conv_5x5': lambda f, s, wd, affine: + DilConv(f, 5, s, 2, wd, affine=affine), + 'conv_7x1_1x7': lambda f, s, wd, affine: + Sequential([ReLU(), + Conv2D(filters=f, kernel_size=(1, 7), strides=(1, s), + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), + padding='same', use_bias=False), + Conv2D(filters=f, kernel_size=(7, 1), strides=(s, 1), + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), + padding='same', use_bias=False), + BatchNormalization(affine=affine)])} + + +def regularizer(weights_decay): + """l2 regularizer""" + return tf.keras.regularizers.l2(weights_decay) + + +def kernel_init(seed=None): + """He normal initializer""" + return tf.keras.initializers.he_normal(seed) + + +def drop_path(x, drop_rate, name='drop_path'): + """drop path from https://arxiv.org/abs/1605.07648""" + random_values = tf.random.uniform([tf.shape(x)[0], 1, 1, 1], name=name) + mask = tf.cast(random_values > drop_rate, tf.float32) / (1 - drop_rate) + x = mask * x + + return x + + +class BatchNormalization(tf.keras.layers.BatchNormalization): + """Make trainable=False freeze BN for real (the og version is sad). + ref: https://github.com/zzh8829/yolov3-tf2 + """ + def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, affine=True, + name=None, **kwargs): + super(BatchNormalization, self).__init__( + axis=axis, momentum=momentum, epsilon=epsilon, center=affine, + scale=affine, name=name, **kwargs) + + def call(self, x, training=False): + if training is None: + training = tf.constant(False) + training = tf.logical_and(training, self.trainable) + + return super().call(x, training) + + +class ReLUConvBN(tf.keras.layers.Layer): + """ReLu + Conv + BN""" + def __init__(self, ch_out, k, s, wd, padding='valid', affine=True, + name='ReLUConvBN', **kwargs): + super(ReLUConvBN, self).__init__(name=name, **kwargs) + self.op = Sequential([ + ReLU(), + Conv2D(filters=ch_out, kernel_size=k, strides=s, + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), + padding=padding, use_bias=False), + BatchNormalization(affine=affine)]) + + def call(self, x): + return self.op(x) + + +class DilConv(tf.keras.layers.Layer): + """Dilated Conv""" + def __init__(self, ch_out, k, s, d, wd, affine=True, name='DilConv', + **kwargs): + super(DilConv, self).__init__(name=name, **kwargs) + self.op = Sequential([ + ReLU(), + SeparableConv2D(filters=ch_out, kernel_size=k, strides=s, + depthwise_initializer=kernel_init(), + pointwise_initializer=kernel_init(), + depthwise_regularizer=regularizer(wd), + pointwise_regularizer=regularizer(wd), + dilation_rate=d, padding='same', use_bias=False), + BatchNormalization(affine=affine)]) + + def call(self, x): + return self.op(x) + + +class SepConv(tf.keras.layers.Layer): + """Separable Conv""" + def __init__(self, ch_out, k, s, wd, affine=True, name='SepConv', + **kwargs): + super(SepConv, self).__init__(name=name, **kwargs) + self.op = Sequential([ + ReLU(), + SeparableConv2D(filters=ch_out, kernel_size=k, strides=s, + depthwise_initializer=kernel_init(), + pointwise_initializer=kernel_init(), + depthwise_regularizer=regularizer(wd), + pointwise_regularizer=regularizer(wd), + padding='same', use_bias=False), + BatchNormalization(affine=affine), + ReLU(), + SeparableConv2D(filters=ch_out, kernel_size=k, strides=1, + depthwise_initializer=kernel_init(), + pointwise_initializer=kernel_init(), + depthwise_regularizer=regularizer(wd), + pointwise_regularizer=regularizer(wd), + padding='same', use_bias=False), + BatchNormalization(affine=affine)]) + + def call(self, x): + return self.op(x) + + +class Identity(tf.keras.layers.Layer): + """Identity""" + def __init__(self, name='Identity', **kwargs): + super(Identity, self).__init__(name=name, **kwargs) + + def call(self, x): + return x + + +class Zero(tf.keras.layers.Layer): + """Zero""" + def __init__(self, strides, name='Zero', **kwargs): + super(Zero, self).__init__(name=name, **kwargs) + self.strides = strides + + def call(self, x): + if self.strides == 1: + return x * 0. + return x[:, ::self.strides, ::self.strides, :] * 0 + + +class FactorizedReduce(tf.keras.layers.Layer): + """Factorized Reduce Layer""" + def __init__(self, ch_out, wd, affine=True, name='FactorizedReduce', + **kwargs): + super(FactorizedReduce, self).__init__(name=name, **kwargs) + assert ch_out % 2 == 0 + self.relu = ReLU() + self.conv_1 = Conv2D(filters=ch_out // 2, kernel_size=1, strides=2, + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), + padding='valid', use_bias=False) + self.conv_2 = Conv2D(filters=ch_out // 2, kernel_size=1, strides=2, + kernel_initializer=kernel_init(), + kernel_regularizer=regularizer(wd), + padding='valid', use_bias=False) + self.bn = BatchNormalization(affine=affine) + + def call(self, x): + x = self.relu(x) + out = tf.concat([self.conv_1(x), self.conv_2(x)], axis=-1) + out = self.bn(out) + + return out diff --git a/modules/utils.py b/modules/utils.py new file mode 100644 index 0000000..4ec8d99 --- /dev/null +++ b/modules/utils.py @@ -0,0 +1,129 @@ +import cv2 +import yaml +import sys +import time +import numpy as np +import tensorflow as tf +from absl import logging + + +def load_yaml(load_path): + """load yaml file""" + with open(load_path, 'r') as f: + loaded = yaml.load(f, Loader=yaml.Loader) + + return loaded + + +def set_memory_growth(): + """set memory growth in tensorflow""" + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + # Currently, memory growth needs to be the same across GPUs + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.experimental.list_logical_devices( + 'GPU') + logging.info( + "Detect {} Physical GPUs, {} Logical GPUs.".format( + len(gpus), len(logical_gpus))) + except RuntimeError as e: + # Memory growth must be set before GPUs have been initialized + logging.info(e) + + +def count_parameters_in_MB(model): + """count parameters in MB""" + return np.sum( + [tf.keras.backend.count_params(w) for w in model.trainable_weights + if 'Auxiliary' not in w.name]) / 1e6 + + +class ProgressBar(object): + """A progress bar which can print the progress modified from + https://github.com/hellock/cvbase/blob/master/cvbase/progress.py""" + def __init__(self, task_num=0, completed=0, bar_width=15): + self.task_num = task_num + max_bar_width = self._get_max_bar_width() + self.bar_width = ( + bar_width if bar_width <= max_bar_width else max_bar_width) + self.completed = completed + self.first_step = completed + self.warm_up = False + + def _get_max_bar_width(self): + if sys.version_info > (3, 3): + from shutil import get_terminal_size + else: + from backports.shutil_get_terminal_size import get_terminal_size + terminal_width, _ = get_terminal_size() + max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) + if max_bar_width < 10: + logging.info('terminal width is too small ({}), please consider ' + 'widen the terminal for better progressbar ' + 'visualization'.format(terminal_width)) + max_bar_width = 10 + return max_bar_width + + def reset(self): + """reset""" + self.completed = 0 + self.fps = 0 + + def update(self, inf_str=''): + """update""" + self.completed += 1 + + if not self.warm_up: + self.start_time = time.time() - 1e-1 + self.warm_up = True + + if self.completed > self.task_num: + self.completed = self.completed % self.task_num + self.start_time = time.time() - 1 / self.fps + self.first_step = self.completed - 1 + sys.stdout.write('\n') + + elapsed = time.time() - self.start_time + self.fps = (self.completed - self.first_step) / elapsed + percentage = self.completed / float(self.task_num) + mark_width = int(self.bar_width * percentage) + bar_chars = '>' * mark_width + ' ' * (self.bar_width - mark_width) + stdout_str = '\rtrain [{}] {}/{}, {} {:.1f} step/sec' + sys.stdout.write(stdout_str.format( + bar_chars, self.completed, self.task_num, inf_str, self.fps)) + + sys.stdout.flush() + + +class AvgrageMeter(object): + """Average meter""" + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def accuracy(outputs, labels, topk=(1,)): + """caculate accuracy""" + maxk = max(topk) + batch_size = labels.shape[0] + + pred = np.argsort(outputs)[:, ::-1][:, :maxk] + correct = (pred == np.reshape(labels, [batch_size, 1])) + + results = [] + for k in topk: + correct_k = np.sum(correct[:, :k]) + results.append(correct_k * 100.0 / batch_size) + + return results diff --git a/photo/architecture.jpg b/photo/architecture.jpg new file mode 100644 index 0000000..cc68f07 Binary files /dev/null and b/photo/architecture.jpg differ diff --git a/photo/genotype_normal.jpg b/photo/genotype_normal.jpg new file mode 100644 index 0000000..8123b58 Binary files /dev/null and b/photo/genotype_normal.jpg differ diff --git a/photo/genotype_reduction.jpg b/photo/genotype_reduction.jpg new file mode 100644 index 0000000..9f0b26c Binary files /dev/null and b/photo/genotype_reduction.jpg differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2a9a7bb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +tensorflow-gpu==2.1.0 +tensorflow_datasets +numpy +opencv-python +PyYAML +ipython +graphviz diff --git a/test.py b/test.py new file mode 100644 index 0000000..7c2a72b --- /dev/null +++ b/test.py @@ -0,0 +1,75 @@ +from absl import app, flags, logging +from absl.flags import FLAGS +import cv2 +import os +import numpy as np +import tensorflow as tf +import time + +from modules.models import CifarModel +from modules.dataset import load_cifar10_dataset +from modules.utils import ( + set_memory_growth, load_yaml, count_parameters_in_MB, AvgrageMeter, + accuracy) + + +flags.DEFINE_string('cfg_path', './configs/pcdarts_cifar10.yaml', + 'config file path') +flags.DEFINE_string('gpu', '0', 'which gpu to use') + + +def main(_argv): + # init + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu + + logger = tf.get_logger() + logger.disabled = True + logger.setLevel(logging.FATAL) + set_memory_growth() + + cfg = load_yaml(FLAGS.cfg_path) + + # load dataset + test_dataset = load_cifar10_dataset( + cfg['val_batch_size'], split='test', shuffle=False, + drop_remainder=False, using_crop=False, using_flip=False, + using_cutout=False) + + # define network + model = CifarModel(cfg, training=False) + model.summary(line_length=80) + print("param size = {:f}MB".format(count_parameters_in_MB(model))) + + # load checkpoint + checkpoint_path = './checkpoints/' + cfg['sub_name'] + '/best.ckpt' + try: + model.load_weights('./checkpoints/' + cfg['sub_name'] + '/best.ckpt') + print("[*] load ckpt from {}.".format(checkpoint_path)) + except: + print("[*] Cannot find ckpt from {}.".format(checkpoint_path)) + exit() + + # inference + top1 = AvgrageMeter() + top5 = AvgrageMeter() + for step, (inputs, labels) in enumerate(test_dataset): + # run model + logits = model(inputs) + + # cacludate top1, top5 acc + prec1, prec5 = accuracy(logits.numpy(), labels.numpy(), topk=(1, 5)) + n = inputs.shape[0] + top1.update(prec1, n) + top5.update(prec5, n) + + print(" {:03d}: top1 {:f}, top5 {:f}".format(step, top1.avg, top5.avg)) + + print("Test Acc: top1 {:.2f}%, top5 {:.2f}%".format(top1.avg, top5.avg)) + + +if __name__ == '__main__': + try: + app.run(main) + except SystemExit: + pass diff --git a/train.py b/train.py new file mode 100644 index 0000000..1d3bd42 --- /dev/null +++ b/train.py @@ -0,0 +1,162 @@ +from absl import app, flags, logging +from absl.flags import FLAGS +import os +import numpy as np +import tensorflow as tf + +from modules.models import CifarModel +from modules.dataset import load_cifar10_dataset +from modules.lr_scheduler import CosineAnnealingLR +from modules.losses import CrossEntropyLoss +from modules.utils import ( + set_memory_growth, load_yaml, count_parameters_in_MB, ProgressBar, + AvgrageMeter, accuracy) + + +flags.DEFINE_string('cfg_path', './configs/pcdarts_cifar10.yaml', + 'config file path') +flags.DEFINE_string('gpu', '0', 'which gpu to use') + + +def main(_): + # init + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu + + logger = tf.get_logger() + logger.disabled = True + logger.setLevel(logging.FATAL) + set_memory_growth() + + cfg = load_yaml(FLAGS.cfg_path) + + # define network + model = CifarModel(cfg, training=True) + model.summary(line_length=80) + print("param size = {:f}MB".format(count_parameters_in_MB(model))) + + # load dataset + train_dataset = load_cifar10_dataset( + cfg['batch_size'], split='train', shuffle=True, drop_remainder=True, + using_normalize=cfg['using_normalize'], using_crop=cfg['using_crop'], + using_flip=cfg['using_flip'], using_cutout=cfg['using_cutout'], + cutout_length=cfg['cutout_length']) + val_dataset = load_cifar10_dataset( + cfg['val_batch_size'], split='test', shuffle=False, + drop_remainder=False, using_normalize=cfg['using_normalize'], + using_crop=False, using_flip=False, using_cutout=False) + + # define optimizer + steps_per_epoch = cfg['dataset_len'] // cfg['batch_size'] + learning_rate = CosineAnnealingLR( + initial_learning_rate=cfg['init_lr'], + t_period=cfg['epoch'] * steps_per_epoch, lr_min=cfg['lr_min']) + optimizer = tf.keras.optimizers.SGD( + learning_rate=learning_rate, momentum=cfg['momentum']) + + # define losses function + criterion = CrossEntropyLoss() + + # load checkpoint + checkpoint_dir = './checkpoints/' + cfg['sub_name'] + checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'), + optimizer=optimizer, + model=model) + manager = tf.train.CheckpointManager(checkpoint=checkpoint, + directory=checkpoint_dir, + max_to_keep=3) + if manager.latest_checkpoint: + checkpoint.restore(manager.latest_checkpoint) + print('[*] load ckpt from {} at step {}.'.format( + manager.latest_checkpoint, checkpoint.step.numpy())) + else: + print("[*] training from scratch.") + + # define training step function + @tf.function + def train_step(inputs, labels, drop_path_prob): + with tf.GradientTape() as tape: + logits, logits_aux = model((inputs, drop_path_prob), training=True) + + losses = {} + losses['reg'] = tf.reduce_sum(model.losses) + losses['ce'] = criterion(labels, logits) + losses['ce_auxiliary'] = \ + cfg['auxiliary_weight'] * criterion(labels, logits_aux) + total_loss = tf.add_n([l for l in losses.values()]) + + grads = tape.gradient(total_loss, model.trainable_variables) + grads = [(tf.clip_by_norm(grad, cfg['grad_clip'])) for grad in grads] + optimizer.apply_gradients(zip(grads, model.trainable_variables)) + + return logits, total_loss, losses + + # training loop + summary_writer = tf.summary.create_file_writer('./logs/' + cfg['sub_name']) + total_steps = steps_per_epoch * cfg['epoch'] + remain_steps = max(total_steps - checkpoint.step.numpy(), 0) + prog_bar = ProgressBar(steps_per_epoch, + checkpoint.step.numpy() % steps_per_epoch) + + train_acc = AvgrageMeter() + val_acc = AvgrageMeter() + best_acc = 0. + for inputs, labels in train_dataset.take(remain_steps): + checkpoint.step.assign_add(1) + drop_path_prob = cfg['drop_path_prob'] * ( + tf.cast(checkpoint.step, tf.float32) / total_steps) + steps = checkpoint.step.numpy() + epochs = ((steps - 1) // steps_per_epoch) + 1 + + logits, total_loss, losses = train_step(inputs, labels, drop_path_prob) + train_acc.update( + accuracy(logits.numpy(), labels.numpy())[0], cfg['batch_size']) + + prog_bar.update( + "epoch={}/{}, loss={:.4f}, acc={:.2f}, lr={:.2e}".format( + epochs, cfg['epoch'], total_loss.numpy(), train_acc.avg, + optimizer.lr(steps).numpy())) + + if steps % cfg['val_steps'] == 0 and steps > 1: + print("\n[*] validate...", end='') + val_acc.reset() + for inputs_val, labels_val in val_dataset: + logits_val, _ = model((inputs_val, tf.constant([0.]))) + val_acc.update( + accuracy(logits_val.numpy(), labels_val.numpy())[0], + inputs_val.shape[0]) + + if val_acc.avg > best_acc: + best_acc = val_acc.avg + model.save_weights(f"checkpoints/{cfg['sub_name']}/best.ckpt") + + val_str = " val acc {:.2f}%, best acc {:.2f}%" + print(val_str.format(val_acc.avg, best_acc), end='') + + if steps % 10 == 0: + with summary_writer.as_default(): + tf.summary.scalar('acc/train', train_acc.avg, step=steps) + tf.summary.scalar('acc/val', val_acc.avg, step=steps) + + tf.summary.scalar( + 'loss/total_loss', total_loss, step=steps) + for k, l in losses.items(): + tf.summary.scalar('loss/{}'.format(k), l, step=steps) + tf.summary.scalar( + 'learning_rate', optimizer.lr(steps), step=steps) + + if steps % cfg['save_steps'] == 0: + manager.save() + print("\n[*] save ckpt file at {}".format( + manager.latest_checkpoint)) + + if steps % steps_per_epoch == 0: + train_acc.reset() + + manager.save() + print("\n[*] training done! save ckpt file at {}".format( + manager.latest_checkpoint)) + + +if __name__ == '__main__': + app.run(main) diff --git a/train_search.py b/train_search.py new file mode 100644 index 0000000..0e8013a --- /dev/null +++ b/train_search.py @@ -0,0 +1,186 @@ +from absl import app, flags, logging +from absl.flags import FLAGS +import os +import numpy as np +import tensorflow as tf + +from modules.models_search import SearchNetArch +from modules.dataset import load_cifar10_dataset +from modules.lr_scheduler import CosineAnnealingLR +from modules.losses import CrossEntropyLoss +from modules.utils import ( + set_memory_growth, load_yaml, count_parameters_in_MB, ProgressBar, + AvgrageMeter, accuracy) + + +flags.DEFINE_string('cfg_path', './configs/pcdarts_cifar10_search.yaml', + 'config file path') +flags.DEFINE_string('gpu', '0', 'which gpu to use') + + +def main(_): + # init + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu + + logger = tf.get_logger() + logger.disabled = True + logger.setLevel(logging.FATAL) + set_memory_growth() + + cfg = load_yaml(FLAGS.cfg_path) + + # define network + sna = SearchNetArch(cfg) + sna.model.summary(line_length=80) + print("param size = {:f}MB".format(count_parameters_in_MB(sna.model))) + + # load dataset + t_split = f"train[0%:{int(cfg['train_portion'] * 100)}%]" + v_split = f"train[{int(cfg['train_portion'] * 100)}%:100%]" + train_dataset = load_cifar10_dataset( + cfg['batch_size'], split=t_split, shuffle=True, drop_remainder=True, + using_normalize=cfg['using_normalize'], using_crop=cfg['using_crop'], + using_flip=cfg['using_flip'], using_cutout=cfg['using_cutout'], + cutout_length=cfg['cutout_length']) + val_dataset = load_cifar10_dataset( + cfg['batch_size'], split=v_split, shuffle=True, drop_remainder=True, + using_normalize=cfg['using_normalize'], using_crop=cfg['using_crop'], + using_flip=cfg['using_flip'], using_cutout=cfg['using_cutout'], + cutout_length=cfg['cutout_length']) + + # define optimizer + steps_per_epoch = int( + cfg['dataset_len'] * cfg['train_portion'] // cfg['batch_size']) + learning_rate = CosineAnnealingLR( + initial_learning_rate=cfg['init_lr'], + t_period=cfg['epoch'] * steps_per_epoch, lr_min=cfg['lr_min']) + optimizer = tf.keras.optimizers.SGD( + learning_rate=learning_rate, momentum=cfg['momentum']) + optimizer_arch = tf.keras.optimizers.Adam( + learning_rate=cfg['arch_learning_rate'], beta_1=0.5, beta_2=0.999) + + # define losses function + criterion = CrossEntropyLoss() + + # load checkpoint + checkpoint_dir = './checkpoints/' + cfg['sub_name'] + checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'), + optimizer=optimizer, + optimizer_arch=optimizer_arch, + model=sna.model, + alphas_normal=sna.alphas_normal, + alphas_reduce=sna.alphas_reduce, + betas_normal=sna.betas_normal, + betas_reduce=sna.betas_reduce) + manager = tf.train.CheckpointManager(checkpoint=checkpoint, + directory=checkpoint_dir, + max_to_keep=3) + if manager.latest_checkpoint: + checkpoint.restore(manager.latest_checkpoint) + print('[*] load ckpt from {} at step {}.'.format( + manager.latest_checkpoint, checkpoint.step.numpy())) + else: + print("[*] training from scratch.") + print(f"[*] searching model after {cfg['start_search_epoch']} epochs.") + + # define training step function for model + @tf.function + def train_step(inputs, labels): + with tf.GradientTape() as tape: + logits = sna.model((inputs, *sna.arch_parameters), training=True) + + losses = {} + losses['reg'] = tf.reduce_sum(sna.model.losses) + losses['ce'] = criterion(labels, logits) + total_loss = tf.add_n([l for l in losses.values()]) + + grads = tape.gradient(total_loss, sna.model.trainable_variables) + grads = [(tf.clip_by_norm(grad, cfg['grad_clip'])) for grad in grads] + optimizer.apply_gradients(zip(grads, sna.model.trainable_variables)) + + return logits, total_loss, losses + + # define training step function for arch_parameters + @tf.function + def train_step_arch(inputs, labels): + with tf.GradientTape() as tape: + logits = sna.model((inputs, *sna.arch_parameters), training=True) + + losses = {} + losses['reg'] = cfg['arch_weight_decay'] * tf.add_n( + [tf.reduce_sum(p**2) for p in sna.arch_parameters]) + losses['ce'] = criterion(labels, logits) + total_loss = tf.add_n([l for l in losses.values()]) + + grads = tape.gradient(total_loss, sna.arch_parameters) + optimizer_arch.apply_gradients(zip(grads, sna.arch_parameters)) + + return losses + + # training loop + summary_writer = tf.summary.create_file_writer('./logs/' + cfg['sub_name']) + total_steps = steps_per_epoch * cfg['epoch'] + remain_steps = max(total_steps - checkpoint.step.numpy(), 0) + prog_bar = ProgressBar(steps_per_epoch, + checkpoint.step.numpy() % steps_per_epoch) + + train_acc = AvgrageMeter() + for inputs, labels in train_dataset.take(remain_steps): + checkpoint.step.assign_add(1) + steps = checkpoint.step.numpy() + epochs = ((steps - 1) // steps_per_epoch) + 1 + + if epochs > cfg['start_search_epoch']: + inputs_val, labels_val = next(iter(val_dataset)) + arch_losses = train_step_arch(inputs_val, labels_val) + + logits, total_loss, losses = train_step(inputs, labels) + train_acc.update( + accuracy(logits.numpy(), labels.numpy())[0], cfg['batch_size']) + + prog_bar.update( + "epoch={:d}/{:d}, loss={:.4f}, acc={:.2f}, lr={:.2e}".format( + epochs, cfg['epoch'], total_loss.numpy(), train_acc.avg, + optimizer.lr(steps).numpy())) + + if steps % 10 == 0: + with summary_writer.as_default(): + tf.summary.scalar('acc/train', train_acc.avg, step=steps) + + tf.summary.scalar( + 'loss/total_loss', total_loss, step=steps) + for k, l in losses.items(): + tf.summary.scalar('loss/{}'.format(k), l, step=steps) + tf.summary.scalar( + 'learning_rate', optimizer.lr(steps), step=steps) + + if epochs > cfg['start_search_epoch']: + for k, l in arch_losses.items(): + tf.summary.scalar( + 'arch_losses/{}'.format(k), l, step=steps) + tf.summary.scalar('arch_learning_rate', + cfg['arch_learning_rate'], step=steps) + + if steps % cfg['save_steps'] == 0: + manager.save() + print("\n[*] save ckpt file at {}".format( + manager.latest_checkpoint)) + + if steps % steps_per_epoch == 0: + train_acc.reset() + if epochs > cfg['start_search_epoch']: + genotype = sna.get_genotype() + print(f"\nsearch arch: {genotype}") + f = open(os.path.join( + './logs', cfg['sub_name'], 'search_arch_genotype.py'), 'a') + f.write(f"\n{cfg['sub_name']}_{epochs} = {genotype}\n") + f.close() + + manager.save() + print("\n[*] training done! save ckpt file at {}".format( + manager.latest_checkpoint)) + + +if __name__ == '__main__': + app.run(main) diff --git a/visualize_genotype.py b/visualize_genotype.py new file mode 100644 index 0000000..591d531 --- /dev/null +++ b/visualize_genotype.py @@ -0,0 +1,55 @@ +import sys +from modules import genotypes +from graphviz import Digraph + + +def plot(genotype, filename): + g = Digraph(format='pdf', + edge_attr=dict(fontsize='20', fontname="times"), + node_attr=dict(style='filled', shape='rect', align='center', + fontsize='20', height='0.5', width='0.5', + penwidth='2', fontname="times"), + engine='dot') + g.body.extend(['rankdir=LR']) + + g.node("c_{k-2}", fillcolor='darkseagreen2') + g.node("c_{k-1}", fillcolor='darkseagreen2') + assert len(genotype) % 2 == 0 + steps = len(genotype) // 2 + + for i in range(steps): + g.node(str(i), fillcolor='lightblue') + + for i in range(steps): + for k in [2 * i, 2 * i + 1]: + op, j = genotype[k] + if j == 0: + u = "c_{k-2}" + elif j == 1: + u = "c_{k-1}" + else: + u = str(j - 2) + v = str(i) + g.edge(u, v, label=op, fillcolor="gray") + + g.node("c_{k}", fillcolor='palegoldenrod') + for i in range(steps): + g.edge(str(i), "c_{k}", fillcolor="gray") + + g.render(filename, view=True) + + +if __name__ == '__main__': + if len(sys.argv) != 2: + print("usage:\n python {} ARCH_NAME".format(sys.argv[0])) + sys.exit(1) + + genotype_name = sys.argv[1] + try: + genotype = eval('genotypes.{}'.format(genotype_name)) + except AttributeError: + print("{} is not specified in genotypes.py".format(genotype_name)) + sys.exit(1) + + plot(genotype.normal, "normal") + plot(genotype.reduce, "reduction")