Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Bug fix for Cream NAS #3498

Merged
merged 1 commit into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/nas/cream/lib/models/structures/supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def rand_parameters(self, architecture, meta=False):
yield param

if not meta:
for layer, layer_arch in zip(self.blocks, architecture):
for blocks, arch in zip(layer, layer_arch):
if arch == -1:
for choice_blocks, choice_name in zip(self.blocks, architecture):
choice_sample = architecture[choice_name]
for block, arch in zip(choice_blocks, choice_sample):
if not arch:
continue
for name, param in blocks[arch].named_parameters(
recurse=True):
for name, param in block.named_parameters(recurse=True):
yield param


Expand Down
4 changes: 1 addition & 3 deletions nni/algorithms/nas/pytorch/cream/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop
(val_prec1,
prec1,
flops,
self.current_teacher_arch,
self.current_student_arch,
training_data,
torch.nn.functional.softmax(
features,
Expand All @@ -174,8 +174,6 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop
self.prioritized_board, reverse=True)

if len(self.prioritized_board) > self.pool_size:
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
del self.prioritized_board[-1]

# only update student network weights
Expand Down