Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1769 from microsoft/dev-nas-refactor
Browse files Browse the repository at this point in the history
NAS refactor merge back to master (DO NOT SQUASH)
  • Loading branch information
chicm-ms authored Nov 25, 2019
2 parents 503a357 + 17ea5f0 commit d07f728
Show file tree
Hide file tree
Showing 42 changed files with 2,717 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ venv.bak/

# VSCode
.vscode
.vs

# In case you place source code in ~/nni/
/experiments
2 changes: 1 addition & 1 deletion docs/en_US/AdvancedFeature/MultiPhase.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ With this information, the tuner could know which trial is requesting a configur

### Tuners support multi-phase experiments:

[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md), [ENAS tuner](https://github.com/countif/enas_nni/blob/master/nni/examples/tuners/enas/nni_controller_ptb.py).
[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md).

### Training services support multi-phase experiment:
[Local Machine](../TrainingService/LocalMode.md), [Remote Servers](../TrainingService/RemoteMachineMode.md), [OpenPAI](../TrainingService/PaiMode.md)
77 changes: 77 additions & 0 deletions docs/en_US/NAS/Overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Neural Architecture Search (NAS) on NNI

Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging.

However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial.

With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster.

## Supported algorithms

NNI supports below NAS algorithms now and being adding more. User can reproduce an algorithm or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people.

Note, these algorithms run standalone without nnictl, and supports PyTorch only.

### Dependencies

* Install latest NNI
* PyTorch 1.2+
* git

### DARTS

The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization.

#### Usage

```bash
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git

# search the best architecture
cd examples/nas/darts
python3 search.py

# train the best architecture
python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json
```

### P-DARTS

[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on [DARTS](#DARTS). It's contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure.

#### Usage

```bash
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git

# search the best architecture
cd examples/nas/pdarts
python3 search.py

# train the best architecture, it's the same progress as darts.
cd examples/nas/darts
python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json
```

## Use NNI API

NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly.

*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.*

### Programming interface

The programming interface of designing and searching a model is often demanded in two scenarios.

1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So, it needs an easy way to express the candidate layers or sub-models.
2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms.

NNI proposed API is [here](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface.

[1]: https://arxiv.org/abs/1802.03268
[2]: https://arxiv.org/abs/1707.07012
[3]: https://arxiv.org/abs/1806.09055
[4]: https://arxiv.org/abs/1806.10282
[5]: https://arxiv.org/abs/1703.01041
8 changes: 0 additions & 8 deletions docs/en_US/Tutorial/SearchSpaceSpec.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ All types of sampling strategies and their parameter are listed here:
* Which means the variable value is a value like `round(exp(normal(mu, sigma)) / q) * q`
* Suitable for a discrete variable with respect to which the objective is smooth and gets smoother with the size of the variable, which is bounded from one side.

* `{"_type": "mutable_layer", "_value": {mutable_layer_infomation}}`
* Type for [Neural Architecture Search Space][1]. Value is also a dictionary, which contains key-value pairs representing respectively name and search space of each mutable_layer.
* For now, users can only use this type of search space with annotation, which means that there is no need to define a json file for search space since it will be automatically generated according to the annotation in trial code.
* The following HPO tuners can be adapted to tune this search space: TPE, Random, Anneal, Evolution, Grid Search,
Hyperband and BOHB.
* For detailed usage, please refer to [General NAS Interfaces][1].

## Search Space Types Supported by Each Tuner

Expand All @@ -105,5 +99,3 @@ Known Limitations:
* Only Random Search/TPE/Anneal/Evolution tuner supports nested search space

* We do not support nested search space "Hyper Parameter" in visualization now, the enhancement is being considered in [#1110](https://github.com/microsoft/nni/issues/1110), any suggestions or discussions or contributions are warmly welcomed

[1]: ../AdvancedFeature/GeneralNasInterfaces.md
2 changes: 0 additions & 2 deletions docs/en_US/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@ Advanced Features

.. toctree::
MultiPhase<./AdvancedFeature/MultiPhase>
AdvancedNas<./AdvancedFeature/AdvancedNas>
NAS Programming Interface<./AdvancedFeature/GeneralNasInterfaces>
3 changes: 3 additions & 0 deletions examples/nas/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data
checkpoints
runs
53 changes: 53 additions & 0 deletions examples/nas/darts/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10


class Cutout(object):
def __init__(self, length):
self.length = length

def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)

y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)

mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask

return img


def get_dataset(cls, cutout_length=0):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
cutout = []
if cutout_length > 0:
cutout.append(Cutout(cutout_length))

train_transform = transforms.Compose(transf + normalize + cutout)
valid_transform = transforms.Compose(normalize)

if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
157 changes: 157 additions & 0 deletions examples/nas/darts/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables


class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """

def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)

def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits


class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath_()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))

def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)


class Cell(nn.Module):

def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes

# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)

# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))

def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for node in self.mutable_ops:
cur_tensor = node(tensors)
tensors.append(cur_tensor)

output = torch.cat(tensors[2:], dim=1)
return output


class CNN(nn.Module):

def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1

c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)

# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels

self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True

cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out

if i == self.aux_pos:
self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)

self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)

def forward(self, x):
s0 = s1 = self.stem(x)

aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
if i == self.aux_pos and self.training:
aux_logits = self.aux_head(s1)

out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)

if aux_logits is not None:
return logits, aux_logits
return logits

def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
Loading

0 comments on commit d07f728

Please sign in to comment.