This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
integrate c-darts nas algorithm #1955
Merged
Merged
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
ab8fc7b
integrate c-darts algorithm
penghouwen 83a7aa4
revise cdarts2nni after code review
penghouwen 4543d3f
update readme
penghouwen 28b579c
new revisions
penghouwen 5f66784
Merge remote-tracking branch 'upstream/master' into cdarts2nni
022faa1
fix code style
25778f5
fix code style
f959050
disable wrong import order
2e96bdf
fix syntax issues
ultmaster 386a4a9
update nas docs
ultmaster 457be73
remove trailing whitespace in trainer
ultmaster 2320cdd
mock apex import
ultmaster 2dcab31
elaborate documentation
ultmaster 36882f0
update
ultmaster 19eb69e
update docs
ultmaster 46c229d
fix crossref
ultmaster bbb8873
fix typo
ultmaster 3861109
update mutator and trainer
ultmaster cf0471f
remove trailing whitespace
ultmaster 7c87a7f
fix dosctring format
ultmaster 1a167ba
add license
ultmaster 5d39705
resolve comments in ops.py
ultmaster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# CDARTS | ||
|
||
## Introduction | ||
CDARTS builds a cyclic feedback mechanism between the search and evaluation networks. First, the search network generates an initial topology for evaluation, so that the weights of the evaluation network can be optimized. Second, the architecture topology in the search network is further optimized by the label supervision in classification, as well as the regularization from the evaluation network through feature distillation. Repeating the above cycle results in a joint optimization of the search and evaluation networks, and thus enables the evolution of the topology to fit the final evaluation network. | ||
|
||
## Reproduction Results | ||
This is CDARTS based on the NNI platform, which currently supports CIFAR10 search and retrain. ImageNet search and retrain should also be supported, and we provide corresponding interfaces. Our reproduced results on NNI are slightly lower than the paper, but much higher than the original DARTS. Here we show the results of three independent experiments on CIFAR10. | ||
| Runs | Paper | NNI | | ||
| ---- |:-------------:| :-----:| | ||
| 1 | 97.52 | 97.44 | | ||
| 2 | 97.53 | 97.48 | | ||
| 3 | 97.58 | 97.56 | | ||
|
||
|
||
## Examples | ||
|
||
[Example code](https://github.com/microsoft/nni/tree/master/examples/nas/cdarts) | ||
|
||
```bash | ||
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. | ||
git clone https://github.com/Microsoft/nni.git | ||
|
||
# install apex for distributed training. | ||
git clone https://github.com/NVIDIA/apex | ||
cd apex | ||
python setup.py install --cpp_ext --cuda_ext | ||
|
||
# search the best architecture | ||
cd examples/nas/cdarts | ||
bash run_search_cifar.sh | ||
|
||
# train the best architecture. | ||
bash run_retrain_cifar.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch.nn as nn | ||
ultmaster marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class DistillHeadCIFAR(nn.Module): | ||
|
||
def __init__(self, C, size, num_classes, bn_affine=False): | ||
"""assuming input size 8x8 or 16x16""" | ||
super(DistillHeadCIFAR, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.ReLU(), | ||
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False), # image size = 2 x 2 / 6 x 6 | ||
nn.Conv2d(C, 128, 1, bias=False), | ||
nn.BatchNorm2d(128, affine=bn_affine), | ||
nn.ReLU(), | ||
nn.Conv2d(128, 768, 2, bias=False), | ||
nn.BatchNorm2d(768, affine=bn_affine), | ||
nn.ReLU() | ||
) | ||
self.classifier = nn.Linear(768, num_classes) | ||
self.gap = nn.AdaptiveAvgPool2d(1) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.gap(x) | ||
x = self.classifier(x.view(x.size(0), -1)) | ||
return x | ||
|
||
|
||
class DistillHeadImagenet(nn.Module): | ||
|
||
def __init__(self, C, size, num_classes, bn_affine=False): | ||
"""assuming input size 7x7 or 14x14""" | ||
super(DistillHeadImagenet, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.ReLU(), | ||
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False), # image size = 2 x 2 / 6 x 6 | ||
nn.Conv2d(C, 128, 1, bias=False), | ||
nn.BatchNorm2d(128, affine=bn_affine), | ||
nn.ReLU(), | ||
nn.Conv2d(128, 768, 2, bias=False), | ||
nn.BatchNorm2d(768, affine=bn_affine), | ||
nn.ReLU() | ||
) | ||
self.classifier = nn.Linear(768, num_classes) | ||
self.gap = nn.AdaptiveAvgPool2d(1) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.gap(x) | ||
x = self.classifier(x.view(x.size(0), -1)) | ||
return x | ||
|
||
|
||
class AuxiliaryHeadCIFAR(nn.Module): | ||
|
||
def __init__(self, C, size=5, num_classes=10): | ||
"""assuming input size 8x8""" | ||
super(AuxiliaryHeadCIFAR, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.ReLU(inplace=True), | ||
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 | ||
nn.Conv2d(C, 128, 1, bias=False), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(128, 768, 2, bias=False), | ||
nn.BatchNorm2d(768), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.classifier = nn.Linear(768, num_classes) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.classifier(x.view(x.size(0), -1)) | ||
return x | ||
|
||
|
||
class AuxiliaryHeadImageNet(nn.Module): | ||
|
||
def __init__(self, C, size=5, num_classes=1000): | ||
"""assuming input size 7x7""" | ||
super(AuxiliaryHeadImageNet, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.ReLU(inplace=True), | ||
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False), | ||
nn.Conv2d(C, 128, 1, bias=False), | ||
nn.BatchNorm2d(128), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(128, 768, 2, bias=False), | ||
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo. | ||
# Commenting it out for consistency with the experiments in the paper. | ||
# nn.BatchNorm2d(768), | ||
nn.ReLU(inplace=True) | ||
) | ||
self.classifier = nn.Linear(768, num_classes) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = self.classifier(x.view(x.size(0), -1)) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
""" Config class for search/retrain """ | ||
import argparse | ||
from functools import partial | ||
|
||
|
||
def get_parser(name): | ||
""" make default formatted parser """ | ||
parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
# print default value always | ||
parser.add_argument = partial(parser.add_argument, help=' ') | ||
return parser | ||
|
||
|
||
class BaseConfig(argparse.Namespace): | ||
def print_params(self, prtf=print): | ||
prtf("") | ||
prtf("Parameters:") | ||
for attr, value in sorted(vars(self).items()): | ||
prtf("{}={}".format(attr.upper(), value)) | ||
prtf("") | ||
|
||
def as_markdown(self): | ||
""" Return configs as markdown format """ | ||
text = "|name|value| \n|-|-| \n" | ||
for attr, value in sorted(vars(self).items()): | ||
text += "|{}|{}| \n".format(attr, value) | ||
|
||
return text | ||
|
||
|
||
class SearchConfig(BaseConfig): | ||
def build_parser(self): | ||
parser = get_parser("Search config") | ||
########### basic settings ############ | ||
parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet']) | ||
parser.add_argument('--n_classes', type=int, default=10) | ||
parser.add_argument('--stem_multiplier', type=int, default=3) | ||
parser.add_argument('--init_channels', type=int, default=16) | ||
parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset') | ||
parser.add_argument('--output_path', type=str, default='./outputs', help='') | ||
parser.add_argument('--batch_size', type=int, default=128, help='batch size') | ||
parser.add_argument('--log_frequency', type=int, default=10, help='print frequency') | ||
parser.add_argument('--seed', type=int, default=0, help='random seed') | ||
parser.add_argument('--workers', type=int, default=4, help='# of workers') | ||
parser.add_argument('--steps_per_epoch', type=int, default=None, help='how many steps per epoch, use None for one pass of dataset') | ||
|
||
########### learning rate ############ | ||
parser.add_argument('--w_lr', type=float, default=0.05, help='lr for weights') | ||
parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights') | ||
parser.add_argument('--w_weight_decay', type=float, default=3e-4, help='weight decay for weights') | ||
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights') | ||
parser.add_argument('--alpha_lr', type=float, default=6e-4, help='lr for alpha') | ||
parser.add_argument('--alpha_weight_decay', type=float, default=1e-3, help='weight decay for alpha') | ||
parser.add_argument('--nasnet_lr', type=float, default=0.1, help='lr of nasnet') | ||
|
||
########### alternate training ############ | ||
parser.add_argument('--epochs', type=int, default=32, help='# of search epochs') | ||
parser.add_argument('--warmup_epochs', type=int, default=2, help='# warmup epochs of super model') | ||
parser.add_argument('--loss_alpha', type=float, default=1, help='loss alpha') | ||
parser.add_argument('--loss_T', type=float, default=2, help='loss temperature') | ||
parser.add_argument('--interactive_type', type=str, default='kl', choices=['kl', 'smoothl1']) | ||
parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to sync bn') | ||
parser.add_argument('--use_apex', action='store_true', default=False, help='whether to use apex') | ||
parser.add_argument('--regular_ratio', type=float, default=0.5, help='regular ratio') | ||
parser.add_argument('--regular_coeff', type=float, default=5, help='regular coefficient') | ||
parser.add_argument('--fix_head', action='store_true', default=False, help='whether to fix head') | ||
parser.add_argument('--share_module', action='store_true', default=False, help='whether to share stem and aux head') | ||
|
||
########### data augument ############ | ||
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight') | ||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') | ||
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob') | ||
parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa') | ||
parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient (default: 1)') | ||
|
||
########### distributed ############ | ||
parser.add_argument("--local_rank", default=0, type=int) | ||
parser.add_argument("--world_size", default=1, type=int) | ||
parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training') | ||
parser.add_argument('--distributed', action='store_true', help='run model distributed mode') | ||
|
||
return parser | ||
|
||
def __init__(self): | ||
parser = self.build_parser() | ||
args = parser.parse_args() | ||
super().__init__(**vars(args)) | ||
|
||
|
||
class RetrainConfig(BaseConfig): | ||
def build_parser(self): | ||
parser = get_parser("Retrain config") | ||
parser.add_argument('--dataset', default="cifar10", choices=['cifar10', 'cifar100', 'imagenet']) | ||
parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset') | ||
parser.add_argument('--output_path', type=str, default='./outputs', help='') | ||
parser.add_argument("--arc_checkpoint", default="epoch_02.json") | ||
parser.add_argument('--log_frequency', type=int, default=10, help='print frequency') | ||
|
||
########### model settings ############ | ||
parser.add_argument('--n_classes', type=int, default=10) | ||
parser.add_argument('--input_channels', type=int, default=3) | ||
parser.add_argument('--stem_multiplier', type=int, default=3) | ||
parser.add_argument('--batch_size', type=int, default=128, help='batch size') | ||
parser.add_argument('--eval_batch_size', type=int, default=500, help='batch size for validation') | ||
parser.add_argument('--lr', type=float, default=0.025, help='lr for weights') | ||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum') | ||
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights') | ||
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') | ||
parser.add_argument('--epochs', type=int, default=600, help='# of training epochs') | ||
parser.add_argument('--warmup_epochs', type=int, default=5, help='# warmup') | ||
parser.add_argument('--init_channels', type=int, default=36) | ||
parser.add_argument('--layers', type=int, default=20, help='# of layers') | ||
parser.add_argument('--seed', type=int, default=0, help='random seed') | ||
parser.add_argument('--workers', type=int, default=4, help='# of workers') | ||
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight') | ||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') | ||
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') | ||
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path prob') | ||
|
||
########### data augmentation ############ | ||
parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa') | ||
parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient') | ||
|
||
########### distributed ############ | ||
parser.add_argument("--local_rank", default=0, type=int) | ||
parser.add_argument("--world_size", default=1, type=int) | ||
parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training') | ||
parser.add_argument('--distributed', action='store_true', help='run model distributed mode') | ||
|
||
return parser | ||
|
||
def __init__(self): | ||
parser = self.build_parser() | ||
args = parser.parse_args() | ||
super().__init__(**vars(args)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import numpy as np | ||
import torch | ||
import torchvision.datasets as dset | ||
import torchvision.transforms as transforms | ||
|
||
from datasets.data_utils import CIFAR10Policy, Cutout | ||
from datasets.data_utils import SubsetDistributedSampler | ||
|
||
|
||
def data_transforms_cifar(config, cutout=False): | ||
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] | ||
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] | ||
|
||
if config.use_aa: | ||
train_transform = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4, fill=128), | ||
transforms.RandomHorizontalFlip(), CIFAR10Policy(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||
]) | ||
else: | ||
train_transform = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||
]) | ||
|
||
if cutout: | ||
train_transform.transforms.append(Cutout(config.cutout_length)) | ||
|
||
valid_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(CIFAR_MEAN, CIFAR_STD), | ||
]) | ||
return train_transform, valid_transform | ||
|
||
|
||
def get_search_datasets(config): | ||
dataset = config.dataset.lower() | ||
if dataset == 'cifar10': | ||
dset_cls = dset.CIFAR10 | ||
n_classes = 10 | ||
elif dataset == 'cifar100': | ||
dset_cls = dset.CIFAR100 | ||
n_classes = 100 | ||
else: | ||
raise Exception("Not support dataset!") | ||
|
||
train_transform, valid_transform = data_transforms_cifar(config, cutout=False) | ||
train_data = dset_cls(root=config.data_dir, train=True, download=True, transform=train_transform) | ||
test_data = dset_cls(root=config.data_dir, train=False, download=True, transform=valid_transform) | ||
|
||
num_train = len(train_data) | ||
indices = list(range(num_train)) | ||
split_mid = int(np.floor(0.5 * num_train)) | ||
|
||
if config.distributed: | ||
train_sampler = SubsetDistributedSampler(train_data, indices[:split_mid]) | ||
valid_sampler = SubsetDistributedSampler(train_data, indices[split_mid:num_train]) | ||
else: | ||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split_mid]) | ||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split_mid:num_train]) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
train_data, batch_size=config.batch_size, | ||
sampler=train_sampler, | ||
pin_memory=False, num_workers=config.workers) | ||
|
||
valid_loader = torch.utils.data.DataLoader( | ||
train_data, batch_size=config.batch_size, | ||
sampler=valid_sampler, | ||
pin_memory=False, num_workers=config.workers) | ||
|
||
return [train_loader, valid_loader], [train_sampler, valid_sampler] | ||
|
||
|
||
def get_augment_datasets(config): | ||
dataset = config.dataset.lower() | ||
if dataset == 'cifar10': | ||
dset_cls = dset.CIFAR10 | ||
elif dataset == 'cifar100': | ||
dset_cls = dset.CIFAR100 | ||
else: | ||
raise Exception("Not support dataset!") | ||
|
||
train_transform, valid_transform = data_transforms_cifar(config, cutout=True) | ||
train_data = dset_cls(root=config.data_dir, train=True, download=True, transform=train_transform) | ||
test_data = dset_cls(root=config.data_dir, train=False, download=True, transform=valid_transform) | ||
|
||
if config.distributed: | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) | ||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data) | ||
else: | ||
train_sampler = None | ||
test_sampler = None | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
train_data, batch_size=config.batch_size, | ||
sampler=train_sampler, | ||
pin_memory=True, num_workers=config.workers) | ||
|
||
test_loader = torch.utils.data.DataLoader( | ||
test_data, batch_size=config.eval_batch_size, | ||
sampler=test_sampler, | ||
pin_memory=True, num_workers=config.workers) | ||
|
||
return [train_loader, test_loader], [train_sampler, test_sampler] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pr looks great to me. Only one comment: please introduce cdarts trainer and cdarts mutators here. For example, parameters of cdarts trainer (here is an example of darts, this is its rendering), the difference of the cdarts mutators.