diff --git a/examples/nas/cream/lib/models/structures/supernet.py b/examples/nas/cream/lib/models/structures/supernet.py index ea09377eb5..f5a84ae1ea 100644 --- a/examples/nas/cream/lib/models/structures/supernet.py +++ b/examples/nas/cream/lib/models/structures/supernet.py @@ -131,12 +131,12 @@ def rand_parameters(self, architecture, meta=False): yield param if not meta: - for layer, layer_arch in zip(self.blocks, architecture): - for blocks, arch in zip(layer, layer_arch): - if arch == -1: + for choice_blocks, choice_name in zip(self.blocks, architecture): + choice_sample = architecture[choice_name] + for block, arch in zip(choice_blocks, choice_sample): + if not arch: continue - for name, param in blocks[arch].named_parameters( - recurse=True): + for name, param in block.named_parameters(recurse=True): yield param diff --git a/nni/algorithms/nas/pytorch/cream/trainer.py b/nni/algorithms/nas/pytorch/cream/trainer.py index 2ac9f53a25..50830ce64b 100644 --- a/nni/algorithms/nas/pytorch/cream/trainer.py +++ b/nni/algorithms/nas/pytorch/cream/trainer.py @@ -165,7 +165,7 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop (val_prec1, prec1, flops, - self.current_teacher_arch, + self.current_student_arch, training_data, torch.nn.functional.softmax( features, @@ -174,8 +174,6 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop self.prioritized_board, reverse=True) if len(self.prioritized_board) > self.pool_size: - self.prioritized_board = sorted( - self.prioritized_board, reverse=True) del self.prioritized_board[-1] # only update student network weights