Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support classic nas mode: each chosen arch as a separate trial job #1775

Merged
merged 14 commits into from
Nov 26, 2019
17 changes: 17 additions & 0 deletions examples/nas/classic_nas/config_nas.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: NNI_AUTO_GEN
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
useAnnotation: False
tuner:
codeDir: ../../tuners/random_nas_tuner
classFileName: random_nas_tuner.py
className: RandomNASTuner
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
17 changes: 17 additions & 0 deletions examples/nas/classic_nas/config_ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 100h
maxTrialNum: 1000
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: NNI_AUTO_GEN
useAnnotation: False
tuner:
builtinTunerName: PPOTuner
classArgs:
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
181 changes: 181 additions & 0 deletions examples/nas/classic_nas/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
A deep MNIST classifier using convolutional layers.

This file is a modification of the official pytorch mnist example:
https://github.com/pytorch/examples/blob/master/mnist/main.py
"""

import os
import argparse
import logging
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.classic_nas import get_apply_next_architecture


logger = logging.getLogger('mnist_AutoML')
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved


class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
# two options of conv1
self.conv1 = LayerChoice([nn.Conv2d(1, 20, 5, 1),
nn.Conv2d(1, 20, 3, 1)],
key='first_conv')
# two options of mid_conv
self.mid_conv = LayerChoice([nn.Conv2d(20, 20, 3, 1, padding=1),
nn.Conv2d(20, 20, 5, 1, padding=2)],
key='mid_conv')
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
# skip connection over mid_conv
self.input_switch = InputChoice(choose_from=['', 'mid_conv'],
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
n_chosen=1,
key='skip')

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
old_x = x
x = F.relu(self.mid_conv(x))
zero_x = torch.zeros_like(old_x).float()
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
skip_x = self.input_switch([zero_x, old_x])
x = torch.add(x, skip_x)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))


def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

accuracy = 100. * correct / len(test_loader.dataset)

logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))

return accuracy


def main(args):
use_cuda = not args['no_cuda'] and torch.cuda.is_available()

torch.manual_seed(args['seed'])

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

#data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
data_dir = os.path.join(args['data_dir'], 'data')

train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)

hidden_size = args['hidden_size']

model = Net(hidden_size=hidden_size).to(device)
get_apply_next_architecture(model)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])

for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)

if epoch < args['epochs']:
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
else:
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')


def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='/tmp/tensorflow/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')


args, _ = parser.parse_known_args()
return args


if __name__ == '__main__':
try:
params = vars(get_params())
main(params)
except Exception as exception:
logger.exception(exception)
raise
51 changes: 21 additions & 30 deletions examples/tuners/random_nas_tuner/random_nas_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,27 @@
def random_archi_generator(nas_ss, random_state):
'''random
'''
chosen_archi = {}
for block_name, block_value in nas_ss.items():
assert block_value['_type'] == "mutable_layer", \
"Random NAS Tuner only receives NAS search space whose _type is 'mutable_layer'"
block = block_value['_value']
tmp_block = {}
for layer_name, layer in block.items():
tmp_layer = {}
for key, value in layer.items():
if key == 'layer_choice':
index = random_state.randint(len(value))
tmp_layer['chosen_layer'] = value[index]
elif key == 'optional_inputs':
tmp_layer['chosen_inputs'] = []
if layer['optional_inputs']:
if isinstance(layer['optional_input_size'], int):
choice_num = layer['optional_input_size']
else:
choice_range = layer['optional_input_size']
choice_num = random_state.randint(choice_range[0], choice_range[1] + 1)
for _ in range(choice_num):
index = random_state.randint(len(layer['optional_inputs']))
tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index])
elif key == 'optional_input_size':
pass
else:
raise ValueError('Unknown field %s in layer %s of block %s' % (key, layer_name, block_name))
tmp_block[layer_name] = tmp_layer
chosen_archi[block_name] = tmp_block
return chosen_archi
chosen_arch = {}
for key, val in nas_ss.items():
assert val['_type'] in ['layer_choice', 'input_choice'], \
"Random NAS Tuner only receives NAS search space whose _type is 'layer_choice' or 'input_choice'"
if val['_type'] == 'layer_choice':
choices = val['_value']
index = random_state.randint(len(choices))
chosen_arch[key] = {'_value': choices[index], '_idx': index}
elif val['_type'] == 'input_choice':
choices = val['_value']['candidates']
n_chosen = val['_value']['n_chosen']
chosen = []
idxs = []
for _ in range(n_chosen):
index = random_state.randint(len(choices))
chosen.append(choices[index])
idxs.append(index)
chosen_arch[key] = {'_value': chosen, '_idx': idxs}
else:
raise ValueError('Unknown key %s and value %s' % (key, val))
return chosen_arch


class RandomNASTuner(Tuner):
Expand Down
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/nas/pytorch/classic_nas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mutator import get_apply_next_architecture
Loading