Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Improve high-level API interface and add implementation of ValueChoice #3349

Merged
merged 11 commits into from
Feb 3, 2021
10 changes: 8 additions & 2 deletions docs/en_US/NAS/retiarii/ApiReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Inline Mutation APIs
.. autoclass:: nni.retiarii.nn.pytorch.InputChoice
:members:

.. autoclass:: nni.retiarii.nn.pytorch.ValueChoice
:members:

.. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
:members:

Graph Mutation APIs
-------------------

Expand All @@ -36,10 +42,10 @@ Graph Mutation APIs
Trainers
--------

.. autoclass:: nni.retiarii.trainer.PyTorchImageClassificationTrainer
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer
:members:

.. autoclass:: nni.retiarii.trainer.PyTorchMultiModelTrainer
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer
:members:

Oneshot Trainers
Expand Down
36 changes: 23 additions & 13 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,25 +408,33 @@ def refine_graph(self, ir_graph):
self.merge_aten_slices(ir_graph)

def _handle_layerchoice(self, module):
m_attrs = {}
candidates = module.op_candidates
choices = []
for cand in candidates:
assert id(cand) in self.modules_arg, 'id not exist: {}'.format(id(cand))
for cand in list(module):
assert id(cand) in self.modules_arg, \
f'Module not recorded: {id(cand)}. ' \
'Try to import from `retiarii.nn` if you are using torch.nn module or ' \
'annotate your customized module with @blackbox_module.'
assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
return {
'candidates': choices,
'label': module.label
}

def _handle_inputchoice(self, module):
m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
return {
'n_candidates': module.n_candidates,
'n_chosen': module.n_chosen,
'reduction': module.reduction,
'label': module.label
}

def _handle_valuechoice(self, module):
return {
'candidates': module.candidates,
'label': module.label
}

def convert_module(self, script_module, module, module_name, ir_model):
"""
Expand Down Expand Up @@ -461,6 +469,8 @@ def convert_module(self, script_module, module, module_name, ir_model):
m_attrs = self._handle_layerchoice(module)
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = self.modules_arg[id(module)]
elif original_type_name in torch.nn.__dict__:
Expand Down
2 changes: 1 addition & 1 deletion nni/retiarii/execution/logical_optimizer/logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = []
# FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer'
phy_model.training_config.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'

# merge sub-graphs
for model in multi_model_placement:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)

_supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer']
_supported_training_modules = ['nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer']


class DedupInputNode(AbstractLogicalNode):
Expand Down
32 changes: 6 additions & 26 deletions nni/retiarii/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .mutator import Mutator
from .trainer.interface import BaseTrainer, BaseOneShotTrainer
from .strategies.strategy import BaseStrategy
from .trainer.pytorch import DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer
from .trainer import BaseOneShotTrainer

_logger = logging.getLogger(__name__)

OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)


@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
Expand Down Expand Up @@ -94,28 +92,10 @@ def __init__(self, base_model: Model, trainer: BaseTrainer,
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None

def _process_inline_mutation(self, base_model):
"""
the mutators are order independent
"""
lc_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.LayerChoice')
ic_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.InputChoice')
if not lc_nodes and not ic_nodes:
return None
applied_mutators = []
for node in lc_nodes:
mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices'])
applied_mutators.append(mutator)
for node in ic_nodes:
mutator = InputChoiceMutator(node.name,
node.operation.parameters['n_candidates'],
node.operation.parameters['n_chosen'],
node.operation.parameters['reduction'])
applied_mutators.append(mutator)
return applied_mutators

def _start_strategy(self):
import torch
from .nn.pytorch.mutator import process_inline_mutation

try:
script_module = torch.jit.script(self.base_model)
except Exception as e:
Expand All @@ -131,7 +111,7 @@ def _start_strategy(self):
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])

# handle inline mutations
mutators = self._process_inline_mutation(base_model_ir)
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice')
Expand Down Expand Up @@ -165,7 +145,7 @@ def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool =
Run the experiment.
This function will block until experiment finish or error.
"""
if isinstance(self.trainer, OneShotTrainers):
if isinstance(self.trainer, BaseOneShotTrainer):
self.trainer.fit()
else:
assert config is not None, 'You are using classic search mode, config cannot be None!'
Expand Down
33 changes: 0 additions & 33 deletions nni/retiarii/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,36 +105,3 @@ def __init__(self):
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]


# the following is for inline mutation


class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List):
super().__init__()
self.node_name = node_name
self.candidates = candidates

def mutate(self, model):
target = model.get_node_by_name(self.node_name)
indexes = [i for i in range(len(self.candidates))]
chosen_index = self.choice(indexes)
chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])


class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_candidates: int, n_chosen: int, reduction: str):
super().__init__()
self.node_name = node_name
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction

def mutate(self, model):
target = model.get_node_by_name(self.node_name)
candidates = [i for i in range(self.n_candidates)]
chosen = [self.choice(candidates) for _ in range(self.n_chosen)]
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen, 'reduction': self.reduction})
1 change: 1 addition & 0 deletions nni/retiarii/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .api import *
from .nn import *
Loading