Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Prettify the export format of NAS trainer #2389

Merged
merged 8 commits into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/en_US/NAS/NasGuide.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture(model, "model_dir/final_architecture.json")
```

The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example
The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:

* A string: select the candidate with corresponding name.
* A number: select the candidate with corresponding index.
* A list of string: select the candidates with corresponding names.
* A list of number: select the candidates with corresponding indices.
* A list of boolean values: a multi-hot array.

For example,

```json
{
"LayerChoice1": [false, true, false, false],
"InputChoice2": [true, true, false]
"LayerChoice1": "conv5x5",
"LayerChoice2": 6,
"InputChoice3": ["layer1", "layer3"],
"InputChoice4": [1, 2],
"InputChoice5": [false, true, false, false, true]
}
```

Expand Down
51 changes: 33 additions & 18 deletions src/sdk/pynni/nni/nas/pytorch/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

import json

import torch

from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from .utils import to_list


class FixedArchitecture(Mutator):
Expand All @@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
fixed_arc : dict
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
Expand All @@ -33,6 +32,34 @@ def __init__(self, model, fixed_arc, strict=True):
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)

def _from_human_readable_architecture(self, human_arc):
# convert from an exported architecture
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
for mutable in self.mutables:
if mutable.key not in result_arc:
continue # skip silently
choice_arr = result_arc[mutable.key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of "arr"?

if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
# multihot, do nothing
continue
if isinstance(mutable, LayerChoice):
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(len(mutable))]
elif isinstance(mutable, InputChoice):
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
result_arc[mutable.key] = choice_arr
return result_arc

def sample_search(self):
"""
Expand All @@ -47,17 +74,6 @@ def sample_final(self):
return self._fixed_arc


def _encode_tensor(data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v) for k, v in data.items()}
return data


def apply_fixed_architecture(model, fixed_arc):
"""
Load architecture from `fixed_arc` and apply to model.
Expand All @@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc)
architecture.reset()
return architecture
85 changes: 77 additions & 8 deletions src/sdk/pynni/nni/nas/pytorch/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import numpy as np
import torch

from nni.nas.pytorch.base_mutator import BaseMutator
from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +60,16 @@ def export(self):
dict
A mapping from key of mutables to decisions.
"""
return self.sample_final()
sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result

def status(self):
"""
Expand Down Expand Up @@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def on_forward_input_choice(self, mutable, tensor_list):
Expand All @@ -185,17 +196,41 @@ def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask

def _select_with_mask(self, map_fn, candidates, mask):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new implementation of this function has complex logic, please add docstring for this function.

if "BoolTensor" in mask.type():
"""
Select masked tensors and return a list of tensors.

Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.

Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be int?

(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError("Unrecognized mask")
return out
raise ValueError("Unrecognized mask '%s'" % mask)
if not torch.is_tensor(mask):
mask = torch.tensor(mask) # pylint: disable=not-callable
return out, mask

def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
Expand Down Expand Up @@ -237,3 +272,37 @@ def _get_decision(self, mutable):
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result

def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted
11 changes: 11 additions & 0 deletions src/sdk/pynni/nni/nas/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import OrderedDict

import numpy as np
import torch

_counter = 0
Expand Down Expand Up @@ -45,6 +46,16 @@ def to_device(obj, device):
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))


def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr


class AverageMeterGroup:
"""
Average meter group for multiple average meters.
Expand Down