diff --git a/mmrazor/models/algorithms/nas/darts.py b/mmrazor/models/algorithms/nas/darts.py index c17450826..352dec8f1 100644 --- a/mmrazor/models/algorithms/nas/darts.py +++ b/mmrazor/models/algorithms/nas/darts.py @@ -140,9 +140,6 @@ def train_step(self, data: List[dict], supernet_data, mutator_data = data log_vars = dict() - # TODO support unroll - - # TODO 1. adapter to 2.0; 2. ddp # part1: update the parameter of mutator if self.unroll: @@ -188,16 +185,16 @@ def train_step(self, data: List[dict], optim_wrapper.update_params(parsed_losses) return log_vars - def _unrolled_backward(self, mutator_data, supernet_data, optimizer): + def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper): """Compute unrolled loss and backward its gradients.""" backup_params = copy.deepcopy(tuple(self.architecture.parameters())) # do virtual step on training data - lr = optimizer.param_groups[0]['lr'] - momentum = optimizer.param_groups[0]['momentum'] - weight_decay = optimizer.param_groups[0]['weight_decay'] + lr = optim_wrapper.param_groups[0]['lr'] + momentum = optim_wrapper.param_groups[0]['momentum'] + weight_decay = optim_wrapper.param_groups[0]['weight_decay'] self._compute_virtual_model(supernet_data, lr, momentum, weight_decay, - optimizer) + optim_wrapper) # calculate unrolled loss on validation data # keep gradients for model here for compute hessian @@ -224,7 +221,7 @@ def _unrolled_backward(self, mutator_data, supernet_data, optimizer): return mutator_log_vars def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay, - optimizer): + optim_wrapper): """Compute unrolled weights w`""" # don't need zero_grad, using autograd to calculate gradients supernet_batch_inputs, supernet_data_samples = \ @@ -236,17 +233,18 @@ def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay, self.architecture.parameters()) with torch.no_grad(): for w, g in zip(self.architecture.parameters(), gradients): - m = optimizer.state[w].get('momentum_buffer', 0.) + m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.) w = w - lr * (momentum * m + g + weight_decay * w) def _restore_weights(self, backup_params): + """restore weight from backup params.""" with torch.no_grad(): for param, backup in zip(self.architecture.parameters(), backup_params): param.copy_(backup) def _compute_hessian(self, backup_params, dw, supernet_data) -> List: - """ + """compute hession metric dw = dw` { L_val(w`, alpha) } w+ = w + eps * dw w- = w - eps * dw @@ -286,6 +284,7 @@ def _compute_hessian(self, backup_params, dw, supernet_data) -> List: @MODEL_WRAPPERS.register_module() class DartsDDP(MMDistributedDataParallel): + """DDP for Darts and rewrite train_step of MMDDP.""" def __init__(self, *,