-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Extract controller from mutator to make offline decisions #1758
Extract controller from mutator to make offline decisions #1758
Conversation
@@ -6,7 +6,7 @@ | |||
|
|||
class ENASLayer(mutables.MutableScope): | |||
|
|||
def __init__(self, key, num_prev_layers, in_filters, out_filters): | |||
def __init__(self, key, prev_labels, in_filters, out_filters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add docstring
key2module = dict() | ||
for name, module in root.named_modules(): | ||
self.model = model | ||
self._structured_mutables = StructuredMutableTreeNode(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do I still have named_mutables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. It's now called mutables.traverse()
. I can override __iter__
to get for mutable in mutables:
, but with traverse
, I can do more.
examples/nas/darts/retrain.py
Outdated
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) | ||
|
||
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) | ||
archit = FixedArchitecture(model, "./checkpoints/epoch_0.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove the return value archit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Architecture might need to be sent to GPU.
if num_prev_layers > 0: | ||
self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum") | ||
if len(prev_labels) > 0: | ||
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should think about what if the inputs are not all from layerchoice but from normal layers, then how to specify choose_from
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normal layers should be wrapped with a mutable scope to gain the power from the keys. Maybe we can do some annotations.
@@ -4,13 +4,13 @@ | |||
class BaseTrainer(ABC): | |||
|
|||
@abstractmethod | |||
def train(self): | |||
def train(self, validate=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not make sense, please remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean validate=True
doesn't make sense? Why?
Number of inputs to choose from. | ||
choose_from: list of str | ||
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled. | ||
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not make sense...
and other minor changes.