diff --git a/docs/en_US/AdvancedFeature/MultiPhase.md b/docs/en_US/AdvancedFeature/MultiPhase.md index c9727bcdcc..4cdb3a7a99 100644 --- a/docs/en_US/AdvancedFeature/MultiPhase.md +++ b/docs/en_US/AdvancedFeature/MultiPhase.md @@ -79,7 +79,7 @@ With this information, the tuner could know which trial is requesting a configur ### Tuners support multi-phase experiments: -[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md), [ENAS tuner](https://github.com/countif/enas_nni/blob/master/nni/examples/tuners/enas/nni_controller_ptb.py). +[TPE](../Tuner/HyperoptTuner.md), [Random](../Tuner/HyperoptTuner.md), [Anneal](../Tuner/HyperoptTuner.md), [Evolution](../Tuner/EvolutionTuner.md), [SMAC](../Tuner/SmacTuner.md), [NetworkMorphism](../Tuner/NetworkmorphismTuner.md), [MetisTuner](../Tuner/MetisTuner.md), [BOHB](../Tuner/BohbAdvisor.md), [Hyperband](../Tuner/HyperbandAdvisor.md). ### Training services support multi-phase experiment: [Local Machine](../TrainingService/LocalMode.md), [Remote Servers](../TrainingService/RemoteMachineMode.md), [OpenPAI](../TrainingService/PaiMode.md) diff --git a/docs/en_US/NAS/Overview.md b/docs/en_US/NAS/Overview.md index 92b06b413f..4e48483df3 100644 --- a/docs/en_US/NAS/Overview.md +++ b/docs/en_US/NAS/Overview.md @@ -1,62 +1,77 @@ -# Neural Architecture Search (NAS) on NNI - -Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. - -However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. - -With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. - -## Supported algorithms - -NNI supports below NAS algorithms now, and being adding more. User can reproduce an algorithm, or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. - -Note, these algorithms run standalone without nnictl, and supports PyTorch only. - -### DARTS - -The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. - -#### Usage - -```bash -### In case NNI code is not cloned. -git clone https://github.com/Microsoft/nni.git - -cd examples/nas/darts -python search.py -``` - -### P-DARTS - -[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on DARTS(#DARTS). It main contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. - -#### Usage - -```bash -### In case NNI code is not cloned. -git clone https://github.com/Microsoft/nni.git - -cd examples/nas/pdarts -python main.py -``` - -## Use NNI API - -NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. - -*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* - -### Programming interface - -The programming interface of designing and searching a model is often demanded in two scenarios. - -1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So it needs an easy way to express the candidate layers or sub-models. -2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. - -NNI proposed API is [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/dev-nas-refactor/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. - -[1]: https://arxiv.org/abs/1802.03268 -[2]: https://arxiv.org/abs/1707.07012 -[3]: https://arxiv.org/abs/1806.09055 -[4]: https://arxiv.org/abs/1806.10282 -[5]: https://arxiv.org/abs/1703.01041 +# Neural Architecture Search (NAS) on NNI + +Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. + +However, it takes great efforts to implement NAS algorithms, and it is hard to reuse code base of existing algorithms in new one. To facilitate NAS innovations (e.g., design and implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial. + +With this motivation, our ambition is to provide a unified architecture in NNI, to accelerate innovations on NAS, and apply state-of-art algorithms on real world problems faster. + +## Supported algorithms + +NNI supports below NAS algorithms now and being adding more. User can reproduce an algorithm or use it on owned dataset. we also encourage user to implement other algorithms with [NNI API](#use-nni-api), to benefit more people. + +Note, these algorithms run standalone without nnictl, and supports PyTorch only. + +### Dependencies + +* Install latest NNI +* PyTorch 1.2+ +* git + +### DARTS + +The main contribution of [DARTS: Differentiable Architecture Search][3] on algorithm is to introduce a novel algorithm for differentiable network architecture search on bilevel optimization. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/darts +python3 search.py + +# train the best architecture +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json +``` + +### P-DARTS + +[Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) bases on [DARTS](#DARTS). It's contribution on algorithm is to introduce an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. + +#### Usage + +```bash +# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. +git clone https://github.com/Microsoft/nni.git + +# search the best architecture +cd examples/nas/pdarts +python3 search.py + +# train the best architecture, it's the same progress as darts. +cd examples/nas/darts +python3 retrain.py --arc-checkpoint ./checkpoints/epoch_2.json +``` + +## Use NNI API + +NOTE, we are trying to support various NAS algorithms with unified programming interface, and it's in very experimental stage. It means the current programing interface may be updated significantly. + +*previous [NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) interface will be deprecated soon.* + +### Programming interface + +The programming interface of designing and searching a model is often demanded in two scenarios. + +1. When designing a neural network, there may be multiple operation choices on a layer, sub-model, or connection, and it's undetermined which one or combination performs best. So, it needs an easy way to express the candidate layers or sub-models. +2. When applying NAS on a neural network, it needs an unified way to express the search space of architectures, so that it doesn't need to update trial code for different searching algorithms. + +NNI proposed API is [here](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch). And [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts) is an example of NAS implementation, which bases on NNI proposed interface. + +[1]: https://arxiv.org/abs/1802.03268 +[2]: https://arxiv.org/abs/1707.07012 +[3]: https://arxiv.org/abs/1806.09055 +[4]: https://arxiv.org/abs/1806.10282 +[5]: https://arxiv.org/abs/1703.01041 diff --git a/docs/en_US/Tutorial/SearchSpaceSpec.md b/docs/en_US/Tutorial/SearchSpaceSpec.md index fd1781716f..eb5d39315c 100644 --- a/docs/en_US/Tutorial/SearchSpaceSpec.md +++ b/docs/en_US/Tutorial/SearchSpaceSpec.md @@ -73,12 +73,6 @@ All types of sampling strategies and their parameter are listed here: * Which means the variable value is a value like `round(exp(normal(mu, sigma)) / q) * q` * Suitable for a discrete variable with respect to which the objective is smooth and gets smoother with the size of the variable, which is bounded from one side. -* `{"_type": "mutable_layer", "_value": {mutable_layer_infomation}}` - * Type for [Neural Architecture Search Space][1]. Value is also a dictionary, which contains key-value pairs representing respectively name and search space of each mutable_layer. - * For now, users can only use this type of search space with annotation, which means that there is no need to define a json file for search space since it will be automatically generated according to the annotation in trial code. - * The following HPO tuners can be adapted to tune this search space: TPE, Random, Anneal, Evolution, Grid Search, - Hyperband and BOHB. - * For detailed usage, please refer to [General NAS Interfaces][1]. ## Search Space Types Supported by Each Tuner @@ -105,5 +99,3 @@ Known Limitations: * Only Random Search/TPE/Anneal/Evolution tuner supports nested search space * We do not support nested search space "Hyper Parameter" in visualization now, the enhancement is being considered in [#1110](https://github.com/microsoft/nni/issues/1110), any suggestions or discussions or contributions are warmly welcomed - -[1]: ../AdvancedFeature/GeneralNasInterfaces.md diff --git a/docs/en_US/advanced.rst b/docs/en_US/advanced.rst index d9192cc869..e38f634969 100644 --- a/docs/en_US/advanced.rst +++ b/docs/en_US/advanced.rst @@ -3,5 +3,3 @@ Advanced Features .. toctree:: MultiPhase<./AdvancedFeature/MultiPhase> - AdvancedNas<./AdvancedFeature/AdvancedNas> - NAS Programming Interface<./AdvancedFeature/GeneralNasInterfaces> \ No newline at end of file diff --git a/examples/nas/darts/retrain.py b/examples/nas/darts/retrain.py index 5c8fabf8d0..e3167376f9 100644 --- a/examples/nas/darts/retrain.py +++ b/examples/nas/darts/retrain.py @@ -1,4 +1,5 @@ import logging +import time from argparse import ArgumentParser import torch @@ -10,8 +11,17 @@ from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.utils import AverageMeter -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/examples/nas/darts/search.py b/examples/nas/darts/search.py index 02c720a60c..d9bdf0c7b5 100644 --- a/examples/nas/darts/search.py +++ b/examples/nas/darts/search.py @@ -1,14 +1,27 @@ +import logging +import time from argparse import ArgumentParser -import datasets import torch import torch.nn as nn +import datasets from model import CNN -from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint +from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, + LearningRateScheduler) from nni.nas.pytorch.darts import DartsTrainer from utils import accuracy +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) if __name__ == "__main__": parser = ArgumentParser("darts") diff --git a/examples/nas/enas/search.py b/examples/nas/enas/search.py index 35bc930333..6fade75164 100644 --- a/examples/nas/enas/search.py +++ b/examples/nas/enas/search.py @@ -1,3 +1,5 @@ +import logging +import time from argparse import ArgumentParser import torch @@ -10,6 +12,17 @@ from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint from utils import accuracy, reward_accuracy +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) diff --git a/examples/nas/pdarts/datasets.py b/examples/nas/pdarts/datasets.py deleted file mode 100644 index 8fe0ab0fbf..0000000000 --- a/examples/nas/pdarts/datasets.py +++ /dev/null @@ -1,25 +0,0 @@ -from torchvision import transforms -from torchvision.datasets import CIFAR10 - - -def get_dataset(cls): - MEAN = [0.49139968, 0.48215827, 0.44653124] - STD = [0.24703233, 0.24348505, 0.26158768] - transf = [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip() - ] - normalize = [ - transforms.ToTensor(), - transforms.Normalize(MEAN, STD) - ] - - train_transform = transforms.Compose(transf + normalize) - valid_transform = transforms.Compose(normalize) - - if cls == "cifar10": - dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) - dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) - else: - raise NotImplementedError - return dataset_train, dataset_valid diff --git a/examples/nas/pdarts/main.py b/examples/nas/pdarts/main.py deleted file mode 100644 index 68a59c8856..0000000000 --- a/examples/nas/pdarts/main.py +++ /dev/null @@ -1,65 +0,0 @@ -from argparse import ArgumentParser - -import datasets -import torch -import torch.nn as nn -import nni.nas.pytorch as nas -from nni.nas.pytorch.pdarts import PdartsTrainer -from nni.nas.pytorch.darts import CnnNetwork, CnnCell - - -def accuracy(output, target, topk=(1,)): - """ Computes the precision@k for the specified values of k """ - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - # one-hot case - if target.ndimension() > 1: - target = target.max(1)[1] - - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = dict() - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item() - return res - - -if __name__ == "__main__": - parser = ArgumentParser("darts") - parser.add_argument("--layers", default=5, type=int) - parser.add_argument('--add_layers', action='append', - default=[0, 6, 12], help='add layers') - parser.add_argument("--nodes", default=4, type=int) - parser.add_argument("--batch-size", default=128, type=int) - parser.add_argument("--log-frequency", default=1, type=int) - args = parser.parse_args() - - dataset_train, dataset_valid = datasets.get_dataset("cifar10") - - def model_creator(layers, n_nodes): - model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell) - loss = nn.CrossEntropyLoss() - - model_optim = torch.optim.SGD(model.parameters(), 0.025, - momentum=0.9, weight_decay=3.0E-4) - n_epochs = 50 - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001) - return model, loss, model_optim, lr_scheduler - - trainer = PdartsTrainer(model_creator, - metrics=lambda output, target: accuracy(output, target, topk=(1,)), - num_epochs=50, - pdarts_num_layers=[0, 6, 12], - pdarts_num_to_drop=[3, 2, 2], - dataset_train=dataset_train, - dataset_valid=dataset_valid, - layers=args.layers, - n_nodes=args.nodes, - batch_size=args.batch_size, - log_frequency=args.log_frequency) - trainer.train() - trainer.export() diff --git a/examples/nas/pdarts/search.py b/examples/nas/pdarts/search.py new file mode 100644 index 0000000000..5d38fda0db --- /dev/null +++ b/examples/nas/pdarts/search.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import sys +import time +from argparse import ArgumentParser + +import torch +import torch.nn as nn + +from nni.nas.pytorch.callbacks import ArchitectureCheckpoint +from nni.nas.pytorch.pdarts import PdartsTrainer + +# prevent it to be reordered. +if True: + sys.path.append('../darts') + from utils import accuracy + from model import CNN + import datasets + +logger = logging.getLogger() + +fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' +logging.Formatter.converter = time.localtime +formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p') + +std_out_info = logging.StreamHandler() +std_out_info.setFormatter(formatter) +logger.setLevel(logging.INFO) +logger.addHandler(std_out_info) + +if __name__ == "__main__": + parser = ArgumentParser("pdarts") + parser.add_argument('--add_layers', action='append', + default=[0, 6, 12], help='add layers') + parser.add_argument("--nodes", default=4, type=int) + parser.add_argument("--layers", default=5, type=int) + parser.add_argument("--batch-size", default=64, type=int) + parser.add_argument("--log-frequency", default=1, type=int) + parser.add_argument("--epochs", default=50, type=int) + args = parser.parse_args() + + logger.info("loading data") + dataset_train, dataset_valid = datasets.get_dataset("cifar10") + + def model_creator(layers): + model = CNN(32, 3, 16, 10, layers, n_nodes=args.nodes) + criterion = nn.CrossEntropyLoss() + + optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) + + return model, criterion, optim, lr_scheduler + + logger.info("initializing trainer") + trainer = PdartsTrainer(model_creator, + layers=args.layers, + metrics=lambda output, target: accuracy(output, target, topk=(1,)), + pdarts_num_layers=[0, 6, 12], + pdarts_num_to_drop=[3, 2, 2], + num_epochs=args.epochs, + dataset_train=dataset_train, + dataset_valid=dataset_valid, + batch_size=args.batch_size, + log_frequency=args.log_frequency, + callbacks=[ArchitectureCheckpoint("./checkpoints")]) + logger.info("training") + trainer.train() diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py deleted file mode 100644 index 69dc28e8f0..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_cell.py +++ /dev/null @@ -1,69 +0,0 @@ - -import torch -import torch.nn as nn - -import nni.nas.pytorch as nas -from nni.nas.pytorch.modules import RankedModule - -from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv - - -class CnnCell(RankedModule): - """ - Cell for search. - """ - - def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): - """ - Initialization a search cell. - - Parameters - ---------- - n_nodes: int - Number of nodes in current DAG. - channels_pp: int - Number of output channels from previous previous cell. - channels_p: int - Number of output channels from previous cell. - channels: int - Number of channels that will be used in the current DAG. - reduction_p: bool - Flag for whether the previous cell is reduction cell or not. - reduction: bool - Flag for whether the current cell is reduction cell or not. - """ - super(CnnCell, self).__init__(rank=1, reduction=reduction) - self.n_nodes = n_nodes - - # If previous cell is reduction cell, current input size does not match with - # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. - if reduction_p: - self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) - else: - self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) - self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) - - # generate dag - self.mutable_ops = nn.ModuleList() - for depth in range(self.n_nodes): - self.mutable_ops.append(nn.ModuleList()) - for i in range(2 + depth): # include 2 input nodes - # reduction should be used only for input node - stride = 2 if reduction and i < 2 else 1 - m_ops = [] - for primitive in PRIMITIVES: - op = OPS[primitive](channels, stride, False) - m_ops.append(op) - op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i)) - self.mutable_ops[depth].append(op) - - def forward(self, s0, s1): - # s0, s1 are the outputs of previous previous cell and previous cell, respectively. - tensors = [self.preproc0(s0), self.preproc1(s1)] - for ops in self.mutable_ops: - assert len(ops) == len(tensors) - cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors)) - tensors.append(cur_tensor) - - output = torch.cat(tensors[2:], dim=1) - return output diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py deleted file mode 100644 index d126e3353e..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_network.py +++ /dev/null @@ -1,73 +0,0 @@ - -import torch.nn as nn - -from .cnn_cell import CnnCell - - -class CnnNetwork(nn.Module): - """ - Search CNN model - """ - - def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell): - """ - Initializing a search channelsNN. - - Parameters - ---------- - in_channels: int - Number of channels in images. - channels: int - Number of channels used in the network. - n_classes: int - Number of classes. - n_layers: int - Number of cells in the whole network. - n_nodes: int - Number of nodes in a cell. - stem_multiplier: int - Multiplier of channels in STEM. - """ - super().__init__() - self.in_channels = in_channels - self.channels = channels - self.n_classes = n_classes - self.n_layers = n_layers - - c_cur = stem_multiplier * self.channels - self.stem = nn.Sequential( - nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), - nn.BatchNorm2d(c_cur) - ) - - # for the first cell, stem is used for both s0 and s1 - # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. - channels_pp, channels_p, c_cur = c_cur, c_cur, channels - - self.cells = nn.ModuleList() - reduction_p, reduction = False, False - for i in range(n_layers): - reduction_p, reduction = reduction, False - # Reduce featuremap size and double channels in 1/3 and 2/3 layer. - if i in [n_layers // 3, 2 * n_layers // 3]: - c_cur *= 2 - reduction = True - - cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) - self.cells.append(cell) - c_cur_out = c_cur * n_nodes - channels_pp, channels_p = channels_p, c_cur_out - - self.gap = nn.AdaptiveAvgPool2d(1) - self.linear = nn.Linear(channels_p, n_classes) - - def forward(self, x): - s0 = s1 = self.stem(x) - - for cell in self.cells: - s0, s1 = s1, cell(s0, s1) - - out = self.gap(s1) - out = out.view(out.size(0), -1) # flatten - logits = self.linear(out) - return logits diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py b/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py deleted file mode 100644 index 02b4a3a94c..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/darts/cnn_ops.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import torch.nn as nn - -PRIMITIVES = [ - 'none', - 'max_pool_3x3', - 'avg_pool_3x3', - 'skip_connect', # identity - 'sep_conv_3x3', - 'sep_conv_5x5', - 'dil_conv_3x3', - 'dil_conv_5x5', -] - -OPS = { - 'none': lambda C, stride, affine: Zero(stride), - 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), - 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), - 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), - 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), - 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), - 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), - 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 - 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 - 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) -} - - -def drop_path_(x, drop_prob, training): - if training and drop_prob > 0.: - keep_prob = 1. - drop_prob - # per data point mask; assuming x in cuda. - mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) - x.div_(keep_prob).mul_(mask) - - return x - - -class DropPath_(nn.Module): - def __init__(self, p=0.): - """ [!] DropPath is inplace module - Args: - p: probability of an path to be zeroed. - """ - super().__init__() - self.p = p - - def extra_repr(self): - return 'p={}, inplace'.format(self.p) - - def forward(self, x): - drop_path_(x, self.p, self.training) - - return x - - -class PoolBN(nn.Module): - """ - AvgPool or MaxPool - BN - """ - - def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): - """ - Args: - pool_type: 'max' or 'avg' - """ - super().__init__() - if pool_type.lower() == 'max': - self.pool = nn.MaxPool2d(kernel_size, stride, padding) - elif pool_type.lower() == 'avg': - self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) - else: - raise ValueError() - - self.bn = nn.BatchNorm2d(C, affine=affine) - - def forward(self, x): - out = self.pool(x) - out = self.bn(out) - return out - - -class StdConv(nn.Module): - """ Standard conv - ReLU - Conv - BN - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class FacConv(nn.Module): - """ Factorized conv - ReLU - Conv(Kx1) - Conv(1xK) - BN - """ - - def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), - nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class DilConv(nn.Module): - """ (Dilated) depthwise separable conv - ReLU - (Dilated) depthwise separable - Pointwise - BN - If dilation == 2, 3x3 conv => 5x5 receptive field - 5x5 conv => 9x9 receptive field - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): - super().__init__() - self.net = nn.Sequential( - nn.ReLU(), - nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False), - nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(C_out, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class SepConv(nn.Module): - """ Depthwise separable conv - DilConv(dilation=1) * 2 - """ - - def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): - super().__init__() - self.net = nn.Sequential( - DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), - DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) - ) - - def forward(self, x): - return self.net(x) - - -class Identity(nn.Module): - - def forward(self, x): - return x - - -class Zero(nn.Module): - def __init__(self, stride): - super().__init__() - self.stride = stride - - def forward(self, x): - if self.stride == 1: - return x * 0. - - # re-sizing by stride - return x[:, :, ::self.stride, ::self.stride] * 0. - - -class FactorizedReduce(nn.Module): - """ - Reduce feature map size by factorized pointwise(stride=2). - """ - - def __init__(self, C_in, C_out, affine=True): - super().__init__() - self.relu = nn.ReLU() - self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) - self.bn = nn.BatchNorm2d(C_out, affine=affine) - - def forward(self, x): - x = self.relu(x) - out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) - out = self.bn(out) - return out diff --git a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py index c6b29de04a..6392962111 100644 --- a/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/darts/trainer.py @@ -1,12 +1,17 @@ import copy +import logging import torch from torch import nn as nn from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.utils import AverageMeterGroup + from .mutator import DartsMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class DartsTrainer(Trainer): def __init__(self, model, loss, metrics, @@ -72,7 +77,8 @@ def train_one_epoch(self, epoch): metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.train_loader), meters)) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, + self.num_epochs, step+1, len(self.train_loader), meters) def validate_one_epoch(self, epoch): self.model.eval() @@ -86,7 +92,8 @@ def validate_one_epoch(self, epoch): metrics = self.metrics(logits, y) meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.valid_loader), meters)) + logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch+1, + self.num_epochs, step+1, len(self.test_loader), meters) def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): """ diff --git a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py index 1ed302ac7b..49052d6b08 100644 --- a/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/enas/trainer.py @@ -1,3 +1,4 @@ +import logging import torch import torch.optim as optim @@ -6,6 +7,10 @@ from .mutator import EnasMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + class EnasTrainer(Trainer): def __init__(self, model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, @@ -70,8 +75,8 @@ def train_one_epoch(self, epoch): meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: - print("Model Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - step, len(self.train_loader), meters)) + logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch, + self.num_epochs, step, len(self.train_loader), meters) # Train sampler (mutator) self.model.eval() @@ -109,9 +114,8 @@ def train_one_epoch(self, epoch): self.mutator_optim.zero_grad() if self.log_frequency is not None and step % self.log_frequency == 0: - print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, - mutator_step // self.mutator_steps_aggregate, - self.mutator_steps, meters)) + logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch, self.num_epochs, + mutator_step // self.mutator_steps_aggregate, self.mutator_steps, meters) mutator_step += 1 if mutator_step >= total_mutator_steps: break diff --git a/src/sdk/pynni/nni/nas/pytorch/modules.py b/src/sdk/pynni/nni/nas/pytorch/modules.py deleted file mode 100644 index 6570220e13..0000000000 --- a/src/sdk/pynni/nni/nas/pytorch/modules.py +++ /dev/null @@ -1,9 +0,0 @@ - -from torch import nn as nn - - -class RankedModule(nn.Module): - def __init__(self, rank=None, reduction=False): - super(RankedModule, self).__init__() - self.rank = rank - self.reduction = reduction diff --git a/src/sdk/pynni/nni/nas/pytorch/mutables.py b/src/sdk/pynni/nni/nas/pytorch/mutables.py index 79cde1cf3f..4dbf514af8 100644 --- a/src/sdk/pynni/nni/nas/pytorch/mutables.py +++ b/src/sdk/pynni/nni/nas/pytorch/mutables.py @@ -1,7 +1,12 @@ +import logging + import torch.nn as nn from nni.nas.pytorch.utils import global_mutable_counting +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + class Mutable(nn.Module): """ @@ -20,7 +25,7 @@ def __init__(self, key=None): if key is not None: if not isinstance(key, str): key = str(key) - print("Warning: key \"{}\" is not string, converted to string.".format(key)) + logger.warning("Warning: key \"%s\" is not string, converted to string.", key) self._key = key else: self._key = self.__class__.__name__ + str(global_mutable_counting()) diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py index 27dd912ab3..d1d17764ba 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/__init__.py @@ -1 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + from .trainer import PdartsTrainer diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py index da31b3cc69..5862e9714b 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py @@ -1,8 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + import copy import numpy as np -import torch -from torch import nn as nn from torch.nn import functional as F from nni.nas.pytorch.darts import DartsMutator @@ -11,24 +12,27 @@ class PdartsMutator(DartsMutator): - def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None): + def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}): self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_num_to_drop = pdarts_num_to_drop - self.switches = switches + if switches is None: + self.switches = {} + else: + self.switches = switches - super(PdartsMutator, self).__init__() + super(PdartsMutator, self).__init__(model) - def before_build(self): - self.choices = nn.ParameterDict() - if self.switches is None: - self.switches = {} + for mutable in self.mutables: + if isinstance(mutable, LayerChoice): + + switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) + + for index in range(len(switches)-1, -1, -1): + if switches[index] == False: + del(mutable.choices[index]) + mutable.length -= 1 - def named_mutables(self, model): - key2module = dict() - for name, module in model.named_modules(): - if isinstance(module, LayerChoice): - key2module[module.key] = module - yield name, module, True + self.switches[mutable.key] = switches def drop_paths(self): for key in self.switches: @@ -49,22 +53,6 @@ def drop_paths(self): switches[idxs[idx]] = False return self.switches - def on_init_layer_choice(self, mutable: LayerChoice): - switches = self.switches.get( - mutable.key, [True for j in range(mutable.length)]) - - for index in range(len(switches)-1, -1, -1): - if switches[index] == False: - del(mutable.choices[index]) - mutable.length -= 1 - - self.switches[mutable.key] = switches - - self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) - - def on_calc_layer_choice_mask(self, mutable: LayerChoice): - return F.softmax(self.choices[mutable.key], dim=-1) - def get_min_k(self, input_in, k): index = [] for _ in range(k): diff --git a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py index d4ef2bbb8e..af31da08fc 100644 --- a/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py @@ -1,17 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +import logging +from nni.nas.pytorch.callbacks import LearningRateScheduler from nni.nas.pytorch.darts import DartsTrainer -from nni.nas.pytorch.trainer import Trainer +from nni.nas.pytorch.trainer import BaseTrainer from .mutator import PdartsMutator +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) -class PdartsTrainer(Trainer): - def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid, - layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], - mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): +class PdartsTrainer(BaseTrainer): + + def __init__(self, model_creator, layers, metrics, + num_epochs, dataset_train, dataset_valid, + pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2], + mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None): + super(PdartsTrainer, self).__init__() self.model_creator = model_creator self.layers = layers - self.n_nodes = n_nodes self.pdarts_num_layers = pdarts_num_layers self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_epoch = len(pdarts_num_to_drop) @@ -25,29 +33,41 @@ def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_va "device": device, "log_frequency": log_frequency } + self.callbacks = callbacks if callbacks is not None else [] def train(self): layers = self.layers - n_nodes = self.n_nodes switches = None for epoch in range(self.pdarts_epoch): layers = self.layers+self.pdarts_num_layers[epoch] - model, loss, model_optim, _ = self.model_creator( - layers, n_nodes) - mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) # pylint: disable=too-many-function-args + model, criterion, optim, lr_scheduler = self.model_creator(layers) + self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches) + + for callback in self.callbacks: + callback.build(model, self.mutator, self) + callback.on_epoch_begin(epoch) + + darts_callbacks = [] + if lr_scheduler is not None: + darts_callbacks.append(LearningRateScheduler(lr_scheduler)) - self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim, - mutator=mutator, **self.darts_parameters) - print("start pdrats training %s..." % epoch) + self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim, + callbacks=darts_callbacks, **self.darts_parameters) + logger.info("start pdarts training %s...", epoch) self.trainer.train() - # with open('log/parameters_%d.txt' % epoch, "w") as f: - # f.write(str(model.parameters)) + switches = self.mutator.drop_paths() - switches = mutator.drop_paths() + for callback in self.callbacks: + callback.on_epoch_end(epoch) + + def validate(self): + self.model.validate() def export(self): - if (self.trainer is not None) and hasattr(self.trainer, "export"): - self.trainer.export() + self.mutator.export() + + def checkpoint(self): + raise NotImplementedError("Not implemented yet") diff --git a/src/sdk/pynni/nni/nas/pytorch/trainer.py b/src/sdk/pynni/nni/nas/pytorch/trainer.py index a4954a0747..9195631a60 100644 --- a/src/sdk/pynni/nni/nas/pytorch/trainer.py +++ b/src/sdk/pynni/nni/nas/pytorch/trainer.py @@ -7,6 +7,7 @@ from .base_trainer import BaseTrainer _logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) class TorchTensorEncoder(json.JSONEncoder): @@ -59,12 +60,12 @@ def train(self, validate=True): callback.on_epoch_begin(epoch) # training - print("Epoch {} Training".format(epoch)) + _logger.info("Epoch %d Training", epoch) self.train_one_epoch(epoch) if validate: # validation - print("Epoch {} Validating".format(epoch)) + _logger.info("Epoch %d Validating", epoch) self.validate_one_epoch(epoch) for callback in self.callbacks: diff --git a/src/sdk/pynni/nni/nas/utils.py b/src/sdk/pynni/nni/nas/utils.py deleted file mode 100644 index 5000946e7e..0000000000 --- a/src/sdk/pynni/nni/nas/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections import OrderedDict - -_counter = 0 - - -def global_mutable_counting(): - global _counter - _counter += 1 - return _counter - - -class AverageMeterGroup(object): - - def __init__(self): - self.meters = OrderedDict() - - def update(self, data): - for k, v in data.items(): - if k not in self.meters: - self.meters[k] = AverageMeter(k, ":4f") - self.meters[k].update(v) - - def __str__(self): - return " ".join(str(v) for _, v in self.meters.items()) - - -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self, name, fmt=':f'): - self.name = name - self.fmt = fmt - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' - return fmtstr.format(**self.__dict__)