Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[tuner] Regularized Evolution (#2802)
Browse files Browse the repository at this point in the history
  • Loading branch information
tabVersion authored Oct 10, 2020
1 parent 9ed545c commit 8d3f444
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 5 deletions.
19 changes: 17 additions & 2 deletions docs/en_US/NAS/ClassicNas.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,25 @@ At this point, trial code is ready. Then, we can prepare an NNI experiment, i.e.

A file named `nni_auto_gen_search_space.json` is generated by this command. Then put the path of the generated search space in the field `searchSpacePath` of the experiment config file. The other fields of the config file can be filled by referring [this tutorial](../Tutorial/QuickStart.md).

Currently, we only support [PPO Tuner](../Tuner/BuiltinTuner.md) and [random tuner](https://github.com/microsoft/nni/tree/master/examples/tuners/random_nas_tuner) for classic NAS. More classic NAS algorithms will be supported soon.
Currently, we only support [PPO Tuner](../Tuner/BuiltinTuner.md), [Regularized Evolution Tuner](#regulaized-evolution-tuner) and [Random Tuner](https://github.com/microsoft/nni/tree/master/examples/tuners/random_nas_tuner) for classic NAS. More classic NAS algorithms will be supported soon.

The complete examples can be found [here](https://github.com/microsoft/nni/tree/master/examples/nas/classic_nas) for PyTorch and [here](https://github.com/microsoft/nni/tree/master/examples/nas/classic_nas-tf) for TensorFlow.

## Standalone mode for easy debugging

We support a standalone mode for easy debugging, where you can directly run the trial command without launching an NNI experiment. This is for checking whether your trial code can correctly run. The first candidate(s) are chosen for `LayerChoice` and `InputChoice` in this standalone mode.
We support a standalone mode for easy debugging, where you can directly run the trial command without launching an NNI experiment. This is for checking whether your trial code can correctly run. The first candidate(s) are chosen for `LayerChoice` and `InputChoice` in this standalone mode.

<a name="regulaized-evolution-tuner"></a>

## Regularized Evolution Tuner

This is a tuner geared for NNI’s Neural Architecture Search (NAS) interface. It uses the [evolution algorithm](https://arxiv.org/pdf/1802.01548.pdf).

The tuner first randomly initializes the number of `population` models and evaluates them. After that, every time to produce a new architecture, the tuner randomly chooses the number of `sample` architectures from `population`, then mutates the best model in `sample`, the parent model, to produce the child model. The mutation includes the hidden mutation and the op mutation. The hidden state mutation consists of replacing a hidden state with another hidden state from within the cell, subject to the constraint that no loops are formed. The op mutation behaves like the hidden state mutation as far as replacing one op with another op from the op set. Note that keeping the child model the same as its parent is not allowed. After evaluating the child model, it is added to the tail of the `population`, then pops the front one.

Note that **trial concurrency should be less than the population of the model**, otherwise NO_MORE_TRIAL exception will be raised.

The whole procedure is summarized by the pseudocode below.

![](../../img/EvoNasTuner.png)

Binary file added docs/img/EvoNasTuner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions src/sdk/pynni/nni/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
'name': 'PBTTuner',
'class_name': 'nni.pbt_tuner.pbt_tuner.PBTTuner',
'class_args_validator': 'nni.pbt_tuner.pbt_tuner.PBTClassArgsValidator'
},
{
'name': 'RegularizedEvolutionTuner',
'class_name': 'nni.regularized_evolution_tuner.regularized_evolution_tuner.RegularizedEvolutionTuner',
'class_args_validator': 'nni.regularized_evolution_tuner.regularized_evolution_tuner.EvolutionClassArgsValidator'
}
],
'assessors': [
Expand Down
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/regularized_evolution_tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .regularized_evolution_tuner import RegularizedEvolutionTuner
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import copy
import logging
import random
from collections import deque

import nni
from schema import Schema, Optional
from nni.tuner import Tuner
from nni import ClassArgsValidator
from nni.utils import OptimizeMode, extract_scalar_reward

logger = logging.getLogger(__name__)


class FinishedIndividual:
def __init__(self, parameter_id, parameters, result):
"""
Parameters
----------
parameter_id: int
the index of the parameter
parameters : dict
chosen architecture and parameters
result : float
final metric of the chosen one
"""
self.parameter_id = parameter_id
self.parameters = parameters
self.result = result


class EvolutionClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('population_size'): self.range('population_size', int, 0, 99999),
Optional('sample_size'): self.range('sample_size', int, 0, 9999),
}).validate(kwargs)


class RegularizedEvolutionTuner(Tuner):
"""
RegularizedEvolutionTuner is tuner using Evolution NAS Tuner.
See ``Regularized Evolution for Image Classifier Architecture Search`` for details.
Parameters
---
optimize_mode: str
whether to maximize metric or not. default: 'maximize'
population_size: int
the maximum number of kept models
sample_size: int
the number of models chosen from population each time when evolution
"""
def __init__(self, optimize_mode="maximize", population_size=100, sample_size=25):
super(RegularizedEvolutionTuner, self).__init__()
self.optimize_mode = OptimizeMode(optimize_mode)
self.population_size = population_size
self.sample_size = sample_size
self.initial_population = deque()
self.population = deque()
self.history = {}
self.search_space = None
self._from_initial = {} # whether the parameter is from initial population

def generate_parameters(self, parameter_id, **kwargs):
"""
This function will returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
---
parameter_id: int
the index of current set of parameters
"""
if self.initial_population:
arch = self.initial_population.popleft()
self.history[parameter_id] = arch
self._from_initial[parameter_id] = True
return arch
elif self.population:
sample = []
while len(sample) < self.sample_size:
sample.append(random.choice(list(self.population)))

candidate = max(sample, key=lambda x: x.result)
arch = self._mutate_model(candidate)
self.history[parameter_id] = arch
self._from_initial[parameter_id] = False
return arch
else:
raise nni.NoMoreTrialError

def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Record the result from a trial
Parameters
----------
parameter_id : int
parameters : dict
value : dict/float
if value is dict, it should have "default" key.
value is final metrics of the trial.
"""
reward = extract_scalar_reward(value)
if parameter_id not in self.history:
raise RuntimeError('Received parameter_id not in total_data.')
params = self.history[parameter_id]

if self.optimize_mode == OptimizeMode.Minimize:
reward = -reward

self.population.append(FinishedIndividual(parameter_id, params, reward))
if len(self.population) > self.population_size:
self.population.popleft()

def update_search_space(self, search_space):
"""
Update search space.
Search_space contains the information that user pre-defined.
Parameters
----------
search_space : dict
"""
logger.info('update search space %s', search_space)
assert self.search_space is None
self.search_space = search_space

for _, val in search_space.items():
if val['_type'] != 'layer_choice' and val['_type'] != 'input_choice':
raise ValueError('Unsupported search space type: %s' % (val['_type']))

self._generate_initial_population()

def trial_end(self, parameter_id, success, **kwargs):
if not success:
del self.history[parameter_id]
if self._from_initial[parameter_id]:
self.initial_population.append(self._random_model())
del self._from_initial[parameter_id]

def _mutate(self, key, individual):
mutate_val = self.search_space[key]
if mutate_val['_type'] == 'layer_choice':
idx = random.randint(0, len(mutate_val['_value']) - 1)
individual[key] = {'_value': mutate_val['_value'][idx], '_idx': idx}
elif mutate_val['_type'] == 'input_choice':
candidates = mutate_val['_value']['candidates']
n_chosen = mutate_val['_value']['n_chosen']
idxs = [random.randint(0, len(candidates) - 1) for _ in range(n_chosen)]
vals = [candidates[k] for k in idxs]
individual[key] = {'_value': vals, '_idx': idxs}
else:
raise KeyError

def _random_model(self):
individual = {}
for key in self.search_space.keys():
self._mutate(key, individual)
return individual

def _mutate_model(self, model):
new_individual = copy.deepcopy(model.parameters)
mutate_key = random.choice(list(new_individual.keys()))
self._mutate(mutate_key, new_individual)
return new_individual

def _generate_initial_population(self):
while len(self.initial_population) < self.population_size:
self.initial_population.append(self._random_model())
logger.info('init population done.')
26 changes: 26 additions & 0 deletions src/sdk/pynni/tests/assets/classic_nas_search_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"first_conv": {
"_type": "layer_choice",
"_value": [
"conv5x5",
"conv3x3"
]
},
"mid_conv": {
"_type": "layer_choice",
"_value": [
"0",
"1"
]
},
"skip": {
"_type": "input_choice",
"_value": {
"candidates": [
"",
""
],
"n_chosen": 1
}
}
}
41 changes: 38 additions & 3 deletions src/sdk/pynni/tests/test_builtin_tuners.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nni.metis_tuner.metis_tuner import MetisTuner
from nni.msg_dispatcher import _pack_parameter, MsgDispatcher
from nni.pbt_tuner.pbt_tuner import PBTTuner
from nni.regularized_evolution_tuner.regularized_evolution_tuner import RegularizedEvolutionTuner

try:
from nni.smac_tuner.smac_tuner import SMACTuner
Expand Down Expand Up @@ -57,7 +58,8 @@ def send_trial_result(self, tuner, parameter_id, parameters, metrics):
tuner.receive_trial_result(parameter_id, parameters, metrics)
tuner.trial_end(parameter_id, True)

def search_space_test_one(self, tuner_factory, search_space):
def search_space_test_one(self, tuner_factory, search_space, nas=False):
# nas: whether the test checks classic nas tuner
tuner = tuner_factory()
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
Expand All @@ -68,12 +70,14 @@ def search_space_test_one(self, tuner_factory, search_space):
(i + 1) * self.params_each_round)),
st_callback=self.send_trial_callback(queue))
logger.debug(parameters)
self.check_range(parameters, search_space)
check_range = lambda parameters, search_space: self.nas_check_range(parameters, search_space) \
if nas else self.check_range(parameters, search_space)
check_range(parameters, search_space)
for k in range(min(len(parameters), self.params_each_round)):
self.send_trial_result(tuner, self.params_each_round * i + k, parameters[k], random.uniform(-100, 100))
while queue:
id_, params = queue.popleft()
self.check_range([params], search_space)
check_range([params], search_space)
self.send_trial_result(tuner, id_, params, random.uniform(-100, 100))
if not parameters and not self.exhaustive:
raise ValueError("No parameters generated")
Expand Down Expand Up @@ -123,6 +127,19 @@ def check_range(self, generated_params, search_space):
for layer_name in item["_value"].keys():
self.assertIn(v[layer_name]["chosen_layer"], item["layer_choice"])

def nas_check_range(self, generated_params, search_space):
for params in generated_params:
for k in params:
v = params[k]
items = search_space[k]
if items['_type'] == 'layer_choice':
self.assertIn(v['_value'], items['_value'])
elif items['_type'] == 'input_choice':
for choice in v['_value']:
self.assertIn(choice, items['_value']['candidates'])
else:
raise KeyError

def search_space_test_all(self, tuner_factory, supported_types=None, ignore_types=None, fail_types=None):
# Three types: 1. supported; 2. ignore; 3. fail.
# NOTE(yuge): ignore types
Expand Down Expand Up @@ -163,6 +180,20 @@ def search_space_test_all(self, tuner_factory, supported_types=None, ignore_type
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space)

def nas_search_space_test_all(self, tuner_factory):
# Since classic tuner should support only LayerChoice and InputChoice,
# ignore type and fail type are dismissed here.
with open(os.path.join(os.path.dirname(__file__), "assets/classic_nas_search_space.json"), "r") as fp:
search_space_all = json.load(fp)
full_supported_search_space = dict()
for single in search_space_all:
space = search_space_all[single]
single_search_space = {single: space}
self.search_space_test_one(tuner_factory, single_search_space, nas=True)
full_supported_search_space.update(single_search_space)
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space, nas=True)

def import_data_test_for_pbt(self):
"""
test1: import data with complete epoch
Expand Down Expand Up @@ -368,6 +399,10 @@ def tearDown(self):
else:
os.remove(file)

def test_regularized_evolution_tuner(self):
tuner_fn = lambda: RegularizedEvolutionTuner()
self.nas_search_space_test_all(tuner_fn)


if __name__ == '__main__':
main()
4 changes: 4 additions & 0 deletions test/config/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,12 @@ testCases:
- name: tuner-metis
configFile: test/config/tuners/metis.yml

- name: tuner-regularized_evolution
configFile: test/config/tuners/regularized_evolution_tuner.yml

#########################################################################
# nni customized-tuners test
#########################################################################
- name: customized-tuners-demotuner
configFile: test/config/customized_tuners/demotuner-sklearn-classification.yml

20 changes: 20 additions & 0 deletions test/config/tuners/regularized_evolution_tuner.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
authorName: nni
experimentName: default_test
maxExecDuration: 10m
maxTrialNum: 1
trialConcurrency: 1
searchSpacePath: seach_space_classic_nas.json
tuner:
builtinTunerName: RegularizedEvolutionTuner
classArgs:
optimize_mode: maximize
trial:
codeDir: ../../../examples/nas/classic_nas
command: python3 mnist.py --epochs 1
gpuNum: 0

useAnnotation: false
multiPhase: false
multiThread: false

trainingServicePlatform: local
26 changes: 26 additions & 0 deletions test/config/tuners/seach_space_classic_nas.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"first_conv": {
"_type": "layer_choice",
"_value": [
"conv5x5",
"conv3x3"
]
},
"mid_conv": {
"_type": "layer_choice",
"_value": [
"0",
"1"
]
},
"skip": {
"_type": "input_choice",
"_value": {
"candidates": [
"",
""
],
"n_chosen": 1
}
}
}

0 comments on commit 8d3f444

Please sign in to comment.