Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support proxylessnas with NNI NAS APIs #1863

Merged
merged 64 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
bd7c0f0
update doc
zhangql08hit Nov 5, 2019
d9f3afb
update
zhangql08hit Nov 5, 2019
b5c295c
update
zhangql08hit Nov 5, 2019
8f9c7bc
update
zhangql08hit Nov 5, 2019
0e7f6b9
update
zhangql08hit Nov 5, 2019
c7c218f
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-…
zhangql08hit Nov 11, 2019
bccb536
init commit
zhangql08hit Nov 13, 2019
5647dd0
update
zhangql08hit Nov 14, 2019
5b7cb43
update
zhangql08hit Nov 14, 2019
366b793
debug
zhangql08hit Nov 16, 2019
088a56c
update
zhangql08hit Nov 17, 2019
0a47184
update
zhangql08hit Nov 17, 2019
52dd740
update
zhangql08hit Nov 18, 2019
95b1974
update
zhangql08hit Nov 18, 2019
44145e4
update
zhangql08hit Nov 18, 2019
a0febf9
update
zhangql08hit Nov 19, 2019
7b92588
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-…
zhangql08hit Nov 19, 2019
cc8a1fb
update
zhangql08hit Nov 19, 2019
dacbdf7
update
zhangql08hit Nov 19, 2019
007e043
update
zhangql08hit Nov 20, 2019
098fe3d
fix bug
zhangql08hit Dec 10, 2019
ca9ec6c
update
zhangql08hit Dec 11, 2019
3d2159e
update
zhangql08hit Dec 11, 2019
181f9c0
update
zhangql08hit Dec 12, 2019
5578542
update
zhangql08hit Dec 12, 2019
55c75f5
update
zhangql08hit Dec 12, 2019
5a403ec
update
zhangql08hit Dec 12, 2019
3e2ee56
update
zhangql08hit Dec 12, 2019
ed27d47
update
zhangql08hit Dec 12, 2019
80eafc4
update
zhangql08hit Dec 12, 2019
b8e29e8
update
zhangql08hit Dec 12, 2019
1354025
update
zhangql08hit Dec 13, 2019
4b611db
update
zhangql08hit Dec 13, 2019
a624c12
fix bug
zhangql08hit Dec 13, 2019
393d837
update
zhangql08hit Dec 13, 2019
f768b5a
update
zhangql08hit Dec 13, 2019
8bc69a8
update
zhangql08hit Dec 13, 2019
640103d
update
zhangql08hit Dec 13, 2019
810ea95
update
zhangql08hit Dec 13, 2019
b890fce
update
zhangql08hit Dec 13, 2019
5996d4f
update
zhangql08hit Dec 13, 2019
51128bb
update
zhangql08hit Dec 16, 2019
3b3aba4
update
zhangql08hit Dec 16, 2019
14f3f1d
add retrain
zhangql08hit Dec 16, 2019
346e5a4
update
zhangql08hit Dec 16, 2019
5ff1ccd
Merge branch 'master' of github.com:Microsoft/nni into dev-plnas
zhangql08hit Dec 16, 2019
0eddd52
update
zhangql08hit Dec 16, 2019
8d499ec
retrain tested
zhangql08hit Dec 17, 2019
cb0c2e9
update
zhangql08hit Dec 18, 2019
38fab2d
update
zhangql08hit Dec 18, 2019
eab6e22
update
zhangql08hit Dec 19, 2019
a7f59f0
update
zhangql08hit Dec 19, 2019
8ef5f6d
add doc string
zhangql08hit Dec 22, 2019
477af83
update
zhangql08hit Dec 22, 2019
aab28e2
add docstring
zhangql08hit Dec 23, 2019
d9a778d
update
zhangql08hit Dec 23, 2019
e9c7603
add doc
zhangql08hit Dec 23, 2019
4f7c662
resolve comments
QuanluZhang Dec 23, 2019
0b8cb1e
update
QuanluZhang Dec 24, 2019
b462b25
Merge branch 'master' of https://github.com/microsoft/nni into dev-plnas
Feb 10, 2020
927ab9e
update doc
Feb 10, 2020
e32bb72
update doc
Feb 10, 2020
61d2944
update toctree
Feb 10, 2020
fba009e
fix broken link
Feb 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en_US/NAS/Overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ NNI supports below NAS algorithms now and is adding more. User can reproduce an
| [P-DARTS](PDARTS.md) | [Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) is based on DARTS. It introduces an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. |
| [SPOS](SPOS.md) | [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420) constructs a simplified supernet trained with an uniform path sampling method, and applies an evolutionary algorithm to efficiently search for the best-performing architectures. |
| [CDARTS](CDARTS.md) | [Cyclic Differentiable Architecture Search](https://arxiv.org/abs/****) builds a cyclic feedback mechanism between the search and evaluation networks. It introduces a cyclic differentiable architecture search framework which integrates the two networks into a unified architecture.|
| [ProxylessNAS](Proxylessnas.md) | [ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware](https://arxiv.org/abs/1812.00332).|

One-shot algorithms run **standalone without nnictl**. Only PyTorch version has been implemented. Tensorflow 2.x will be supported in future release.

Expand Down
63 changes: 63 additions & 0 deletions docs/en_US/NAS/Proxylessnas.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# ProxylessNAS on NNI

## Introduction

The paper [ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware](https://arxiv.org/pdf/1812.00332.pdf) removes proxy, it directly learns the architectures for large-scale target tasks and target hardware platforms. They address high memory consumption issue of differentiable NAS and reduce the computational cost to the same level of regular training while still allowing a large candidate set. Please refer to the paper for the details.

## Usage

To use ProxylessNAS training/searching approach, users need to specify search space in their model using [NNI NAS interface](NasGuide.md), e.g., `LayerChoice`, `InputChoice`. After defining and instantiating the model, the following work can be leaved to ProxylessNasTrainer by instantiating the trainer and passing the model to it.
```python
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)
trainer.train()
trainer.export(args.arch_path)
```
The complete example code can be found [here](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas).

**Input arguments of ProxylessNasTrainer**

* **model** (*PyTorch model, required*) - The model that users want to tune/search. It has mutables to specify search space.
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
* **model_optim** (*PyTorch optimizer, required*) - The optimizer users want to train the model.
* **device** (*device, required*) - The devices that users provide to do the train/search. The trainer applies data parallel on the model for users.
* **train_loader** (*PyTorch data loader, required*) - The data loader for training set.
* **valid_loader** (*PyTorch data loader, required*) - The data loader for validation set.
* **label_smoothing** (*float, optional, default = 0.1*) - The degree of label smoothing.
* **n_epochs** (*int, optional, default = 120*) - The number of epochs to train/search.
* **init_lr** (*float, optional, default = 0.025*) - The initial learning rate for training the model.
* **binary_mode** (*'two', 'full', or 'full_v2', optional, default = 'full_v2'*) - The forward/backward mode for the binary weights in mutator. 'full' means forward all the candidate ops, 'two' means only forward two sampled ops, 'full_v2' means recomputing the inactive ops during backward.
* **arch_init_type** (*'normal' or 'uniform', optional, default = 'normal'*) - The way to init architecture parameters.
* **arch_init_ratio** (*float, optional, default = 1e-3*) - The ratio to init architecture parameters.
* **arch_optim_lr** (*float, optional, default = 1e-3*) - The learning rate of the architecture parameters optimizer.
* **arch_weight_decay** (*float, optional, default = 0*) - Weight decay of the architecture parameters optimizer.
* **grad_update_arch_param_every** (*int, optional, default = 5*) - Update architecture weights every this number of minibatches.
* **grad_update_steps** (*int, optional, default = 1*) - During each update of architecture weights, the number of steps to train architecture weights.
* **warmup** (*bool, optional, default = True*) - Whether to do warmup.
* **warmup_epochs** (*int, optional, default = 25*) - The number of epochs to do during warmup.
* **arch_valid_frequency** (*int, optional, default = 1*) - The frequency of printing validation result.
* **load_ckpt** (*bool, optional, default = False*) - Whether to load checkpoint.
* **ckpt_path** (*str, optional, default = None*) - checkpoint path, if load_ckpt is True, ckpt_path cannot be None.
* **arch_path** (*str, optional, default = None*) - The path to store chosen architecture.


## Implementation

The implementation on NNI is based on the [offical implementation](https://github.com/mit-han-lab/ProxylessNAS). The official implementation supports two training approaches: gradient descent and RL based, and support different targeted hardware, including 'mobile', 'cpu', 'gpu8', 'flops'. In our current implementation on NNI, gradient descent training approach is supported, but has not supported different hardwares. The complete support is ongoing.

Below we will describe implementation details. Like other one-shot NAS algorithms on NNI, ProxylessNAS is composed of two parts: *search space* and *training approach*. For users to flexibly define their own search space and use built-in ProxylessNAS training approach, we put the specified search space in [example code](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas) using [NNI NAS interface](NasGuide.md), and put the training approach in [SDK](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch/proxylessnas).

![](../../img/proxylessnas.png)

ProxylessNAS training approach is composed of ProxylessNasMutator and ProxylessNasTrainer. ProxylessNasMutator instantiates MixedOp for each mutable (i.e., LayerChoice), and manage architecture weights in MixedOp. **For DataParallel**, architecture weights should be included in user model. Specifically, in ProxylessNAS implementation, we add MixedOp to the corresponding mutable (i.e., LayerChoice) as a member variable. The mutator also exposes two member functions, i.e., `arch_requires_grad`, `arch_disable_grad`, for the trainer to control the training of architecture weights.

ProxylessNasMutator also implements the forward logic of the mutables (i.e., LayerChoice).

## Reproduce Results

Ongoing...
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions docs/en_US/nas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ For details, please refer to the following tutorials:
P-DARTS <NAS/PDARTS>
SPOS <NAS/SPOS>
CDARTS <NAS/CDARTS>
ProxylessNAS <NAS/Proxylessnas>
API Reference <NAS/NasReference>
Binary file added docs/img/proxylessnas.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
188 changes: 188 additions & 0 deletions examples/nas/proxylessnas/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def get_split_list(in_dim, child_num):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
return in_dim_list

class DataProvider:
VALID_SEED = 0 # random seed for the validation set

@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError

@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError

@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError

@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError

@property
def data_url(self):
""" link to download the data """
raise NotImplementedError

@staticmethod
def random_sample_valid_set(train_labels, valid_size, n_classes):
train_size = len(train_labels)
assert train_size > valid_size

g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()

train_indexes, valid_indexes = [], []
per_class_remain = get_split_list(valid_size, n_classes)

for idx in rand_indexes:
label = train_labels[idx]
if isinstance(label, float):
label = int(label)
elif isinstance(label, np.ndarray):
label = np.argmax(label)
else:
assert isinstance(label, int)
if per_class_remain[label] > 0:
valid_indexes.append(idx)
per_class_remain[label] -= 1
else:
train_indexes.append(idx)
return train_indexes, valid_indexes


class ImagenetDataProvider(DataProvider):

def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None,
n_worker=32, resize_scale=0.08, distort_color=None):

self._save_path = save_path
train_transforms = self.build_train_transform(distort_color, resize_scale)
train_dataset = datasets.ImageFolder(self.train_path, train_transforms)

if valid_size is not None:
if isinstance(valid_size, float):
valid_size = int(valid_size * len(train_dataset))
else:
assert isinstance(valid_size, int), 'invalid valid_size: %s' % valid_size
train_indexes, valid_indexes = self.random_sample_valid_set(
[cls for _, cls in train_dataset.samples], valid_size, self.n_classes,
)
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)

valid_dataset = datasets.ImageFolder(self.train_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
]))

self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None

self.test = torch.utils.data.DataLoader(
datasets.ImageFolder(self.valid_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
])), batch_size=test_batch_size, shuffle=False, num_workers=n_worker, pin_memory=True,
)

if self.valid is None:
self.valid = self.test

@staticmethod
def name():
return 'imagenet'

@property
def data_shape(self):
return 3, self.image_size, self.image_size # C, H, W

@property
def n_classes(self):
return 1000

@property
def save_path(self):
if self._save_path is None:
self._save_path = '/dataset/imagenet'
return self._save_path

@property
def data_url(self):
raise ValueError('unable to download ImageNet')

@property
def train_path(self):
return os.path.join(self.save_path, 'train')

@property
def valid_path(self):
return os.path.join(self._save_path, 'val')

@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

def build_train_transform(self, distort_color, resize_scale):
print('Color jitter: %s' % distort_color)
if distort_color == 'strong':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif distort_color == 'normal':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if color_transform is None:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
])
else:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
color_transform,
transforms.ToTensor(),
self.normalize,
])
return train_transforms

@property
def resize_value(self):
return 256

@property
def image_size(self):
return 224
105 changes: 105 additions & 0 deletions examples/nas/proxylessnas/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import sys
import logging
from argparse import ArgumentParser
import torch
import datasets

from putils import get_parameters
from model import SearchMobileNet
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')

if __name__ == "__main__":
parser = ArgumentParser("proxylessnas")
# configurations of the model
parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
parser.add_argument("--bn_momentum", default=0.1, type=float)
parser.add_argument("--bn_eps", default=1e-3, type=float)
parser.add_argument("--dropout_rate", default=0, type=float)
parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
# configurations of imagenet dataset
parser.add_argument("--data_path", default='/data/imagenet/', type=str)
parser.add_argument("--train_batch_size", default=256, type=int)
parser.add_argument("--test_batch_size", default=500, type=int)
parser.add_argument("--n_worker", default=32, type=int)
parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
# configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str)

args = parser.parse_args()
if args.train_mode == 'retrain' and args.exported_arch_path is None:
logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
sys.exit(-1)

model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
stride_stages=[int(i) for i in args.stride_stages.split(',')],
n_classes=1000,
dropout_rate=args.dropout_rate,
bn_param=(args.bn_momentum, args.bn_eps))
logger.info('SearchMobileNet model create done')
model.init_model()
logger.info('SearchMobileNet model init done')

# move network to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')

logger.info('Creating data provider...')
data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
valid_size=None,
n_worker=args.n_worker,
resize_scale=args.resize_scale,
distort_color=args.distort_color)
logger.info('Creating data provider done')

if args.no_decay_keys:
keys = args.no_decay_keys
momentum, nesterov = 0.9, True
optimizer = torch.optim.SGD([
{'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
{'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
], lr=0.05, momentum=momentum, nesterov=nesterov)
else:
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

if args.train_mode == 'search':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)

logger.info('Start to train with ProxylessNasTrainer...')
trainer.train()
logger.info('Training done')
trainer.export(args.arch_path)
logger.info('Best architecture exported in %s', args.arch_path)
elif args.train_mode == 'retrain':
# this is retrain
from nni.nas.pytorch.fixed import apply_fixed_architecture
assert os.path.isfile(args.exported_arch_path), \
"exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run()
Loading