This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support proxylessnas with NNI NAS APIs (#1863)
- Loading branch information
1 parent
fdcd877
commit affb211
Showing
15 changed files
with
2,128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# ProxylessNAS on NNI | ||
|
||
## Introduction | ||
|
||
The paper [ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware](https://arxiv.org/pdf/1812.00332.pdf) removes proxy, it directly learns the architectures for large-scale target tasks and target hardware platforms. They address high memory consumption issue of differentiable NAS and reduce the computational cost to the same level of regular training while still allowing a large candidate set. Please refer to the paper for the details. | ||
|
||
## Usage | ||
|
||
To use ProxylessNAS training/searching approach, users need to specify search space in their model using [NNI NAS interface](NasGuide.md), e.g., `LayerChoice`, `InputChoice`. After defining and instantiating the model, the following work can be leaved to ProxylessNasTrainer by instantiating the trainer and passing the model to it. | ||
```python | ||
trainer = ProxylessNasTrainer(model, | ||
model_optim=optimizer, | ||
train_loader=data_provider.train, | ||
valid_loader=data_provider.valid, | ||
device=device, | ||
warmup=True, | ||
ckpt_path=args.checkpoint_path, | ||
arch_path=args.arch_path) | ||
trainer.train() | ||
trainer.export(args.arch_path) | ||
``` | ||
The complete example code can be found [here](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas). | ||
|
||
**Input arguments of ProxylessNasTrainer** | ||
|
||
* **model** (*PyTorch model, required*) - The model that users want to tune/search. It has mutables to specify search space. | ||
* **model_optim** (*PyTorch optimizer, required*) - The optimizer users want to train the model. | ||
* **device** (*device, required*) - The devices that users provide to do the train/search. The trainer applies data parallel on the model for users. | ||
* **train_loader** (*PyTorch data loader, required*) - The data loader for training set. | ||
* **valid_loader** (*PyTorch data loader, required*) - The data loader for validation set. | ||
* **label_smoothing** (*float, optional, default = 0.1*) - The degree of label smoothing. | ||
* **n_epochs** (*int, optional, default = 120*) - The number of epochs to train/search. | ||
* **init_lr** (*float, optional, default = 0.025*) - The initial learning rate for training the model. | ||
* **binary_mode** (*'two', 'full', or 'full_v2', optional, default = 'full_v2'*) - The forward/backward mode for the binary weights in mutator. 'full' means forward all the candidate ops, 'two' means only forward two sampled ops, 'full_v2' means recomputing the inactive ops during backward. | ||
* **arch_init_type** (*'normal' or 'uniform', optional, default = 'normal'*) - The way to init architecture parameters. | ||
* **arch_init_ratio** (*float, optional, default = 1e-3*) - The ratio to init architecture parameters. | ||
* **arch_optim_lr** (*float, optional, default = 1e-3*) - The learning rate of the architecture parameters optimizer. | ||
* **arch_weight_decay** (*float, optional, default = 0*) - Weight decay of the architecture parameters optimizer. | ||
* **grad_update_arch_param_every** (*int, optional, default = 5*) - Update architecture weights every this number of minibatches. | ||
* **grad_update_steps** (*int, optional, default = 1*) - During each update of architecture weights, the number of steps to train architecture weights. | ||
* **warmup** (*bool, optional, default = True*) - Whether to do warmup. | ||
* **warmup_epochs** (*int, optional, default = 25*) - The number of epochs to do during warmup. | ||
* **arch_valid_frequency** (*int, optional, default = 1*) - The frequency of printing validation result. | ||
* **load_ckpt** (*bool, optional, default = False*) - Whether to load checkpoint. | ||
* **ckpt_path** (*str, optional, default = None*) - checkpoint path, if load_ckpt is True, ckpt_path cannot be None. | ||
* **arch_path** (*str, optional, default = None*) - The path to store chosen architecture. | ||
|
||
|
||
## Implementation | ||
|
||
The implementation on NNI is based on the [offical implementation](https://github.com/mit-han-lab/ProxylessNAS). The official implementation supports two training approaches: gradient descent and RL based, and support different targeted hardware, including 'mobile', 'cpu', 'gpu8', 'flops'. In our current implementation on NNI, gradient descent training approach is supported, but has not supported different hardwares. The complete support is ongoing. | ||
|
||
Below we will describe implementation details. Like other one-shot NAS algorithms on NNI, ProxylessNAS is composed of two parts: *search space* and *training approach*. For users to flexibly define their own search space and use built-in ProxylessNAS training approach, we put the specified search space in [example code](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas) using [NNI NAS interface](NasGuide.md), and put the training approach in [SDK](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch/proxylessnas). | ||
|
||
![](../../img/proxylessnas.png) | ||
|
||
ProxylessNAS training approach is composed of ProxylessNasMutator and ProxylessNasTrainer. ProxylessNasMutator instantiates MixedOp for each mutable (i.e., LayerChoice), and manage architecture weights in MixedOp. **For DataParallel**, architecture weights should be included in user model. Specifically, in ProxylessNAS implementation, we add MixedOp to the corresponding mutable (i.e., LayerChoice) as a member variable. The mutator also exposes two member functions, i.e., `arch_requires_grad`, `arch_disable_grad`, for the trainer to control the training of architecture weights. | ||
|
||
ProxylessNasMutator also implements the forward logic of the mutables (i.e., LayerChoice). | ||
|
||
## Reproduce Results | ||
|
||
Ongoing... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import os | ||
import numpy as np | ||
import torch.utils.data | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
|
||
def get_split_list(in_dim, child_num): | ||
in_dim_list = [in_dim // child_num] * child_num | ||
for _i in range(in_dim % child_num): | ||
in_dim_list[_i] += 1 | ||
return in_dim_list | ||
|
||
class DataProvider: | ||
VALID_SEED = 0 # random seed for the validation set | ||
|
||
@staticmethod | ||
def name(): | ||
""" Return name of the dataset """ | ||
raise NotImplementedError | ||
|
||
@property | ||
def data_shape(self): | ||
""" Return shape as python list of one data entry """ | ||
raise NotImplementedError | ||
|
||
@property | ||
def n_classes(self): | ||
""" Return `int` of num classes """ | ||
raise NotImplementedError | ||
|
||
@property | ||
def save_path(self): | ||
""" local path to save the data """ | ||
raise NotImplementedError | ||
|
||
@property | ||
def data_url(self): | ||
""" link to download the data """ | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def random_sample_valid_set(train_labels, valid_size, n_classes): | ||
train_size = len(train_labels) | ||
assert train_size > valid_size | ||
|
||
g = torch.Generator() | ||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set | ||
rand_indexes = torch.randperm(train_size, generator=g).tolist() | ||
|
||
train_indexes, valid_indexes = [], [] | ||
per_class_remain = get_split_list(valid_size, n_classes) | ||
|
||
for idx in rand_indexes: | ||
label = train_labels[idx] | ||
if isinstance(label, float): | ||
label = int(label) | ||
elif isinstance(label, np.ndarray): | ||
label = np.argmax(label) | ||
else: | ||
assert isinstance(label, int) | ||
if per_class_remain[label] > 0: | ||
valid_indexes.append(idx) | ||
per_class_remain[label] -= 1 | ||
else: | ||
train_indexes.append(idx) | ||
return train_indexes, valid_indexes | ||
|
||
|
||
class ImagenetDataProvider(DataProvider): | ||
|
||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, | ||
n_worker=32, resize_scale=0.08, distort_color=None): | ||
|
||
self._save_path = save_path | ||
train_transforms = self.build_train_transform(distort_color, resize_scale) | ||
train_dataset = datasets.ImageFolder(self.train_path, train_transforms) | ||
|
||
if valid_size is not None: | ||
if isinstance(valid_size, float): | ||
valid_size = int(valid_size * len(train_dataset)) | ||
else: | ||
assert isinstance(valid_size, int), 'invalid valid_size: %s' % valid_size | ||
train_indexes, valid_indexes = self.random_sample_valid_set( | ||
[cls for _, cls in train_dataset.samples], valid_size, self.n_classes, | ||
) | ||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes) | ||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes) | ||
|
||
valid_dataset = datasets.ImageFolder(self.train_path, transforms.Compose([ | ||
transforms.Resize(self.resize_value), | ||
transforms.CenterCrop(self.image_size), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
])) | ||
|
||
self.train = torch.utils.data.DataLoader( | ||
train_dataset, batch_size=train_batch_size, sampler=train_sampler, | ||
num_workers=n_worker, pin_memory=True, | ||
) | ||
self.valid = torch.utils.data.DataLoader( | ||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler, | ||
num_workers=n_worker, pin_memory=True, | ||
) | ||
else: | ||
self.train = torch.utils.data.DataLoader( | ||
train_dataset, batch_size=train_batch_size, shuffle=True, | ||
num_workers=n_worker, pin_memory=True, | ||
) | ||
self.valid = None | ||
|
||
self.test = torch.utils.data.DataLoader( | ||
datasets.ImageFolder(self.valid_path, transforms.Compose([ | ||
transforms.Resize(self.resize_value), | ||
transforms.CenterCrop(self.image_size), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
])), batch_size=test_batch_size, shuffle=False, num_workers=n_worker, pin_memory=True, | ||
) | ||
|
||
if self.valid is None: | ||
self.valid = self.test | ||
|
||
@staticmethod | ||
def name(): | ||
return 'imagenet' | ||
|
||
@property | ||
def data_shape(self): | ||
return 3, self.image_size, self.image_size # C, H, W | ||
|
||
@property | ||
def n_classes(self): | ||
return 1000 | ||
|
||
@property | ||
def save_path(self): | ||
if self._save_path is None: | ||
self._save_path = '/dataset/imagenet' | ||
return self._save_path | ||
|
||
@property | ||
def data_url(self): | ||
raise ValueError('unable to download ImageNet') | ||
|
||
@property | ||
def train_path(self): | ||
return os.path.join(self.save_path, 'train') | ||
|
||
@property | ||
def valid_path(self): | ||
return os.path.join(self._save_path, 'val') | ||
|
||
@property | ||
def normalize(self): | ||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
|
||
def build_train_transform(self, distort_color, resize_scale): | ||
print('Color jitter: %s' % distort_color) | ||
if distort_color == 'strong': | ||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) | ||
elif distort_color == 'normal': | ||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5) | ||
else: | ||
color_transform = None | ||
if color_transform is None: | ||
train_transforms = transforms.Compose([ | ||
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
self.normalize, | ||
]) | ||
else: | ||
train_transforms = transforms.Compose([ | ||
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)), | ||
transforms.RandomHorizontalFlip(), | ||
color_transform, | ||
transforms.ToTensor(), | ||
self.normalize, | ||
]) | ||
return train_transforms | ||
|
||
@property | ||
def resize_value(self): | ||
return 256 | ||
|
||
@property | ||
def image_size(self): | ||
return 224 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import sys | ||
import logging | ||
from argparse import ArgumentParser | ||
import torch | ||
import datasets | ||
|
||
from putils import get_parameters | ||
from model import SearchMobileNet | ||
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer | ||
from retrain import Retrain | ||
|
||
logger = logging.getLogger('nni_proxylessnas') | ||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser("proxylessnas") | ||
# configurations of the model | ||
parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str) | ||
parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str) | ||
parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str) | ||
parser.add_argument("--bn_momentum", default=0.1, type=float) | ||
parser.add_argument("--bn_eps", default=1e-3, type=float) | ||
parser.add_argument("--dropout_rate", default=0, type=float) | ||
parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias']) | ||
# configurations of imagenet dataset | ||
parser.add_argument("--data_path", default='/data/imagenet/', type=str) | ||
parser.add_argument("--train_batch_size", default=256, type=int) | ||
parser.add_argument("--test_batch_size", default=500, type=int) | ||
parser.add_argument("--n_worker", default=32, type=int) | ||
parser.add_argument("--resize_scale", default=0.08, type=float) | ||
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None']) | ||
# configurations for training mode | ||
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain']) | ||
# configurations for search | ||
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) | ||
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) | ||
# configurations for retrain | ||
parser.add_argument("--exported_arch_path", default=None, type=str) | ||
|
||
args = parser.parse_args() | ||
if args.train_mode == 'retrain' and args.exported_arch_path is None: | ||
logger.error('When --train_mode is retrain, --exported_arch_path must be specified.') | ||
sys.exit(-1) | ||
|
||
model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')], | ||
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')], | ||
stride_stages=[int(i) for i in args.stride_stages.split(',')], | ||
n_classes=1000, | ||
dropout_rate=args.dropout_rate, | ||
bn_param=(args.bn_momentum, args.bn_eps)) | ||
logger.info('SearchMobileNet model create done') | ||
model.init_model() | ||
logger.info('SearchMobileNet model init done') | ||
|
||
# move network to GPU if available | ||
if torch.cuda.is_available(): | ||
device = torch.device('cuda:0') | ||
else: | ||
device = torch.device('cpu') | ||
|
||
logger.info('Creating data provider...') | ||
data_provider = datasets.ImagenetDataProvider(save_path=args.data_path, | ||
train_batch_size=args.train_batch_size, | ||
test_batch_size=args.test_batch_size, | ||
valid_size=None, | ||
n_worker=args.n_worker, | ||
resize_scale=args.resize_scale, | ||
distort_color=args.distort_color) | ||
logger.info('Creating data provider done') | ||
|
||
if args.no_decay_keys: | ||
keys = args.no_decay_keys | ||
momentum, nesterov = 0.9, True | ||
optimizer = torch.optim.SGD([ | ||
{'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5}, | ||
{'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0}, | ||
], lr=0.05, momentum=momentum, nesterov=nesterov) | ||
else: | ||
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) | ||
|
||
if args.train_mode == 'search': | ||
# this is architecture search | ||
logger.info('Creating ProxylessNasTrainer...') | ||
trainer = ProxylessNasTrainer(model, | ||
model_optim=optimizer, | ||
train_loader=data_provider.train, | ||
valid_loader=data_provider.valid, | ||
device=device, | ||
warmup=True, | ||
ckpt_path=args.checkpoint_path, | ||
arch_path=args.arch_path) | ||
|
||
logger.info('Start to train with ProxylessNasTrainer...') | ||
trainer.train() | ||
logger.info('Training done') | ||
trainer.export(args.arch_path) | ||
logger.info('Best architecture exported in %s', args.arch_path) | ||
elif args.train_mode == 'retrain': | ||
# this is retrain | ||
from nni.nas.pytorch.fixed import apply_fixed_architecture | ||
assert os.path.isfile(args.exported_arch_path), \ | ||
"exported_arch_path {} should be a file.".format(args.exported_arch_path) | ||
apply_fixed_architecture(model, args.exported_arch_path, device=device) | ||
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300) | ||
trainer.run() |
Oops, something went wrong.