diff --git a/README.md b/README.md index f1bacd054e..e03b339b80 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
  • ProxylessNAS
  • Network Morphism
  • TextNAS
  • +
  • Cream
  • Model Compression diff --git a/docs/en_US/NAS/CDARTS.md b/docs/en_US/NAS/CDARTS.md index 4152b8efa2..07f8faf22d 100644 --- a/docs/en_US/NAS/CDARTS.md +++ b/docs/en_US/NAS/CDARTS.md @@ -1,16 +1,17 @@ + # CDARTS ## Introduction -CDARTS builds a cyclic feedback mechanism between the search and evaluation networks. First, the search network generates an initial topology for evaluation, so that the weights of the evaluation network can be optimized. Second, the architecture topology in the search network is further optimized by the label supervision in classification, as well as the regularization from the evaluation network through feature distillation. Repeating the above cycle results in a joint optimization of the search and evaluation networks, and thus enables the evolution of the topology to fit the final evaluation network. +[CDARTS](https://arxiv.org/pdf/2006.10724.pdf) builds a cyclic feedback mechanism between the search and evaluation networks. First, the search network generates an initial topology for evaluation, so that the weights of the evaluation network can be optimized. Second, the architecture topology in the search network is further optimized by the label supervision in classification, as well as the regularization from the evaluation network through feature distillation. Repeating the above cycle results in a joint optimization of the search and evaluation networks, and thus enables the evolution of the topology to fit the final evaluation network. -In implementation of `CdartsTrainer`, it first instantiates two models and two mutators (one for each). The first model is the so-called "search network", which is mutated with a `RegularizedDartsMutator` -- a mutator with subtle differences with `DartsMutator`. The second model is the "evaluation network", which is mutated with a discrete mutator that leverages the previous search network mutator, to sample a single path each time. Trainers train models and mutators alternatively. Users can refer to [references](#reference) if they are interested in more details on these trainers and mutators. +In implementation of `CdartsTrainer`, it first instantiates two models and two mutators (one for each). The first model is the so-called "search network", which is mutated with a `RegularizedDartsMutator` -- a mutator with subtle differences with `DartsMutator`. The second model is the "evaluation network", which is mutated with a discrete mutator that leverages the previous search network mutator, to sample a single path each time. Trainers train models and mutators alternatively. Users can refer to [paper](https://arxiv.org/pdf/2006.10724.pdf) if they are interested in more details on these trainers and mutators. ## Reproduction Results This is CDARTS based on the NNI platform, which currently supports CIFAR10 search and retrain. ImageNet search and retrain should also be supported, and we provide corresponding interfaces. Our reproduced results on NNI are slightly lower than the paper, but much higher than the original DARTS. Here we show the results of three independent experiments on CIFAR10. -| Runs | Paper | NNI | +| Runs | Paper | NNI | | ---- |:-------------:| :-----:| | 1 | 97.52 | 97.44 | | 2 | 97.53 | 97.48 | @@ -19,7 +20,7 @@ This is CDARTS based on the NNI platform, which currently supports CIFAR10 searc ## Examples -[Example code](https://github.com/microsoft/nni/tree/v1.9/examples/nas/cdarts) +[Example code](https://github.com/microsoft/nni/tree/master/examples/nas/cdarts) ```bash # In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. @@ -55,3 +56,4 @@ bash run_retrain_cifar.sh .. autoclass:: nni.algorithms.nas.pytorch.cdarts.RegularizedMutatorParallel :members: ``` + diff --git a/docs/en_US/NAS/Cream.md b/docs/en_US/NAS/Cream.md new file mode 100644 index 0000000000..beb232c085 --- /dev/null +++ b/docs/en_US/NAS/Cream.md @@ -0,0 +1,127 @@ +# Cream of the Crop: Distilling Prioritized Paths For One-Shot Neural Architecture Search + +**[[Paper]](https://papers.nips.cc/paper/2020/file/d072677d210ac4c03ba046120f0802ec-Paper.pdf) [[Models-Google Drive]](https://drive.google.com/drive/folders/1NLGAbBF9bA1IUAxKlk2VjgRXhr6RHvRW?usp=sharing)[[Models-Baidu Disk (PWD: wqw6)]](https://pan.baidu.com/s/1TqQNm2s14oEdyNPimw3T9g) [[BibTex]](https://scholar.googleusercontent.com/scholar.bib?q=info:ICWVXc_SsKAJ:scholar.google.com/&output=citation&scisdr=CgUmooXfEMfTi0cV5aU:AAGBfm0AAAAAX7sQ_aXoamdKRaBI12tAVN8REq1VKNwM&scisig=AAGBfm0AAAAAX7sQ_RdYtp6BSro3zgbXVJU2MCgsG730&scisf=4&ct=citation&cd=-1&hl=ja)**
    + +In this work, we present a simple yet effective architecture distillation method. The central idea is that subnetworks can learn collaboratively and teach each other throughout the training process, aiming to boost the convergence of individual models. We introduce the concept of prioritized path, which refers to the architecture candidates exhibiting superior performance during training. Distilling knowledge from the prioritized paths is able to boost the training of subnetworks. Since the prioritized paths are changed on the fly depending on their performance and complexity, the final obtained paths are the cream of the crop. The discovered architectures achieve superior performance compared to the recent [MobileNetV3](https://arxiv.org/abs/1905.02244) and [EfficientNet](https://arxiv.org/abs/1905.11946) families under aligned settings. + +
    + +
    + + +## Reproduced Results +Top-1 Accuracy on ImageNet. The top-1 accuracy of Cream search algorithm surpasses MobileNetV3 and EfficientNet-B0/B1 on ImageNet. +The training with 16 Gpus is a little bit superior than 8 Gpus, as below. + +| Model (M Flops) | 8Gpus | 16Gpus | +| ---- |:-------------:| :-----:| +| 14M | 53.7 | 53.8 | +| 43M | 65.8 | 66.5 | +| 114M | 72.1 | 72.8 | +| 287M | 76.7 | 77.6 | +| 481M | 78.9 | 79.2 | +| 604M | 79.4 | 80.0 | + + + + +
    drawingdrawing
    + +## Examples + +[Example code](https://github.com/microsoft/nni/tree/master/examples/nas/cream) + +Please run the following scripts in the example folder. + +## Data Preparation + +You need to first download the [ImageNet-2012](http://www.image-net.org/) to the folder `./data/imagenet` and move the validation set to the subfolder `./data/imagenet/val`. To move the validation set, you cloud use the following script: + +Put the imagenet data in `./data`. It should be like following: + +``` +./data/imagenet/train +./data/imagenet/val +... +``` + +## Quick Start + +### I. Search + +First, build environments for searching. + +``` +pip install -r ./requirements + +git clone https://github.com/NVIDIA/apex.git +cd apex +python setup.py install --cpp_ext --cuda_ext +``` + +To search for an architecture, you need to configure the parameters `FLOPS_MINIMUM` and `FLOPS_MAXIMUM` to specify the desired model flops, such as [0,600]MB flops. You can specify the flops interval by changing these two parameters in `./configs/train.yaml` + +``` +FLOPS_MINIMUM: 0 # Minimum Flops of Architecture +FLOPS_MAXIMUM: 600 # Maximum Flops of Architecture +``` + +For example, if you expect to search an architecture with model flops <= 200M, please set the `FLOPS_MINIMUM` and `FLOPS_MAXIMUM` to be `0` and `200`. + +After you specify the flops of the architectures you would like to search, you can search an architecture now by running: + +``` +python -m torch.distributed.launch --nproc_per_node=8 ./train.py --cfg ./configs/train.yaml +``` + +The searched architectures need to be retrained and obtain the final model. The final model is saved in `.pth.tar` format. Retraining code will be released soon. + +### II. Retrain + +To train searched architectures, you need to configure the parameter `MODEL_SELECTION` to specify the model Flops. To specify which model to train, you should add `MODEL_SELECTION` in `./configs/retrain.yaml`. You can select one from [14,43,112,287,481,604], which stands for different Flops(MB). + +``` +MODEL_SELECTION: 43 # Retrain 43m model +MODEL_SELECTION: 481 # Retrain 481m model +...... +``` + +To train random architectures, you need specify `MODEL_SELECTION` to `-1` and configure the parameter `INPUT_ARCH`: + +``` +MODEL_SELECTION: -1 # Train random architectures +INPUT_ARCH: [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] # Random Architectures +...... +``` + +After adding `MODEL_SELECTION` in `./configs/retrain.yaml`, you need to use the following command to train the model. + +``` +python -m torch.distributed.launch --nproc_per_node=8 ./retrain.py --cfg ./configs/retrain.yaml +``` + +### III. Test + +To test our trained of models, you need to use `MODEL_SELECTION` in `./configs/test.yaml` to specify which model to test. + +``` +MODEL_SELECTION: 43 # test 43m model +MODEL_SELECTION: 481 # test 470m model +...... +``` + +After specifying the flops of the model, you need to write the path to the resume model in `./test.sh`. + +``` +RESUME_PATH: './43.pth.tar' +RESUME_PATH: './481.pth.tar' +...... +``` + +We provide 14M/43M/114M/287M/481M/604M pretrained models in [google drive](https://drive.google.com/drive/folders/1CQjyBryZ4F20Rutj7coF8HWFcedApUn2) or [[Models-Baidu Disk (password: wqw6)]](https://pan.baidu.com/s/1TqQNm2s14oEdyNPimw3T9g) . + +After downloading the pretrained models and adding `MODEL_SELECTION` and `RESUME_PATH` in './configs/test.yaml', you need to use the following command to test the model. + +``` +python -m torch.distributed.launch --nproc_per_node=8 ./test.py --cfg ./configs/test.yaml +``` diff --git a/docs/en_US/NAS/one_shot_nas.rst b/docs/en_US/NAS/one_shot_nas.rst index cc7fa688b6..77b3cfcc94 100644 --- a/docs/en_US/NAS/one_shot_nas.rst +++ b/docs/en_US/NAS/one_shot_nas.rst @@ -14,4 +14,5 @@ One-shot NAS algorithms leverage weight sharing among models in neural architect SPOS CDARTS ProxylessNAS - TextNAS \ No newline at end of file + TextNAS + Cream diff --git a/docs/img/cream.png b/docs/img/cream.png new file mode 100644 index 0000000000..99a24840a7 Binary files /dev/null and b/docs/img/cream.png differ diff --git a/docs/img/cream_flops100.jpg b/docs/img/cream_flops100.jpg new file mode 100644 index 0000000000..a31078dd8f Binary files /dev/null and b/docs/img/cream_flops100.jpg differ diff --git a/docs/img/cream_flops600.jpg b/docs/img/cream_flops600.jpg new file mode 100644 index 0000000000..e9f7a5a6d0 Binary files /dev/null and b/docs/img/cream_flops600.jpg differ diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/nas/__init__.py b/examples/nas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/nas/cream/Cream.md b/examples/nas/cream/Cream.md new file mode 100644 index 0000000000..a871bddf78 --- /dev/null +++ b/examples/nas/cream/Cream.md @@ -0,0 +1 @@ +[Documentation](https://nni.readthedocs.io/en/latest/NAS/Cream.html) diff --git a/examples/nas/cream/__init__.py b/examples/nas/cream/__init__.py new file mode 100755 index 0000000000..e69de29bb2 diff --git a/examples/nas/cream/configs/retrain.yaml b/examples/nas/cream/configs/retrain.yaml new file mode 100644 index 0000000000..2339dea982 --- /dev/null +++ b/examples/nas/cream/configs/retrain.yaml @@ -0,0 +1,52 @@ +AUTO_RESUME: False +DATA_DIR: './data/imagenet' +MODEL: '604m_retrain' +RESUME_PATH: './experiments/workspace/retrain/resume.pth.tar' +SAVE_PATH: './' +SEED: 42 +LOG_INTERVAL: 50 +RECOVERY_INTERVAL: 0 +WORKERS: 4 +NUM_GPU: 2 +SAVE_IMAGES: False +AMP: False +OUTPUT: 'None' +EVAL_METRICS: 'prec1' +TTA: 0 +LOCAL_RANK: 0 + +DATASET: + NUM_CLASSES: 1000 + IMAGE_SIZE: 224 # image patch size + INTERPOLATION: 'random' # Image resize interpolation type + BATCH_SIZE: 32 # batch size + NO_PREFECHTER: False + +NET: + GP: 'avg' + DROPOUT_RATE: 0.0 + SELECTION: 42 + + EMA: + USE: True + FORCE_CPU: False # force model ema to be tracked on CPU + DECAY: 0.9998 + +OPT: 'sgd' +OPT_EPS: 1e-2 +MOMENTUM: 0.9 +DECAY_RATE: 0.1 + +SCHED: 'sgd' +LR_NOISE: None +LR_NOISE_PCT: 0.67 +LR_NOISE_STD: 1.0 +WARMUP_LR: 1e-4 +MIN_LR: 1e-5 +EPOCHS: 200 +START_EPOCH: None +DECAY_EPOCHS: 30.0 +WARMUP_EPOCHS: 3 +COOLDOWN_EPOCHS: 10 +PATIENCE_EPOCHS: 10 +LR: 1e-2 \ No newline at end of file diff --git a/examples/nas/cream/configs/test.yaml b/examples/nas/cream/configs/test.yaml new file mode 100644 index 0000000000..4bf568517f --- /dev/null +++ b/examples/nas/cream/configs/test.yaml @@ -0,0 +1,37 @@ +AUTO_RESUME: True +DATA_DIR: './data/imagenet' +MODEL: 'Childnet_Testing' +RESUME_PATH: './experiments/workspace/ckps/42.pth.tar' +SAVE_PATH: './' +SEED: 42 +LOG_INTERVAL: 50 +RECOVERY_INTERVAL: 0 +WORKERS: 4 +NUM_GPU: 2 +SAVE_IMAGES: False +AMP: False +OUTPUT: 'None' +EVAL_METRICS: 'prec1' +TTA: 0 +LOCAL_RANK: 0 + +DATASET: + NUM_CLASSES: 1000 + IMAGE_SIZE: 224 # image patch size + INTERPOLATION: 'bilinear' # Image resize interpolation type + BATCH_SIZE: 32 # batch size + NO_PREFECHTER: False + +NET: + GP: 'avg' + DROPOUT_RATE: 0.0 + SELECTION: 42 + + EMA: + USE: True + FORCE_CPU: False # force model ema to be tracked on CPU + DECAY: 0.9998 + +OPTIMIZER: + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-3 \ No newline at end of file diff --git a/examples/nas/cream/configs/train.yaml b/examples/nas/cream/configs/train.yaml new file mode 100644 index 0000000000..85164e0eda --- /dev/null +++ b/examples/nas/cream/configs/train.yaml @@ -0,0 +1,53 @@ +AUTO_RESUME: False +DATA_DIR: './data/imagenet' +MODEL: 'Supernet_Training' +RESUME_PATH: './experiments/workspace/train/resume.pth.tar' +SAVE_PATH: './' +SEED: 42 +LOG_INTERVAL: 50 +RECOVERY_INTERVAL: 0 +WORKERS: 8 +NUM_GPU: 8 +SAVE_IMAGES: False +AMP: False +OUTPUT: 'None' +EVAL_METRICS: 'prec1' +TTA: 0 +LOCAL_RANK: 0 + +DATASET: + NUM_CLASSES: 1000 + IMAGE_SIZE: 224 # image patch size + INTERPOLATION: 'bilinear' # Image resize interpolation type + BATCH_SIZE: 128 # batch size + +NET: + GP: 'avg' + DROPOUT_RATE: 0.0 + + EMA: + USE: True + FORCE_CPU: False # force model ema to be tracked on CPU + DECAY: 0.9998 + +OPT: 'sgd' +LR: 1.0 +EPOCHS: 120 +META_LR: 1e-4 + +BATCHNORM: + SYNC_BN: False + +SUPERNET: + UPDATE_ITER: 200 + SLICE: 4 + POOL_SIZE: 10 + RESUNIT: False + DIL_CONV: False + UPDATE_2ND: True + FLOPS_MINIMUM: 0 + FLOPS_MAXIMUM: 600 + PICK_METHOD: 'meta' + META_STA_EPOCH: 20 + HOW_TO_PROB: 'pre_prob' + PRE_PROB: (0.05,0.2,0.05,0.5,0.05,0.15) \ No newline at end of file diff --git a/examples/nas/cream/lib/config.py b/examples/nas/cream/lib/config.py new file mode 100644 index 0000000000..fd50b4a9a5 --- /dev/null +++ b/examples/nas/cream/lib/config.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from yacs.config import CfgNode as CN + +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +__C = CN() + +cfg = __C + +__C.AUTO_RESUME = True +__C.DATA_DIR = './data/imagenet' +__C.MODEL = 'cream' +__C.RESUME_PATH = './experiments/ckps/resume.pth.tar' +__C.SAVE_PATH = './experiments/ckps/' +__C.SEED = 42 +__C.LOG_INTERVAL = 50 +__C.RECOVERY_INTERVAL = 0 +__C.WORKERS = 4 +__C.NUM_GPU = 1 +__C.SAVE_IMAGES = False +__C.AMP = False +__C.ACC_GAP = 5 +__C.OUTPUT = 'output/path/' +__C.EVAL_METRICS = 'prec1' +__C.TTA = 0 # Test or inference time augmentation +__C.LOCAL_RANK = 0 +__C.VERBOSE = False + +# dataset configs +__C.DATASET = CN() +__C.DATASET.NUM_CLASSES = 1000 +__C.DATASET.IMAGE_SIZE = 224 # image patch size +__C.DATASET.INTERPOLATION = 'bilinear' # Image resize interpolation type +__C.DATASET.BATCH_SIZE = 32 # batch size +__C.DATASET.NO_PREFECHTER = False +__C.DATASET.PIN_MEM = True +__C.DATASET.VAL_BATCH_MUL = 4 + + +# model configs +__C.NET = CN() +__C.NET.SELECTION = 14 +__C.NET.GP = 'avg' # type of global pool ["avg", "max", "avgmax", "avgmaxc"] +__C.NET.DROPOUT_RATE = 0.0 # dropout rate +__C.NET.INPUT_ARCH = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + +# model ema parameters +__C.NET.EMA = CN() +__C.NET.EMA.USE = True +__C.NET.EMA.FORCE_CPU = False # force model ema to be tracked on CPU +__C.NET.EMA.DECAY = 0.9998 + +# optimizer configs +__C.OPT = 'sgd' +__C.OPT_EPS = 1e-2 +__C.MOMENTUM = 0.9 +__C.WEIGHT_DECAY = 1e-4 +__C.OPTIMIZER = CN() +__C.OPTIMIZER.NAME = 'sgd' +__C.OPTIMIZER.MOMENTUM = 0.9 +__C.OPTIMIZER.WEIGHT_DECAY = 1e-3 + +# scheduler configs +__C.SCHED = 'sgd' +__C.LR_NOISE = None +__C.LR_NOISE_PCT = 0.67 +__C.LR_NOISE_STD = 1.0 +__C.WARMUP_LR = 1e-4 +__C.MIN_LR = 1e-5 +__C.EPOCHS = 200 +__C.START_EPOCH = None +__C.DECAY_EPOCHS = 30.0 +__C.WARMUP_EPOCHS = 3 +__C.COOLDOWN_EPOCHS = 10 +__C.PATIENCE_EPOCHS = 10 +__C.DECAY_RATE = 0.1 +__C.LR = 1e-2 +__C.META_LR = 1e-4 + +# data augmentation parameters +__C.AUGMENTATION = CN() +__C.AUGMENTATION.AA = 'rand-m9-mstd0.5' +__C.AUGMENTATION.COLOR_JITTER = 0.4 +__C.AUGMENTATION.RE_PROB = 0.2 # random erase prob +__C.AUGMENTATION.RE_MODE = 'pixel' # random erase mode +__C.AUGMENTATION.MIXUP = 0.0 # mixup alpha +__C.AUGMENTATION.MIXUP_OFF_EPOCH = 0 # turn off mixup after this epoch +__C.AUGMENTATION.SMOOTHING = 0.1 # label smoothing parameters + +# batch norm parameters (only works with gen_efficientnet based models +# currently) +__C.BATCHNORM = CN() +__C.BATCHNORM.SYNC_BN = True +__C.BATCHNORM.BN_TF = False +__C.BATCHNORM.BN_MOMENTUM = 0.1 # batchnorm momentum override +__C.BATCHNORM.BN_EPS = 1e-5 # batchnorm eps override + +# supernet training hyperparameters +__C.SUPERNET = CN() +__C.SUPERNET.UPDATE_ITER = 1300 +__C.SUPERNET.SLICE = 4 +__C.SUPERNET.POOL_SIZE = 10 +__C.SUPERNET.RESUNIT = False +__C.SUPERNET.DIL_CONV = False +__C.SUPERNET.UPDATE_2ND = True +__C.SUPERNET.FLOPS_MAXIMUM = 600 +__C.SUPERNET.FLOPS_MINIMUM = 0 +__C.SUPERNET.PICK_METHOD = 'meta' # pick teacher method +__C.SUPERNET.META_STA_EPOCH = 20 # start using meta picking method +__C.SUPERNET.HOW_TO_PROB = 'pre_prob' # sample method +__C.SUPERNET.PRE_PROB = (0.05, 0.2, 0.05, 0.5, 0.05, + 0.15) # sample prob in 'pre_prob' diff --git a/examples/nas/cream/lib/core/retrain.py b/examples/nas/cream/lib/core/retrain.py new file mode 100644 index 0000000000..7468db2bb5 --- /dev/null +++ b/examples/nas/cream/lib/core/retrain.py @@ -0,0 +1,135 @@ +import os +import time +import torch +import torchvision + +from collections import OrderedDict + +from lib.utils.util import AverageMeter, accuracy, reduce_tensor + +def train_epoch( + epoch, model, loader, optimizer, loss_fn, cfg, + lr_scheduler=None, saver=None, output_dir='', use_amp=False, + model_ema=None, logger=None, writer=None, local_rank=0): + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + losses_m = AverageMeter() + prec1_m = AverageMeter() + prec5_m = AverageMeter() + + model.train() + + end = time.time() + last_idx = len(loader) - 1 + num_updates = epoch * len(loader) + optimizer.zero_grad() + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + data_time_m.update(time.time() - end) + + input = input.cuda() + target = target.cuda() + output = model(input) + + loss = loss_fn(output, target) + + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + if cfg.NUM_GPU > 1: + reduced_loss = reduce_tensor(loss.data, cfg.NUM_GPU) + prec1 = reduce_tensor(prec1, cfg.NUM_GPU) + prec5 = reduce_tensor(prec5, cfg.NUM_GPU) + else: + reduced_loss = loss.data + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + if model_ema is not None: + model_ema.update(model) + num_updates += 1 + + batch_time_m.update(time.time() - end) + if last_batch or batch_idx % cfg.LOG_INTERVAL == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + + if local_rank == 0: + logger.info( + 'Train: {} [{:>4d}/{}] ' + 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e}' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, + len(loader), + loss=losses_m, + top1=prec1_m, + top5=prec5_m, + batch_time=batch_time_m, + rate=input.size(0) * + cfg.NUM_GPU / + batch_time_m.val, + rate_avg=input.size(0) * + cfg.NUM_GPU / + batch_time_m.avg, + lr=lr, + data_time=data_time_m)) + + writer.add_scalar( + 'Loss/train', + prec1_m.avg, + epoch * + len(loader) + + batch_idx) + writer.add_scalar( + 'Accuracy/train', + prec1_m.avg, + epoch * + len(loader) + + batch_idx) + writer.add_scalar( + 'Learning_Rate', + optimizer.param_groups[0]['lr'], + epoch * len(loader) + batch_idx) + + if cfg.SAVE_IMAGES and output_dir: + torchvision.utils.save_image( + input, os.path.join( + output_dir, 'train-batch-%d.jpg' % + batch_idx), padding=0, normalize=True) + + if saver is not None and cfg.RECOVERY_INTERVAL and ( + last_batch or (batch_idx + 1) % cfg.RECOVERY_INTERVAL == 0): + saver.save_recovery( + model, + optimizer, + cfg, + epoch, + model_ema=model_ema, + use_amp=use_amp, + batch_idx=batch_idx) + + if lr_scheduler is not None: + lr_scheduler.step_update( + num_updates=num_updates, + metric=losses_m.avg) + + end = time.time() + # end for + + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() + + return OrderedDict([('loss', losses_m.avg)]) diff --git a/examples/nas/cream/lib/core/test.py b/examples/nas/cream/lib/core/test.py new file mode 100644 index 0000000000..7ab69b57c0 --- /dev/null +++ b/examples/nas/cream/lib/core/test.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import time +import torch + +from collections import OrderedDict +from lib.utils.util import AverageMeter, accuracy, reduce_tensor + + +def validate(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0): + batch_time_m = AverageMeter() + losses_m = AverageMeter() + prec1_m = AverageMeter() + prec5_m = AverageMeter() + + model.eval() + + end = time.time() + last_idx = len(loader) - 1 + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + + output = model(input) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = cfg.TTA + if reduce_factor > 1: + output = output.unfold( + 0, + reduce_factor, + reduce_factor).mean( + dim=2) + target = target[0:target.size(0):reduce_factor] + + loss = loss_fn(output, target) + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + if cfg.NUM_GPU > 1: + reduced_loss = reduce_tensor(loss.data, cfg.NUM_GPU) + prec1 = reduce_tensor(prec1, cfg.NUM_GPU) + prec5 = reduce_tensor(prec5, cfg.NUM_GPU) + else: + reduced_loss = loss.data + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + batch_time_m.update(time.time() - end) + end = time.time() + if local_rank == 0 and (last_batch or batch_idx % cfg.LOG_INTERVAL == 0): + log_name = 'Test' + log_suffix + logger.info( + '{0}: [{1:>4d}/{2}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, + batch_time=batch_time_m, loss=losses_m, + top1=prec1_m, top5=prec5_m)) + + writer.add_scalar( + 'Loss' + log_suffix + '/vaild', + prec1_m.avg, + epoch * len(loader) + batch_idx) + writer.add_scalar( + 'Accuracy' + + log_suffix + + '/vaild', + prec1_m.avg, + epoch * + len(loader) + + batch_idx) + + metrics = OrderedDict( + [('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) + + return metrics diff --git a/examples/nas/cream/lib/models/blocks/__init__.py b/examples/nas/cream/lib/models/blocks/__init__.py new file mode 100644 index 0000000000..83a19f2b91 --- /dev/null +++ b/examples/nas/cream/lib/models/blocks/__init__.py @@ -0,0 +1,2 @@ +from lib.models.blocks.residual_block import get_Bottleneck, get_BasicBlock +from lib.models.blocks.inverted_residual_block import InvertedResidual \ No newline at end of file diff --git a/examples/nas/cream/lib/models/blocks/inverted_residual_block.py b/examples/nas/cream/lib/models/blocks/inverted_residual_block.py new file mode 100644 index 0000000000..2f501b561b --- /dev/null +++ b/examples/nas/cream/lib/models/blocks/inverted_residual_block.py @@ -0,0 +1,113 @@ +# This file is downloaded from +# https://github.com/rwightman/pytorch-image-models + +import torch.nn as nn + +from timm.models.layers import create_conv2d +from timm.models.efficientnet_blocks import make_divisible, resolve_se_args, \ + SqueezeExcite, drop_path + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__( + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + pad_type='', + act_layer=nn.ReLU, + noskip=False, + exp_ratio=1.0, + exp_kernel_size=1, + pw_kernel_size=1, + se_ratio=0., + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + conv_kwargs=None, + drop_path_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d( + in_chs, + mid_chs, + exp_kernel_size, + padding=pad_type, + **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if has_se: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None + + # Point-wise linear projection + self.conv_pwl = create_conv2d( + mid_chs, + out_chs, + pw_kernel_size, + padding=pad_type, + **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + info = dict( + module='conv_pwl', + hook_type='forward_pre', + num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict( + module='', + hook_type='', + num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + + return x diff --git a/examples/nas/cream/lib/models/blocks/residual_block.py b/examples/nas/cream/lib/models/blocks/residual_block.py new file mode 100644 index 0000000000..75892eee79 --- /dev/null +++ b/examples/nas/cream/lib/models/blocks/residual_block.py @@ -0,0 +1,105 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=True) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + + def __init__(self, inplanes, planes, stride=1, expansion=4): + super(Bottleneck, self).__init__() + planes = int(planes / expansion) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=True) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d( + planes, + planes * expansion, + kernel_size=1, + bias=True) + self.bn3 = nn.BatchNorm2d(planes * expansion) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + self.expansion = expansion + if inplanes != planes * self.expansion: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * self.expansion, + kernel_size=1, stride=stride, bias=True), + nn.BatchNorm2d(planes * self.expansion), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +def get_Bottleneck(in_c, out_c, stride): + return Bottleneck(in_c, out_c, stride=stride) + + +def get_BasicBlock(in_c, out_c, stride): + return BasicBlock(in_c, out_c, stride=stride) diff --git a/examples/nas/cream/lib/models/builders/build_childnet.py b/examples/nas/cream/lib/models/builders/build_childnet.py new file mode 100755 index 0000000000..8ddfb40024 --- /dev/null +++ b/examples/nas/cream/lib/models/builders/build_childnet.py @@ -0,0 +1,181 @@ +from lib.utils.util import * + +from timm.models.efficientnet_blocks import * + + +class ChildNetBuilder: + def __init__( + self, + channel_multiplier=1.0, + channel_divisor=8, + channel_min=None, + output_stride=32, + pad_type='', + act_layer=None, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + drop_path_rate=0., + feature_location='', + verbose=False, + logger=None): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_path_rate = drop_path_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + self.in_chs = None + self.features = OrderedDict() + self.logger = logger + + def _round_channels(self, chs): + return round_channels( + chs, + self.channel_multiplier, + self.channel_divisor, + self.channel_min) + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' InvertedResidual {}, Args: {}'.format( + block_idx, str(ba))) + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' DepthwiseSeparable {}, Args: {}'.format( + block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'cn': + if self.verbose: + self.logger.info( + ' ConvBnAct {}, Args: {}'.format( + block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + self.logger.info( + 'Building model trunk with %d stages...' % + len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some + # conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + self.logger.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + extract_features = '' # No features extracted + if self.verbose: + self.logger.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + do_extract = False + if self.feature_location == 'pre_pwl': + if last_block: + next_stage_idx = stage_idx + 1 + if next_stage_idx >= len(model_block_args): + do_extract = True + else: + do_extract = model_block_args[next_stage_idx][0]['stride'] > 1 + elif self.feature_location == 'post_exp': + if block_args['stride'] > 1 or (last_stack and last_block): + do_extract = True + if do_extract: + extract_features = self.feature_location + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + if self.verbose: + self.logger.info( + ' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride)) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block( + block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature + # extraction + if extract_features: + feature_module = block.feature_module(extract_features) + if feature_module: + feature_module = 'blocks.{}.{}.'.format( + stage_idx, block_idx) + feature_module + feature_channels = block.feature_channels(extract_features) + self.features[feature_idx] = dict( + name=feature_module, + num_chs=feature_channels + ) + feature_idx += 1 + + # incr global block idx (across all stacks) + total_block_idx += 1 + stages.append(nn.Sequential(*blocks)) + return stages diff --git a/examples/nas/cream/lib/models/builders/build_supernet.py b/examples/nas/cream/lib/models/builders/build_supernet.py new file mode 100644 index 0000000000..37d9c575c8 --- /dev/null +++ b/examples/nas/cream/lib/models/builders/build_supernet.py @@ -0,0 +1,214 @@ +from copy import deepcopy + +from lib.utils.builder_util import modify_block_args +from lib.models.blocks import get_Bottleneck, InvertedResidual + +from timm.models.efficientnet_blocks import * + +from nni.nas.pytorch import mutables + +class SuperNetBuilder: + """ Build Trunk Blocks + """ + + def __init__( + self, + choices, + channel_multiplier=1.0, + channel_divisor=8, + channel_min=None, + output_stride=32, + pad_type='', + act_layer=None, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + drop_path_rate=0., + feature_location='', + verbose=False, + resunit=False, + dil_conv=False, + logger=None): + + # dict + # choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + self.choices = [[x, y] for x in choices['kernel_size'] + for y in choices['exp_ratio']] + self.choices_num = len(self.choices) - 1 + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.output_stride = output_stride + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_path_rate = drop_path_rate + self.feature_location = feature_location + assert feature_location in ('pre_pwl', 'post_exp', '') + self.verbose = verbose + self.resunit = resunit + self.dil_conv = dil_conv + self.logger = logger + + # state updated during build, consumed by model + self.in_chs = None + + def _round_channels(self, chs): + return round_channels( + chs, + self.channel_multiplier, + self.channel_divisor, + self.channel_min) + + def _make_block( + self, + ba, + choice_idx, + block_idx, + block_count, + resunit=False, + dil_conv=False): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input + # filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' InvertedResidual {}, Args: {}'.format( + block_idx, str(ba))) + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_path_rate'] = drop_path_rate + ba['se_kwargs'] = self.se_kwargs + if self.verbose: + self.logger.info( + ' DepthwiseSeparable {}, Args: {}'.format( + block_idx, str(ba))) + block = DepthwiseSeparableConv(**ba) + elif bt == 'cn': + if self.verbose: + self.logger.info( + ' ConvBnAct {}, Args: {}'.format( + block_idx, str(ba))) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + if choice_idx == self.choice_num - 1: + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + if self.verbose: + logging.info('Building model trunk with %d stages...' % len(model_block_args)) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + feature_idx = 0 + stages = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stage_idx, stage_block_args in enumerate(model_block_args): + last_stack = stage_idx == (len(model_block_args) - 1) + if self.verbose: + self.logger.info('Stack: {}'.format(stage_idx)) + assert isinstance(stage_block_args, list) + + # blocks = [] + # each stack (stage) contains a list of block arguments + for block_idx, block_args in enumerate(stage_block_args): + last_block = block_idx == (len(stage_block_args) - 1) + if self.verbose: + self.logger.info(' Block: {}'.format(block_idx)) + + # Sort out stride, dilation, and feature extraction details + assert block_args['stride'] in (1, 2) + if block_idx >= 1: + # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + + if stage_idx==0 or stage_idx==6: + self.choice_num = 1 + else: + self.choice_num = len(self.choices) + + if self.dil_conv: + self.choice_num += 2 + + choice_blocks = [] + block_args_copy = deepcopy(block_args) + if self.choice_num == 1: + # create the block + block = self._make_block(block_args, 0, total_block_idx, total_block_count) + choice_blocks.append(block) + else: + for choice_idx, choice in enumerate(self.choices): + # create the block + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, choice[0], choice[1]) + block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count) + choice_blocks.append(block) + if self.dil_conv: + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, 3, 0) + block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count, + resunit=self.resunit, dil_conv=self.dil_conv) + choice_blocks.append(block) + + block_args = deepcopy(block_args_copy) + block_args = modify_block_args(block_args, 5, 0) + block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count, + resunit=self.resunit, dil_conv=self.dil_conv) + choice_blocks.append(block) + + if self.resunit: + block = get_Bottleneck(block.conv_pw.in_channels, + block.conv_pwl.out_channels, + block.conv_dw.stride[0]) + choice_blocks.append(block) + + choice_block = mutables.LayerChoice(choice_blocks) + stages.append(choice_block) + # create the block + # block = self._make_block(block_args, total_block_idx, total_block_count) + total_block_idx += 1 # incr global block idx (across all stacks) + + # stages.append(blocks) + return stages diff --git a/examples/nas/cream/lib/models/structures/childnet.py b/examples/nas/cream/lib/models/structures/childnet.py new file mode 100755 index 0000000000..668b92e157 --- /dev/null +++ b/examples/nas/cream/lib/models/structures/childnet.py @@ -0,0 +1,145 @@ +from lib.utils.builder_util import * +from lib.models.builders.build_childnet import * + +from timm.models.layers import SelectAdaptivePool2d +from timm.models.layers.activations import hard_sigmoid + + +class ChildNet(nn.Module): + + def __init__( + self, + block_args, + num_classes=1000, + in_chans=3, + stem_size=16, + num_features=1280, + head_bias=True, + channel_multiplier=1.0, + pad_type='', + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + norm_kwargs=None, + global_pool='avg', + logger=None, + verbose=False): + super(ChildNet, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + self.logger = logger + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d( + self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = ChildNetBuilder( + channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, + norm_layer, norm_kwargs, drop_path_rate, verbose=verbose) + self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) + # self.blocks = builder(self._in_chs, block_args) + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = create_conv2d( + self._in_chs, + self.num_features, + 1, + padding=pad_type, + bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear( + self.num_features * + self.global_pool.feat_mult(), + self.num_classes) + + efficientnet_init_weights(self) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), + num_classes) if self.num_classes else None + + def forward_features(self, x): + # architecture = [[0], [], [], [], [], [0]] + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def gen_childnet(arch_list, arch_def, **kwargs): + # arch_list = [[0], [], [], [], [], [0]] + choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + choices_list = [[x, y] for x in choices['kernel_size'] + for y in choices['exp_ratio']] + + num_features = 1280 + + # act_layer = HardSwish + act_layer = Swish + + new_arch = [] + # change to child arch_def + for i, (layer_choice, layer_arch) in enumerate(zip(arch_list, arch_def)): + if len(layer_arch) == 1: + new_arch.append(layer_arch) + continue + else: + new_layer = [] + for j, (block_choice, block_arch) in enumerate( + zip(layer_choice, layer_arch)): + kernel_size, exp_ratio = choices_list[block_choice] + elements = block_arch.split('_') + block_arch = block_arch.replace( + elements[2], 'k{}'.format(str(kernel_size))) + block_arch = block_arch.replace( + elements[4], 'e{}'.format(str(exp_ratio))) + new_layer.append(block_arch) + new_arch.append(new_layer) + + model_kwargs = dict( + block_args=decode_arch_def(new_arch), + num_features=num_features, + stem_size=16, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict( + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + reduce_mid=True, + divisor=8), + **kwargs, + ) + model = ChildNet(**model_kwargs) + return model diff --git a/examples/nas/cream/lib/models/structures/supernet.py b/examples/nas/cream/lib/models/structures/supernet.py new file mode 100644 index 0000000000..ea09377eb5 --- /dev/null +++ b/examples/nas/cream/lib/models/structures/supernet.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +from lib.utils.builder_util import * +from lib.utils.search_structure_supernet import * +from lib.models.builders.build_supernet import * +from lib.utils.op_by_layer_dict import flops_op_dict + +from timm.models.layers import SelectAdaptivePool2d +from timm.models.layers.activations import hard_sigmoid + + +class SuperNet(nn.Module): + + def __init__( + self, + block_args, + choices, + num_classes=1000, + in_chans=3, + stem_size=16, + num_features=1280, + head_bias=True, + channel_multiplier=1.0, + pad_type='', + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + slice=4, + se_kwargs=None, + norm_layer=nn.BatchNorm2d, + logger=None, + norm_kwargs=None, + global_pool='avg', + resunit=False, + dil_conv=False, + verbose=False): + super(SuperNet, self).__init__() + + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + self._in_chs = in_chans + self.logger = logger + + # Stem + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = create_conv2d( + self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + self._in_chs = stem_size + + # Middle stages (IR/ER/DS Blocks) + builder = SuperNetBuilder( + choices, + channel_multiplier, + 8, + None, + 32, + pad_type, + act_layer, + se_kwargs, + norm_layer, + norm_kwargs, + drop_path_rate, + verbose=verbose, + resunit=resunit, + dil_conv=dil_conv, + logger=self.logger) + blocks = builder(self._in_chs, block_args) + self.blocks = nn.Sequential(*blocks) + self._in_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = create_conv2d( + self._in_chs, + self.num_features, + 1, + padding=pad_type, + bias=head_bias) + self.act2 = act_layer(inplace=True) + + # Classifier + self.classifier = nn.Linear( + self.num_features * + self.global_pool.feat_mult(), + self.num_classes) + + self.meta_layer = nn.Linear(self.num_classes * slice, 1) + efficientnet_init_weights(self) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.num_classes = num_classes + self.classifier = nn.Linear( + self.num_features * self.global_pool.feat_mult(), + num_classes) if self.num_classes else None + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + def forward_meta(self, features): + return self.meta_layer(features.view(1, -1)) + + def rand_parameters(self, architecture, meta=False): + for name, param in self.named_parameters(recurse=True): + if 'meta' in name and meta: + yield param + elif 'blocks' not in name and 'meta' not in name and (not meta): + yield param + + if not meta: + for layer, layer_arch in zip(self.blocks, architecture): + for blocks, arch in zip(layer, layer_arch): + if arch == -1: + continue + for name, param in blocks[arch].named_parameters( + recurse=True): + yield param + + +class Classifier(nn.Module): + def __init__(self, num_classes=1000): + super(Classifier, self).__init__() + self.classifier = nn.Linear(num_classes, num_classes) + + def forward(self, x): + return self.classifier(x) + + +def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs): + choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]} + + num_features = 1280 + + # act_layer = HardSwish + act_layer = Swish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_se0.25'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', + 'ir_r1_k3_s1_e4_c24_se0.25'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', + 'ir_r2_k3_s1_e4_c80_se0.25'], + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25'], + # stage 5, 14x14in + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25', + 'ir_r1_k5_s2_e6_c192_se0.25'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c320_se0.25'], + ] + + sta_num, arch_def, resolution = search_for_layer( + flops_op_dict, arch_def, flops_minimum, flops_maximum) + + if sta_num is None or arch_def is None or resolution is None: + raise ValueError('Invalid FLOPs Settings') + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + choices=choices, + num_features=num_features, + stem_size=16, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + se_kwargs=dict( + act_layer=nn.ReLU, + gate_fn=hard_sigmoid, + reduce_mid=True, + divisor=8), + **kwargs, + ) + model = SuperNet(**model_kwargs) + return model, sta_num, resolution diff --git a/examples/nas/cream/lib/utils/builder_util.py b/examples/nas/cream/lib/utils/builder_util.py new file mode 100644 index 0000000000..138e08299c --- /dev/null +++ b/examples/nas/cream/lib/utils/builder_util.py @@ -0,0 +1,273 @@ +import math +import torch.nn as nn + +from timm.utils import * +from timm.models.layers.activations import Swish +from timm.models.layers import CondConv2d, get_condconv_initializer + + +def parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def decode_arch_def( + arch_def, + depth_multiplier=1.0, + depth_trunc='ceil', + experts_multiplier=1): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + arch_args.append( + scale_stage_depth( + stack_args, + repeats, + depth_multiplier, + depth_trunc)) + return arch_args + + +def modify_block_args(block_args, kernel_size, exp_ratio): + block_type = block_args['block_type'] + if block_type == 'cn': + block_args['kernel_size'] = kernel_size + elif block_type == 'er': + block_args['exp_kernel_size'] = kernel_size + else: + block_args['dw_kernel_size'] = kernel_size + + if block_type == 'ir' or block_type == 'er': + block_args['exp_ratio'] = exp_ratio + return block_args + + +def decode_block_str(block_str): + """ Decode block definition string + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they + # grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = nn.ReLU + elif v == 'r6': + value = nn.ReLU6 + elif v == 'sw': + value = Swish + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be + # used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1 + # FIXME hack to deal with in_chs issue in TPU def + fake_in_chs = int(options['fc']) if 'fc' in options else 0 + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def scale_stage_depth( + stack_args, + repeats, + depth_multiplier=1.0, + depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as + # long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every + # stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch + # definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None): + """ Weight initialization as per Tensorflow official implementations. + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_( + 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + if n in last_bn: + m.weight.data.zero_() + m.bias.data.zero_() + else: + m.weight.data.fill_(1.0) + m.bias.data.zero_() + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def efficientnet_init_weights( + model: nn.Module, + init_fn=None, + zero_gamma=False): + last_bn = [] + if zero_gamma: + prev_n = '' + for n, m in model.named_modules(): + if isinstance(m, nn.BatchNorm2d): + if ''.join( + prev_n.split('.')[ + :- + 1]) != ''.join( + n.split('.')[ + :- + 1]): + last_bn.append(prev_n) + prev_n = n + last_bn.append(prev_n) + + init_fn = init_fn or init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n, last_bn=last_bn) + init_fn(m, n, last_bn=last_bn) diff --git a/examples/nas/cream/lib/utils/flops_table.py b/examples/nas/cream/lib/utils/flops_table.py new file mode 100644 index 0000000000..254241a075 --- /dev/null +++ b/examples/nas/cream/lib/utils/flops_table.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import torch + +from ptflops import get_model_complexity_info + + +class FlopsEst(object): + def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'): + self.block_num = len(model.blocks) + self.choice_num = len(model.blocks[0]) + self.flops_dict = {} + self.params_dict = {} + + if device == 'cpu': + model = model.cpu() + else: + model = model.cuda() + + self.params_fixed = 0 + self.flops_fixed = 0 + + input = torch.randn(input_shape) + + flops, params = get_model_complexity_info( + model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + input = model.conv_stem(input) + + for block_id, block in enumerate(model.blocks): + self.flops_dict[block_id] = {} + self.params_dict[block_id] = {} + for module_id, module in enumerate(block): + flops, params = get_model_complexity_info(module, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + # Flops(M) + self.flops_dict[block_id][module_id] = flops / 1e6 + # Params(M) + self.params_dict[block_id][module_id] = params / 1e6 + + input = module(input) + + # conv_last + flops, params = get_model_complexity_info(model.global_pool, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + input = model.global_pool(input) + + # globalpool + flops, params = get_model_complexity_info(model.conv_head, tuple( + input.shape[1:]), as_strings=False, print_per_layer_stat=False) + self.params_fixed += params / 1e6 + self.flops_fixed += flops / 1e6 + + # return params (M) + def get_params(self, arch): + params = 0 + for block_id, block in enumerate(arch): + if block == -1: + continue + params += self.params_dict[block_id][block] + return params + self.params_fixed + + # return flops (M) + def get_flops(self, arch): + flops = 0 + for block_id, block in enumerate(arch): + if block == 'LayerChoice1' or block_id == 'LayerChoice23': + continue + for idx, choice in enumerate(arch[block]): + flops += self.flops_dict[block_id][idx] * (1 if choice else 0) + return flops + self.flops_fixed diff --git a/examples/nas/cream/lib/utils/op_by_layer_dict.py b/examples/nas/cream/lib/utils/op_by_layer_dict.py new file mode 100644 index 0000000000..47ca509ce4 --- /dev/null +++ b/examples/nas/cream/lib/utils/op_by_layer_dict.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +# This dictionary is generated from calculating each operation of each layer to quickly search for layers. +# flops_op_dict[which_stage][which_operation] = +# (flops_of_operation_with_stride1, flops_of_operation_with_stride2) + +flops_op_dict = {} +for i in range(5): + flops_op_dict[i] = {} +flops_op_dict[0][0] = (21.828704, 18.820752) +flops_op_dict[0][1] = (32.669328, 28.16048) +flops_op_dict[0][2] = (25.039968, 23.637648) +flops_op_dict[0][3] = (37.486224, 35.385824) +flops_op_dict[0][4] = (29.856864, 30.862992) +flops_op_dict[0][5] = (44.711568, 46.22384) +flops_op_dict[1][0] = (11.808656, 11.86712) +flops_op_dict[1][1] = (17.68624, 17.780848) +flops_op_dict[1][2] = (13.01288, 13.87416) +flops_op_dict[1][3] = (19.492576, 20.791408) +flops_op_dict[1][4] = (14.819216, 16.88472) +flops_op_dict[1][5] = (22.20208, 25.307248) +flops_op_dict[2][0] = (8.198, 10.99632) +flops_op_dict[2][1] = (12.292848, 16.5172) +flops_op_dict[2][2] = (8.69976, 11.99984) +flops_op_dict[2][3] = (13.045488, 18.02248) +flops_op_dict[2][4] = (9.4524, 13.50512) +flops_op_dict[2][5] = (14.174448, 20.2804) +flops_op_dict[3][0] = (12.006112, 15.61632) +flops_op_dict[3][1] = (18.028752, 23.46096) +flops_op_dict[3][2] = (13.009632, 16.820544) +flops_op_dict[3][3] = (19.534032, 25.267296) +flops_op_dict[3][4] = (14.514912, 18.62688) +flops_op_dict[3][5] = (21.791952, 27.9768) +flops_op_dict[4][0] = (11.307456, 15.292416) +flops_op_dict[4][1] = (17.007072, 23.1504) +flops_op_dict[4][2] = (11.608512, 15.894528) +flops_op_dict[4][3] = (17.458656, 24.053568) +flops_op_dict[4][4] = (12.060096, 16.797696) +flops_op_dict[4][5] = (18.136032, 25.40832) \ No newline at end of file diff --git a/examples/nas/cream/lib/utils/search_structure_supernet.py b/examples/nas/cream/lib/utils/search_structure_supernet.py new file mode 100644 index 0000000000..b13491c2c7 --- /dev/null +++ b/examples/nas/cream/lib/utils/search_structure_supernet.py @@ -0,0 +1,47 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum): + sta_num = [1, 1, 1, 1, 1] + order = [2, 3, 4, 1, 0, 2, 3, 4, 1, 0] + limits = [3, 3, 3, 2, 2, 4, 4, 4, 4, 4] + size_factor = 224 // 32 + base_min_flops = sum([flops_op_dict[i][0][0] for i in range(5)]) + base_max_flops = sum([flops_op_dict[i][5][0] for i in range(5)]) + + if base_min_flops > flops_maximum: + while base_min_flops > flops_maximum and size_factor >= 2: + size_factor = size_factor - 1 + flops_minimum = flops_minimum * (7. / size_factor) + flops_maximum = flops_maximum * (7. / size_factor) + if size_factor < 2: + return None, None, None + elif base_max_flops < flops_minimum: + cur_ptr = 0 + while base_max_flops < flops_minimum and cur_ptr <= 9: + if sta_num[order[cur_ptr]] >= limits[cur_ptr]: + cur_ptr += 1 + continue + base_max_flops = base_max_flops + \ + flops_op_dict[order[cur_ptr]][5][1] + sta_num[order[cur_ptr]] += 1 + if cur_ptr > 7 and base_max_flops < flops_minimum: + return None, None, None + + cur_ptr = 0 + while cur_ptr <= 9: + if sta_num[order[cur_ptr]] >= limits[cur_ptr]: + cur_ptr += 1 + continue + base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1] + if base_max_flops <= flops_maximum: + sta_num[order[cur_ptr]] += 1 + else: + break + + arch_def = [item[:i] for i, item in zip([1] + sta_num + [1], arch_def)] + # print(arch_def) + + return sta_num, arch_def, size_factor * 32 diff --git a/examples/nas/cream/lib/utils/util.py b/examples/nas/cream/lib/utils/util.py new file mode 100644 index 0000000000..9324a003cc --- /dev/null +++ b/examples/nas/cream/lib/utils/util.py @@ -0,0 +1,178 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import sys +import argparse +import torch.nn as nn + +from torch import optim as optim +from thop import profile, clever_format + +from timm.utils import * + +from lib.config import cfg + + +def get_path_acc(model, path, val_loader, args, val_iters=50): + prec1_m = AverageMeter() + prec5_m = AverageMeter() + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(val_loader): + if batch_idx >= val_iters: + break + if not args.prefetcher: + input = input.cuda() + target = target.cuda() + + output = model(input, path) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold( + 0, + reduce_factor, + reduce_factor).mean( + dim=2) + target = target[0:target.size(0):reduce_factor] + + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + torch.cuda.synchronize() + + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + return (prec1_m.avg, prec5_m.avg) + + +def get_logger(file_path): + """ Make python logger """ + log_format = '%(asctime)s | %(message)s' + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + logger = logging.getLogger('') + + formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') + file_handler = logging.FileHandler(file_path) + file_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + + return logger + + +def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + meta_layer_no_decay = [] + meta_layer_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith( + ".bias") or name in skip_list: + if 'meta_layer' in name: + meta_layer_no_decay.append(param) + else: + no_decay.append(param) + else: + if 'meta_layer' in name: + meta_layer_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0., 'lr': args.lr}, + {'params': decay, 'weight_decay': weight_decay, 'lr': args.lr}, + {'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr}, + {'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr}, + ] + + +def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True): + opt_lower = args.opt.lower() + weight_decay = args.weight_decay + if 'adamw' in opt_lower or 'radam' in opt_lower: + weight_decay /= args.lr + if weight_decay and filter_bias_and_bn: + parameters = add_weight_decay_supernet(model, args, weight_decay) + weight_decay = 0. + else: + parameters = model.parameters() + + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available( + ), 'APEX and CUDA required for fused optimizers' + + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd' or opt_lower == 'nesterov': + optimizer = optim.SGD( + parameters, + momentum=args.momentum, + weight_decay=weight_decay, + nesterov=True) + elif opt_lower == 'momentum': + optimizer = optim.SGD( + parameters, + momentum=args.momentum, + weight_decay=weight_decay, + nesterov=False) + elif opt_lower == 'adam': + optimizer = optim.Adam( + parameters, weight_decay=weight_decay, eps=args.opt_eps) + else: + assert False and "Invalid optimizer" + raise ValueError + + return optimizer + + +def convert_lowercase(cfg): + keys = cfg.keys() + lowercase_keys = [key.lower() for key in keys] + values = [cfg.get(key) for key in keys] + for lowercase_key, value in zip(lowercase_keys, values): + cfg.setdefault(lowercase_key, value) + return cfg + + +def parse_config_args(exp_name): + parser = argparse.ArgumentParser(description=exp_name) + parser.add_argument( + '--cfg', + type=str, + default='../experiments/workspace/retrain/retrain.yaml', + help='configuration of cream') + parser.add_argument('--local_rank', type=int, default=0, + help='local_rank') + args = parser.parse_args() + + cfg.merge_from_file(args.cfg) + converted_cfg = convert_lowercase(cfg) + + return args, converted_cfg + + +def get_model_flops_params(model, input_size=(1, 3, 224, 224)): + input = torch.randn(input_size) + macs, params = profile(deepcopy(model), inputs=(input,), verbose=False) + macs, params = clever_format([macs, params], "%.3f") + return macs, params + + +def cross_entropy_loss_with_soft_target(pred, soft_target): + logsoftmax = nn.LogSoftmax() + return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) + + +def create_supernet_scheduler(cfg, optimizer): + ITERS = cfg.EPOCHS * \ + (1280000 / (cfg.NUM_GPU * cfg.DATASET.BATCH_SIZE)) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: ( + cfg.LR - step / ITERS) if step <= ITERS else 0, last_epoch=-1) + return lr_scheduler, cfg.EPOCHS diff --git a/examples/nas/cream/requirements b/examples/nas/cream/requirements new file mode 100644 index 0000000000..5ddae72e4c --- /dev/null +++ b/examples/nas/cream/requirements @@ -0,0 +1,12 @@ +yacs +numpy==1.17 +opencv-python==4.0.1.24 +torchvision==0.2.1 +thop +git+https://github.com/sovrasov/flops-counter.pytorch.git +pillow==6.1.0 +torch==1.2 +timm==0.1.20 +tensorboardx==1.2 +tensorboard +future \ No newline at end of file diff --git a/examples/nas/cream/retrain.py b/examples/nas/cream/retrain.py new file mode 100644 index 0000000000..566c929b8c --- /dev/null +++ b/examples/nas/cream/retrain.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import os +import warnings +import datetime +import torch +import numpy as np +import torch.nn as nn + +from torchscope import scope +from torch.utils.tensorboard import SummaryWriter + +# import timm packages +from timm.optim import create_optimizer +from timm.models import resume_checkpoint +from timm.scheduler import create_scheduler +from timm.data import Dataset, create_loader +from timm.utils import ModelEma, update_summary +from timm.loss import LabelSmoothingCrossEntropy + +# import apex as distributed package +try: + from apex import amp + from apex.parallel import DistributedDataParallel as DDP + from apex.parallel import convert_syncbn_model + HAS_APEX = True +except ImportError: + from torch.nn.parallel import DistributedDataParallel as DDP + HAS_APEX = False + +# import models and training functions +from lib.core.test import validate +from lib.core.retrain import train_epoch +from lib.models.structures.childnet import gen_childnet +from lib.utils.util import parse_config_args, get_logger, get_model_flops_params +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def main(): + args, cfg = parse_config_args('nni.cream.childnet') + + # resolve logging + output_dir = os.path.join(cfg.SAVE_PATH, + "{}-{}".format(datetime.date.today().strftime('%m%d'), + cfg.MODEL)) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if args.local_rank == 0: + logger = get_logger(os.path.join(output_dir, 'retrain.log')) + writer = SummaryWriter(os.path.join(output_dir, 'runs')) + else: + writer, logger = None, None + + # retrain model selection + if cfg.NET.SELECTION == 481: + arch_list = [ + [0], [ + 3, 4, 3, 1], [ + 3, 2, 3, 0], [ + 3, 3, 3, 1], [ + 3, 3, 3, 3], [ + 3, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 43: + arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 96 + elif cfg.NET.SELECTION == 14: + arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] + cfg.DATASET.IMAGE_SIZE = 64 + elif cfg.NET.SELECTION == 112: + arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 160 + elif cfg.NET.SELECTION == 287: + arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 604: + arch_list = [ + [0], [ + 3, 3, 2, 3, 3], [ + 3, 2, 3, 2, 3], [ + 3, 2, 3, 2, 3], [ + 3, 3, 2, 2, 3, 3], [ + 3, 3, 2, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == -1: + arch_list = cfg.NET.INPUT_ARCH + cfg.DATASET.IMAGE_SIZE = 224 + else: + raise ValueError("Model Retrain Selection is not Supported!") + + # define childnet architecture from arch_list + stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] + choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k3_s2_e6_c80_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k3_s2_e6_c192_se0.25'] + arch_def = [[stem[0]]] + [[choice_block_pool[idx] + for repeat_times in range(len(arch_list[idx + 1]))] + for idx in range(len(choice_block_pool))] + [[stem[1]]] + + # generate childnet + model = gen_childnet( + arch_list, + arch_def, + num_classes=cfg.DATASET.NUM_CLASSES, + drop_rate=cfg.NET.DROPOUT_RATE, + global_pool=cfg.NET.GP) + + # initialize training parameters + eval_metric = cfg.EVAL_METRICS + best_metric, best_epoch, saver = None, None, None + + # initialize distributed parameters + distributed = cfg.NUM_GPU > 1 + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.local_rank == 0: + logger.info( + 'Training on Process {} with {} GPUs.'.format( + args.local_rank, cfg.NUM_GPU)) + + # fix random seeds + torch.manual_seed(cfg.SEED) + torch.cuda.manual_seed_all(cfg.SEED) + np.random.seed(cfg.SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # get parameters and FLOPs of model + if args.local_rank == 0: + macs, params = get_model_flops_params(model, input_size=( + 1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) + logger.info( + '[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params)) + + # create optimizer + optimizer = create_optimizer(cfg, model) + model = model.cuda() + + # optionally resume from a checkpoint + resume_state, resume_epoch = {}, None + if cfg.AUTO_RESUME: + resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH) + optimizer.load_state_dict(resume_state['optimizer']) + del resume_state + + model_ema = None + if cfg.NET.EMA.USE: + model_ema = ModelEma( + model, + decay=cfg.NET.EMA.DECAY, + device='cpu' if cfg.NET.EMA.FORCE_CPU else '', + resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None) + + if distributed: + if cfg.BATCHNORM.SYNC_BN: + try: + if HAS_APEX: + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( + model) + if args.local_rank == 0: + logger.info( + 'Converted model to use Synchronized BatchNorm.') + except Exception as e: + if args.local_rank == 0: + logger.error( + 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}'.format(e)) + if HAS_APEX: + model = DDP(model, delay_allreduce=True) + else: + if args.local_rank == 0: + logger.info( + "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") + # can use device str in Torch >= 1.1 + model = DDP(model, device_ids=[args.local_rank]) + + # imagenet train dataset + train_dir = os.path.join(cfg.DATA_DIR, 'train') + if not os.path.exists(train_dir) and args.local_rank == 0: + logger.error('Training folder does not exist at: {}'.format(train_dir)) + exit(1) + dataset_train = Dataset(train_dir) + loader_train = create_loader( + dataset_train, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=cfg.DATASET.BATCH_SIZE, + is_training=True, + color_jitter=cfg.AUGMENTATION.COLOR_JITTER, + auto_augment=cfg.AUGMENTATION.AA, + num_aug_splits=0, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=cfg.WORKERS, + distributed=distributed, + collate_fn=None, + pin_memory=cfg.DATASET.PIN_MEM, + interpolation='random', + re_mode=cfg.AUGMENTATION.RE_MODE, + re_prob=cfg.AUGMENTATION.RE_PROB + ) + + # imagenet validation dataset + eval_dir = os.path.join(cfg.DATA_DIR, 'val') + if not os.path.exists(eval_dir) and args.local_rank == 0: + logger.error( + 'Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, + is_training=False, + interpolation=cfg.DATASET.INTERPOLATION, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=cfg.WORKERS, + distributed=distributed, + pin_memory=cfg.DATASET.PIN_MEM + ) + + # whether to use label smoothing + if cfg.AUGMENTATION.SMOOTHING > 0.: + train_loss_fn = LabelSmoothingCrossEntropy( + smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + + # create learning rate scheduler + lr_scheduler, num_epochs = create_scheduler(cfg, optimizer) + start_epoch = resume_epoch if resume_epoch is not None else 0 + if start_epoch > 0: + lr_scheduler.step(start_epoch) + if args.local_rank == 0: + logger.info('Scheduled epochs: {}'.format(num_epochs)) + + try: + best_record, best_ep = 0, 0 + for epoch in range(start_epoch, num_epochs): + if distributed: + loader_train.sampler.set_epoch(epoch) + + train_metrics = train_epoch( + epoch, + model, + loader_train, + optimizer, + train_loss_fn, + cfg, + lr_scheduler=lr_scheduler, + saver=saver, + output_dir=output_dir, + model_ema=model_ema, + logger=logger, + writer=writer, + local_rank=args.local_rank) + + eval_metrics = validate( + epoch, + model, + loader_eval, + validate_loss_fn, + cfg, + logger=logger, + writer=writer, + local_rank=args.local_rank) + + if model_ema is not None and not cfg.NET.EMA.FORCE_CPU: + ema_eval_metrics = validate( + epoch, + model_ema.ema, + loader_eval, + validate_loss_fn, + cfg, + log_suffix='_EMA', + logger=logger, + writer=writer) + eval_metrics = ema_eval_metrics + + if lr_scheduler is not None: + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + update_summary(epoch, train_metrics, eval_metrics, os.path.join( + output_dir, 'summary.csv'), write_header=best_metric is None) + + if saver is not None: + # save proper checkpoint with eval metric + save_metric = eval_metrics[eval_metric] + best_metric, best_epoch = saver.save_checkpoint( + model, optimizer, cfg, + epoch=epoch, model_ema=model_ema, metric=save_metric) + + if best_record < eval_metrics[eval_metric]: + best_record = eval_metrics[eval_metric] + best_ep = epoch + + if args.local_rank == 0: + logger.info( + '*** Best metric: {0} (epoch {1})'.format(best_record, best_ep)) + + except KeyboardInterrupt: + pass + + if best_metric is not None: + logger.info( + '*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + + +if __name__ == '__main__': + main() diff --git a/examples/nas/cream/test.py b/examples/nas/cream/test.py new file mode 100644 index 0000000000..67ee822853 --- /dev/null +++ b/examples/nas/cream/test.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import os +import warnings +import datetime +import torch +import torch.nn as nn + +from torch.utils.tensorboard import SummaryWriter + +# import timm packages +from timm.utils import ModelEma +from timm.models import resume_checkpoint +from timm.data import Dataset, create_loader + +# import apex as distributed package +try: + from apex.parallel import convert_syncbn_model + from apex.parallel import DistributedDataParallel as DDP + HAS_APEX = True +except ImportError: + from torch.nn.parallel import DistributedDataParallel as DDP + HAS_APEX = False + +# import models and training functions +from lib.core.test import validate +from lib.models.structures.childnet import gen_childnet +from lib.utils.util import parse_config_args, get_logger, get_model_flops_params +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def main(): + args, cfg = parse_config_args('child net testing') + + # resolve logging + output_dir = os.path.join(cfg.SAVE_PATH, + "{}-{}".format(datetime.date.today().strftime('%m%d'), + cfg.MODEL)) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if args.local_rank == 0: + logger = get_logger(os.path.join(output_dir, 'test.log')) + writer = SummaryWriter(os.path.join(output_dir, 'runs')) + else: + writer, logger = None, None + + # retrain model selection + if cfg.NET.SELECTION == 481: + arch_list = [ + [0], [ + 3, 4, 3, 1], [ + 3, 2, 3, 0], [ + 3, 3, 3, 1], [ + 3, 3, 3, 3], [ + 3, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 43: + arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 96 + elif cfg.NET.SELECTION == 14: + arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]] + cfg.DATASET.IMAGE_SIZE = 64 + elif cfg.NET.SELECTION == 112: + arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 160 + elif cfg.NET.SELECTION == 287: + arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + elif cfg.NET.SELECTION == 604: + arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3], + [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]] + cfg.DATASET.IMAGE_SIZE = 224 + else: + raise ValueError("Model Test Selection is not Supported!") + + # define childnet architecture from arch_list + stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25'] + choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25', + 'ir_r1_k5_s2_e4_c40_se0.25', + 'ir_r1_k3_s2_e6_c80_se0.25', + 'ir_r1_k3_s1_e6_c96_se0.25', + 'ir_r1_k3_s2_e6_c192_se0.25'] + arch_def = [[stem[0]]] + [[choice_block_pool[idx] + for repeat_times in range(len(arch_list[idx + 1]))] + for idx in range(len(choice_block_pool))] + [[stem[1]]] + + # generate childnet + model = gen_childnet( + arch_list, + arch_def, + num_classes=cfg.DATASET.NUM_CLASSES, + drop_rate=cfg.NET.DROPOUT_RATE, + global_pool=cfg.NET.GP) + + if args.local_rank == 0: + macs, params = get_model_flops_params(model, input_size=( + 1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE)) + logger.info( + '[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params)) + + # initialize distributed parameters + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.local_rank == 0: + logger.info( + "Training on Process {} with {} GPUs.".format( + args.local_rank, cfg.NUM_GPU)) + + # resume model from checkpoint + assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH) + _, __ = resume_checkpoint(model, cfg.RESUME_PATH) + + model = model.cuda() + + model_ema = None + if cfg.NET.EMA.USE: + # Important to create EMA model after cuda(), DP wrapper, and AMP but + # before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=cfg.NET.EMA.DECAY, + device='cpu' if cfg.NET.EMA.FORCE_CPU else '', + resume=cfg.RESUME_PATH) + + # imagenet validation dataset + eval_dir = os.path.join(cfg.DATA_DIR, 'val') + if not os.path.exists(eval_dir) and args.local_rank == 0: + logger.error( + 'Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) + + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE, + is_training=False, + num_workers=cfg.WORKERS, + distributed=True, + pin_memory=cfg.DATASET.PIN_MEM, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD + ) + + # only test accuracy of model-EMA + validate_loss_fn = nn.CrossEntropyLoss().cuda() + validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg, + log_suffix='_EMA', logger=logger, + writer=writer, local_rank=args.local_rank) + + +if __name__ == '__main__': + main() diff --git a/examples/nas/cream/train.py b/examples/nas/cream/train.py new file mode 100644 index 0000000000..50d340c1ef --- /dev/null +++ b/examples/nas/cream/train.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Written by Hao Du and Houwen Peng +# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com + +import os +import sys +import datetime +import torch +import numpy as np +import torch.nn as nn + +# import timm packages +from timm.loss import LabelSmoothingCrossEntropy +from timm.data import Dataset, create_loader +from timm.models import resume_checkpoint + +# import apex as distributed package +try: + from apex.parallel import DistributedDataParallel as DDP + from apex.parallel import convert_syncbn_model + USE_APEX = True +except ImportError: + from torch.nn.parallel import DistributedDataParallel as DDP + USE_APEX = False + +# import models and training functions +from lib.utils.flops_table import FlopsEst +from lib.models.structures.supernet import gen_supernet +from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from lib.utils.util import parse_config_args, get_logger, \ + create_optimizer_supernet, create_supernet_scheduler + +from nni.nas.pytorch.callbacks import LRSchedulerCallback +from nni.nas.pytorch.callbacks import ModelCheckpoint +from nni.algorithms.nas.pytorch.cream import CreamSupernetTrainer +from nni.algorithms.nas.pytorch.random import RandomMutator + +def main(): + args, cfg = parse_config_args('nni.cream.supernet') + + # resolve logging + output_dir = os.path.join(cfg.SAVE_PATH, + "{}-{}".format(datetime.date.today().strftime('%m%d'), + cfg.MODEL)) + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if args.local_rank == 0: + logger = get_logger(os.path.join(output_dir, "train.log")) + else: + logger = None + + # initialize distributed parameters + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.local_rank == 0: + logger.info( + 'Training on Process %d with %d GPUs.', + args.local_rank, cfg.NUM_GPU) + + # fix random seeds + torch.manual_seed(cfg.SEED) + torch.cuda.manual_seed_all(cfg.SEED) + np.random.seed(cfg.SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # generate supernet + model, sta_num, resolution = gen_supernet( + flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM, + flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM, + num_classes=cfg.DATASET.NUM_CLASSES, + drop_rate=cfg.NET.DROPOUT_RATE, + global_pool=cfg.NET.GP, + resunit=cfg.SUPERNET.RESUNIT, + dil_conv=cfg.SUPERNET.DIL_CONV, + slice=cfg.SUPERNET.SLICE, + verbose=cfg.VERBOSE, + logger=logger) + + # number of choice blocks in supernet + choice_num = len(model.blocks[7]) + if args.local_rank == 0: + logger.info('Supernet created, param count: %d', ( + sum([m.numel() for m in model.parameters()]))) + logger.info('resolution: %d', (resolution)) + logger.info('choice number: %d', (choice_num)) + + # initialize flops look-up table + model_est = FlopsEst(model) + flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed + + # optionally resume from a checkpoint + optimizer_state = None + resume_epoch = None + if cfg.AUTO_RESUME: + optimizer_state, resume_epoch = resume_checkpoint( + model, cfg.RESUME_PATH) + + # create optimizer and resume from checkpoint + optimizer = create_optimizer_supernet(cfg, model, USE_APEX) + if optimizer_state is not None: + optimizer.load_state_dict(optimizer_state['optimizer']) + model = model.cuda() + + # convert model to distributed mode + if cfg.BATCHNORM.SYNC_BN: + try: + if USE_APEX: + model = convert_syncbn_model(model) + else: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if args.local_rank == 0: + logger.info('Converted model to use Synchronized BatchNorm.') + except Exception as exception: + logger.info( + 'Failed to enable Synchronized BatchNorm. ' + 'Install Apex or Torch >= 1.1 with Exception %s', exception) + if USE_APEX: + model = DDP(model, delay_allreduce=True) + else: + if args.local_rank == 0: + logger.info( + "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") + # can use device str in Torch >= 1.1 + model = DDP(model, device_ids=[args.local_rank]) + + # create learning rate scheduler + lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer) + + start_epoch = resume_epoch if resume_epoch is not None else 0 + if start_epoch > 0: + lr_scheduler.step(start_epoch) + + if args.local_rank == 0: + logger.info('Scheduled epochs: %d', num_epochs) + + # imagenet train dataset + train_dir = os.path.join(cfg.DATA_DIR, 'train') + if not os.path.exists(train_dir): + logger.info('Training folder does not exist at: %s', train_dir) + sys.exit() + + dataset_train = Dataset(train_dir) + loader_train = create_loader( + dataset_train, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=cfg.DATASET.BATCH_SIZE, + is_training=True, + use_prefetcher=True, + re_prob=cfg.AUGMENTATION.RE_PROB, + re_mode=cfg.AUGMENTATION.RE_MODE, + color_jitter=cfg.AUGMENTATION.COLOR_JITTER, + interpolation='random', + num_workers=cfg.WORKERS, + distributed=True, + collate_fn=None, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD + ) + + # imagenet validation dataset + eval_dir = os.path.join(cfg.DATA_DIR, 'val') + if not os.path.isdir(eval_dir): + logger.info('Validation folder does not exist at: %s', eval_dir) + sys.exit() + dataset_eval = Dataset(eval_dir) + loader_eval = create_loader( + dataset_eval, + input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE), + batch_size=4 * cfg.DATASET.BATCH_SIZE, + is_training=False, + use_prefetcher=True, + num_workers=cfg.WORKERS, + distributed=True, + crop_pct=DEFAULT_CROP_PCT, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + interpolation=cfg.DATASET.INTERPOLATION + ) + + # whether to use label smoothing + if cfg.AUGMENTATION.SMOOTHING > 0.: + train_loss_fn = LabelSmoothingCrossEntropy( + smoothing=cfg.AUGMENTATION.SMOOTHING).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + + mutator = RandomMutator(model) + + trainer = CreamSupernetTrainer(model, train_loss_fn, validate_loss_fn, + optimizer, num_epochs, loader_train, loader_eval, + mutator=mutator, batch_size=cfg.DATASET.BATCH_SIZE, + log_frequency=cfg.LOG_INTERVAL, + meta_sta_epoch=cfg.SUPERNET.META_STA_EPOCH, + update_iter=cfg.SUPERNET.UPDATE_ITER, + slices=cfg.SUPERNET.SLICE, + pool_size=cfg.SUPERNET.POOL_SIZE, + pick_method=cfg.SUPERNET.PICK_METHOD, + choice_num=choice_num, sta_num=sta_num, acc_gap=cfg.ACC_GAP, + flops_dict=flops_dict, flops_fixed=flops_fixed, local_rank=args.local_rank, + callbacks=[LRSchedulerCallback(lr_scheduler), + ModelCheckpoint(output_dir)]) + + trainer.train() + + +if __name__ == '__main__': + main() diff --git a/nni/algorithms/nas/pytorch/cream/__init__.py b/nni/algorithms/nas/pytorch/cream/__init__.py new file mode 100755 index 0000000000..43a038b467 --- /dev/null +++ b/nni/algorithms/nas/pytorch/cream/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .trainer import CreamSupernetTrainer diff --git a/nni/algorithms/nas/pytorch/cream/trainer.py b/nni/algorithms/nas/pytorch/cream/trainer.py new file mode 100644 index 0000000000..0c5136d1b4 --- /dev/null +++ b/nni/algorithms/nas/pytorch/cream/trainer.py @@ -0,0 +1,406 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import torch +import logging + +from copy import deepcopy +from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.utils import AverageMeterGroup + +from .utils import accuracy, reduce_metrics + +logger = logging.getLogger(__name__) + + +class CreamSupernetTrainer(Trainer): + """ + This trainer trains a supernet and output prioritized architectures that can be used for other tasks. + + Parameters + ---------- + model : nn.Module + Model with mutables. + loss : callable + Called with logits and targets. Returns a loss tensor. + val_loss : callable + Called with logits and targets for validation only. Returns a loss tensor. + optimizer : Optimizer + Optimizer that optimizes the model. + num_epochs : int + Number of epochs of training. + train_loader : iterablez + Data loader of training. Raise ``StopIteration`` when one epoch is exhausted. + valid_loader : iterablez + Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted. + mutator : Mutator + A mutator object that has been initialized with the model. + batch_size : int + Batch size. + log_frequency : int + Number of mini-batches to log metrics. + meta_sta_epoch : int + start epoch of using meta matching network to pick teacher architecture + update_iter : int + interval of updating meta matching networks + slices : int + batch size of mini training data in the process of training meta matching network + pool_size : int + board size + pick_method : basestring + how to pick teacher network + choice_num : int + number of operations in supernet + sta_num : int + layer number of each stage in supernet (5 stage in supernet) + acc_gap : int + maximum accuracy improvement to omit the limitation of flops + flops_dict : Dict + dictionary of each layer's operations in supernet + flops_fixed : int + flops of fixed part in supernet + local_rank : int + index of current rank + callbacks : list of Callback + Callbacks to plug into the trainer. See Callbacks. + """ + + def __init__(self, model, loss, val_loss, + optimizer, num_epochs, train_loader, valid_loader, + mutator=None, batch_size=64, log_frequency=None, + meta_sta_epoch=20, update_iter=200, slices=2, + pool_size=10, pick_method='meta', choice_num=6, + sta_num=(4, 4, 4, 4, 4), acc_gap=5, + flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None): + assert torch.cuda.is_available() + super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None, + optimizer, num_epochs, None, None, + batch_size, None, None, log_frequency, callbacks) + self.model = model + self.loss = loss + self.val_loss = val_loss + self.train_loader = train_loader + self.valid_loader = valid_loader + self.log_frequency = log_frequency + self.batch_size = batch_size + self.optimizer = optimizer + self.model = model + self.loss = loss + self.num_epochs = num_epochs + self.meta_sta_epoch = meta_sta_epoch + self.update_iter = update_iter + self.slices = slices + self.pick_method = pick_method + self.pool_size = pool_size + self.local_rank = local_rank + self.choice_num = choice_num + self.sta_num = sta_num + self.acc_gap = acc_gap + self.flops_dict = flops_dict + self.flops_fixed = flops_fixed + + self.current_student_arch = None + self.current_teacher_arch = None + self.main_proc = (local_rank == 0) + self.current_epoch = 0 + + self.prioritized_board = [] + + # size of prioritized board + def _board_size(self): + return len(self.prioritized_board) + + # select teacher architecture according to the logit difference + def _select_teacher(self): + self._replace_mutator_cand(self.current_student_arch) + + if self.pick_method == 'top1': + meta_value, teacher_cand = 0.5, sorted( + self.prioritized_board, reverse=True)[0][3] + elif self.pick_method == 'meta': + meta_value, cand_idx, teacher_cand = -1000000000, -1, None + for now_idx, item in enumerate(self.prioritized_board): + inputx = item[4] + output = torch.nn.functional.softmax(self.model(inputx), dim=1) + weight = self.model.module.forward_meta(output - item[5]) + if weight > meta_value: + meta_value = weight + cand_idx = now_idx + teacher_cand = self.prioritized_board[cand_idx][3] + assert teacher_cand is not None + meta_value = torch.nn.functional.sigmoid(-weight) + else: + raise ValueError('Method Not supported') + + return meta_value, teacher_cand + + # check whether to update prioritized board + def _isUpdateBoard(self, prec1, flops): + if self.current_epoch <= self.meta_sta_epoch: + return False + + if len(self.prioritized_board) < self.pool_size: + return True + + if prec1 > self.prioritized_board[-1][1] + self.acc_gap: + return True + + if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]: + return True + + return False + + # update prioritized board + def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops): + if self._isUpdateBoard(prec1, flops): + val_prec1 = prec1 + training_data = deepcopy(inputs[:self.slices].detach()) + if len(self.prioritized_board) == 0: + features = deepcopy(outputs[:self.slices].detach()) + else: + features = deepcopy( + teacher_output[:self.slices].detach()) + self.prioritized_board.append( + (val_prec1, + prec1, + flops, + self.current_teacher_arch, + training_data, + torch.nn.functional.softmax( + features, + dim=1))) + self.prioritized_board = sorted( + self.prioritized_board, reverse=True) + + if len(self.prioritized_board) > self.pool_size: + self.prioritized_board = sorted( + self.prioritized_board, reverse=True) + del self.prioritized_board[-1] + + # only update student network weights + def _update_student_weights_only(self, grad_1): + for weight, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1): + weight.grad = grad_item + torch.nn.utils.clip_grad_norm_( + self.model.module.rand_parameters(self.current_student_arch), 1) + self.optimizer.step() + for weight, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1): + del weight.grad + + # only update meta networks weights + def _update_meta_weights_only(self, teacher_cand, grad_teacher): + for weight, grad_item in zip(self.model.module.rand_parameters( + teacher_cand, self.pick_method == 'meta'), grad_teacher): + weight.grad = grad_item + + # clip gradients + torch.nn.utils.clip_grad_norm_( + self.model.module.rand_parameters( + self.current_student_arch, self.pick_method == 'meta'), 1) + + self.optimizer.step() + for weight, grad_item in zip(self.model.module.rand_parameters( + teacher_cand, self.pick_method == 'meta'), grad_teacher): + del weight.grad + + # simulate sgd updating + def _simulate_sgd_update(self, w, g, optimizer): + return g * optimizer.param_groups[-1]['lr'] + w + + # split training images into several slices + def _get_minibatch_input(self, input): + slice = self.slices + x = deepcopy(input[:slice].clone().detach()) + return x + + # calculate 1st gradient of student architectures + def _calculate_1st_gradient(self, kd_loss): + self.optimizer.zero_grad() + grad = torch.autograd.grad( + kd_loss, + self.model.module.rand_parameters(self.current_student_arch), + create_graph=True) + return grad + + # calculate 2nd gradient of meta networks + def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight): + self.optimizer.zero_grad() + grad_student_val = torch.autograd.grad( + validation_loss, + self.model.module.rand_parameters(self.random_cand), + retain_graph=True) + + grad_teacher = torch.autograd.grad( + students_weight[0], + self.model.module.rand_parameters( + teacher_cand, + self.pick_method == 'meta'), + grad_outputs=grad_student_val) + return grad_teacher + + # forward training data + def _forward_training(self, x, meta_value): + self._replace_mutator_cand(self.current_student_arch) + output = self.model(x) + + with torch.no_grad(): + self._replace_mutator_cand(self.current_teacher_arch) + teacher_output = self.model(x) + soft_label = torch.nn.functional.softmax(teacher_output, dim=1) + + kd_loss = meta_value * \ + self._cross_entropy_loss_with_soft_target(output, soft_label) + return kd_loss + + # calculate soft target loss + def _cross_entropy_loss_with_soft_target(self, pred, soft_target): + logsoftmax = torch.nn.LogSoftmax() + return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) + + # forward validation data + def _forward_validation(self, input, target): + slice = self.slices + x = input[slice:slice * 2].clone() + + self._replace_mutator_cand(self.current_student_arch) + output_2 = self.model(x) + + validation_loss = self.loss(output_2, target[slice:slice * 2]) + return validation_loss + + def _isUpdateMeta(self, batch_idx): + isUpdate = True + isUpdate &= (self.current_epoch > self.meta_sta_epoch) + isUpdate &= (batch_idx > 0) + isUpdate &= (batch_idx % self.update_iter == 0) + isUpdate &= (self._board_size() > 0) + return isUpdate + + def _replace_mutator_cand(self, cand): + self.mutator._cache = cand + + # update meta matching networks + def _run_update(self, input, target, batch_idx): + if self._isUpdateMeta(batch_idx): + x = self._get_minibatch_input(input) + + meta_value, teacher_cand = self._select_teacher() + + kd_loss = self._forward_training(x, meta_value) + + # calculate 1st gradient + grad_1st = self._calculate_1st_gradient(kd_loss) + + # simulate updated student weights + students_weight = [ + self._simulate_sgd_update( + p, grad_item, self.optimizer) for p, grad_item in zip( + self.model.module.rand_parameters(self.current_student_arch), grad_1st)] + + # update student weights + self._update_student_weights_only(grad_1st) + + validation_loss = self._forward_validation(input, target) + + # calculate 2nd gradient + grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight) + + # update meta matching networks + self._update_meta_weights_only(teacher_cand, grad_teacher) + + # delete internal variants + del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight + + def _get_cand_flops(self, cand): + flops = 0 + for block_id, block in enumerate(cand): + if block == 'LayerChoice1' or block_id == 'LayerChoice23': + continue + for idx, choice in enumerate(cand[block]): + flops += self.flops_dict[block_id][idx] * (1 if choice else 0) + return flops + self.flops_fixed + + def train_one_epoch(self, epoch): + self.current_epoch = epoch + meters = AverageMeterGroup() + self.steps_per_epoch = len(self.train_loader) + for step, (input_data, target) in enumerate(self.train_loader): + self.mutator.reset() + self.current_student_arch = self.mutator._cache + + input_data, target = input_data.cuda(), target.cuda() + + # calculate flops of current architecture + cand_flops = self._get_cand_flops(self.mutator._cache) + + # update meta matching network + self._run_update(input_data, target, step) + + if self._board_size() > 0: + # select teacher architecture + meta_value, teacher_cand = self._select_teacher() + self.current_teacher_arch = teacher_cand + + # forward supernet + if self._board_size() == 0 or epoch <= self.meta_sta_epoch: + self._replace_mutator_cand(self.current_student_arch) + output = self.model(input_data) + + loss = self.loss(output, target) + kd_loss, teacher_output, teacher_cand = None, None, None + else: + self._replace_mutator_cand(self.current_student_arch) + output = self.model(input_data) + + gt_loss = self.loss(output, target) + + with torch.no_grad(): + self._replace_mutator_cand(self.current_teacher_arch) + teacher_output = self.model(input_data).detach() + + soft_label = torch.nn.functional.softmax(teacher_output, dim=1) + kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label) + + loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2 + + # update network + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # update metrics + prec1, prec5 = accuracy(output, target, topk=(1, 5)) + metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} + metrics = reduce_metrics(metrics) + meters.update(metrics) + + # update prioritized board + self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops) + + if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch): + logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, + step + 1, len(self.train_loader), meters) + + if self.main_proc and self.num_epochs == epoch + 1: + for idx, i in enumerate(self.best_children_pool): + logger.info("No.%s %s", idx, i[:4]) + + def validate_one_epoch(self, epoch): + self.model.eval() + meters = AverageMeterGroup() + with torch.no_grad(): + for step, (x, y) in enumerate(self.valid_loader): + self.mutator.reset() + logits = self.model(x) + loss = self.val_loss(logits, y) + prec1, prec5 = self.accuracy(logits, y, topk=(1, 5)) + metrics = {"prec1": prec1, "prec5": prec5, "loss": loss} + metrics = self.reduce_metrics(metrics, self.distributed) + meters.update(metrics) + + if self.log_frequency is not None and step % self.log_frequency == 0: + logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1, + self.num_epochs, step + 1, len(self.valid_loader), meters) diff --git a/nni/algorithms/nas/pytorch/cream/utils.py b/nni/algorithms/nas/pytorch/cream/utils.py new file mode 100644 index 0000000000..e0542b2f3e --- /dev/null +++ b/nni/algorithms/nas/pytorch/cream/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + + +import os +import torch.distributed as dist + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(1.0 / batch_size)) + return res + + +def reduce_metrics(metrics): + return {k: reduce_tensor(v).item() for k, v in metrics.items()} + + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= float(os.environ["WORLD_SIZE"]) + return rt