-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Prettify the export format of NAS trainer #2389
Changes from 6 commits
9a2243f
39922d9
a679bf2
c720d16
601ed83
e9e79d2
3ff5b72
04a0cf2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,9 @@ | |
|
||
import json | ||
|
||
import torch | ||
|
||
from nni.nas.pytorch.mutables import MutableScope | ||
from nni.nas.pytorch.mutator import Mutator | ||
from .mutables import InputChoice, LayerChoice, MutableScope | ||
from .mutator import Mutator | ||
from .utils import to_list | ||
|
||
|
||
class FixedArchitecture(Mutator): | ||
|
@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator): | |
---------- | ||
model : nn.Module | ||
A mutable network. | ||
fixed_arc : str or dict | ||
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). | ||
fixed_arc : dict | ||
Preloaded architecture object. | ||
strict : bool | ||
Force everything that appears in ``fixed_arc`` to be used at least once. | ||
""" | ||
|
@@ -33,6 +32,33 @@ def __init__(self, model, fixed_arc, strict=True): | |
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) | ||
if mutable_keys - fixed_arc_keys: | ||
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) | ||
self._fixed_arc = self._convert_human_readable_architecture(self._fixed_arc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this function is convert to or from human readable architecture? |
||
|
||
def _convert_human_readable_architecture(self, human_arc): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please provide docstring for this function, though it is private. |
||
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc. | ||
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"}, | ||
# which means {"op1": [0, ]} ir {"op1": ["conv", ]} | ||
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()} | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format. | ||
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true]. | ||
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length. | ||
QuanluZhang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for mutable in self.mutables: | ||
if mutable.key not in result_arc: | ||
continue # skip silently | ||
choice_arr = result_arc[mutable.key] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the meaning of "arr"? |
||
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr): | ||
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \ | ||
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)): | ||
# multihot, do nothing | ||
continue | ||
if isinstance(mutable, LayerChoice): | ||
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr] | ||
choice_arr = [i in choice_arr for i in range(len(mutable))] | ||
elif isinstance(mutable, InputChoice): | ||
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr] | ||
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)] | ||
result_arc[mutable.key] = choice_arr | ||
return result_arc | ||
|
||
def sample_search(self): | ||
""" | ||
|
@@ -47,17 +73,6 @@ def sample_final(self): | |
return self._fixed_arc | ||
|
||
|
||
def _encode_tensor(data): | ||
if isinstance(data, list): | ||
if all(map(lambda o: isinstance(o, bool), data)): | ||
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable | ||
else: | ||
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable | ||
if isinstance(data, dict): | ||
return {k: _encode_tensor(v) for k, v in data.items()} | ||
return data | ||
|
||
|
||
def apply_fixed_architecture(model, fixed_arc): | ||
""" | ||
Load architecture from `fixed_arc` and apply to model. | ||
|
@@ -78,7 +93,6 @@ def apply_fixed_architecture(model, fixed_arc): | |
if isinstance(fixed_arc, str): | ||
with open(fixed_arc) as f: | ||
fixed_arc = json.load(f) | ||
fixed_arc = _encode_tensor(fixed_arc) | ||
architecture = FixedArchitecture(model, fixed_arc) | ||
architecture.reset() | ||
return architecture |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,9 @@ | |
import numpy as np | ||
import torch | ||
|
||
from nni.nas.pytorch.base_mutator import BaseMutator | ||
from .base_mutator import BaseMutator | ||
from .mutables import LayerChoice, InputChoice | ||
from .utils import to_list | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -58,7 +60,16 @@ def export(self): | |
dict | ||
A mapping from key of mutables to decisions. | ||
""" | ||
return self.sample_final() | ||
sampled = self.sample_final() | ||
result = dict() | ||
for mutable in self.mutables: | ||
if not isinstance(mutable, (LayerChoice, InputChoice)): | ||
# not supported as built-in | ||
continue | ||
result[mutable.key] = self._convert_mutable_decision(mutable, sampled.pop(mutable.key)) | ||
if sampled: | ||
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys())) | ||
return result | ||
|
||
def status(self): | ||
""" | ||
|
@@ -159,7 +170,7 @@ def _map_fn(op, args, kwargs): | |
mask = self._get_decision(mutable) | ||
assert len(mask) == len(mutable), \ | ||
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) | ||
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) | ||
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) | ||
return self._tensor_reduction(mutable.reduction, out), mask | ||
|
||
def on_forward_input_choice(self, mutable, tensor_list): | ||
|
@@ -185,17 +196,23 @@ def on_forward_input_choice(self, mutable, tensor_list): | |
mask = self._get_decision(mutable) | ||
assert len(mask) == mutable.n_candidates, \ | ||
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) | ||
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | ||
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) | ||
return self._tensor_reduction(mutable.reduction, out), mask | ||
|
||
def _select_with_mask(self, map_fn, candidates, mask): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the new implementation of this function has complex logic, please add docstring for this function. |
||
if "BoolTensor" in mask.type(): | ||
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \ | ||
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \ | ||
"BoolTensor" in mask.type(): | ||
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] | ||
elif "FloatTensor" in mask.type(): | ||
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it can be |
||
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \ | ||
"FloatTensor" in mask.type(): | ||
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] | ||
else: | ||
raise ValueError("Unrecognized mask") | ||
return out | ||
raise ValueError("Unrecognized mask '%s'" % mask) | ||
if not torch.is_tensor(mask): | ||
mask = torch.tensor(mask) # pylint: disable=not-callable | ||
return out, mask | ||
|
||
def _tensor_reduction(self, reduction_type, tensor_list): | ||
if reduction_type == "none": | ||
|
@@ -237,3 +254,37 @@ def _get_decision(self, mutable): | |
result = self._cache[mutable.key] | ||
logger.debug("Decision %s: %s", mutable.key, result) | ||
return result | ||
|
||
def _convert_mutable_decision(self, mutable, sampled): | ||
# Assert the existence of mutable.key in returned architecture. | ||
# Also check if there is anything extra. | ||
multihot_list = to_list(sampled) | ||
converted = None | ||
# If it's a boolean array, we can do optimization. | ||
if all([t == 0 or t == 1 for t in multihot_list]): | ||
if isinstance(mutable, LayerChoice): | ||
assert len(multihot_list) == len(mutable), \ | ||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | ||
% (mutable.key, multihot_list) | ||
# check if all modules have different names and they indeed have names | ||
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names): | ||
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]] | ||
else: | ||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | ||
if isinstance(mutable, InputChoice): | ||
assert len(multihot_list) == mutable.n_candidates, \ | ||
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \ | ||
% (mutable.key, multihot_list) | ||
# check if all input candidates have different names | ||
if len(set(mutable.choose_from)) == mutable.n_candidates: | ||
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]] | ||
else: | ||
converted = [i for i in range(len(multihot_list)) if multihot_list[i]] | ||
if converted is not None: | ||
# if only one element, then remove the bracket | ||
if len(converted) == 1: | ||
converted = converted[0] | ||
else: | ||
# do nothing | ||
converted = multihot_list | ||
return converted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of 1, 2? the index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it could be either index or name? when it is index and when it is name?