diff --git a/mmrazor/models/algorithms/nas/darts.py b/mmrazor/models/algorithms/nas/darts.py index de916bbff..6f321854c 100644 --- a/mmrazor/models/algorithms/nas/darts.py +++ b/mmrazor/models/algorithms/nas/darts.py @@ -145,8 +145,7 @@ def train_step(self, data: List[dict], with optim_wrapper['mutator'].optim_context(self): optim_wrapper['mutator'].zero_grad() mutator_log_vars = self._unrolled_backward( - mutator_data, supernet_data, - optim_wrapper['architecture']) + mutator_data, supernet_data, optim_wrapper) optim_wrapper['mutator'].step() log_vars.update(add_prefix(mutator_log_vars, 'mutator')) else: @@ -187,11 +186,12 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper): backup_params = copy.deepcopy(tuple(self.architecture.parameters())) # Do virtual step on training data - lr = optim_wrapper.param_groups[0]['lr'] - momentum = optim_wrapper.param_groups[0]['momentum'] - weight_decay = optim_wrapper.param_groups[0]['weight_decay'] + lr = optim_wrapper['architecture'].param_groups[0]['lr'] + momentum = optim_wrapper['architecture'].param_groups[0]['momentum'] + weight_decay = optim_wrapper['architecture'].param_groups[0][ + 'weight_decay'] self._compute_virtual_model(supernet_data, lr, momentum, weight_decay, - optim_wrapper) + optim_wrapper['architecture']) # Calculate unrolled loss on validation data # Keep gradients for model here for compute hessian @@ -205,15 +205,15 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper): # the gradients of mutator loss. The gradients of model and arch # can directly obtained. For more information, please refer to # https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py - optim_wrapper.backward(mutator_losses) + optim_wrapper['mutator'].backward(mutator_losses) d_model = [param.grad for param in self.architecture.parameters()] d_arch = [param.grad for param in self.mutator.parameters()] # compute hessian and final gradients hessian = self._compute_hessian(backup_params, d_model, supernet_data, - optim_wrapper) + optim_wrapper['architecture']) - w_arch = tuple(self.architecture.parameters()) + w_arch = tuple(self.mutator.parameters()) with torch.no_grad(): for param, d, h in zip(w_arch, d_arch, hessian): @@ -342,9 +342,10 @@ def train_step(self, data: List[dict], # Update the parameter of mutator if self.module.unroll: - optim_wrapper['mutator'].zero_grad() - mutator_log_vars = self._unrolled_backward( - mutator_data, supernet_data, optim_wrapper) + with optim_wrapper['mutator'].optim_context(self): + optim_wrapper['mutator'].zero_grad() + mutator_log_vars = self._unrolled_backward( + mutator_data, supernet_data, optim_wrapper) optim_wrapper['mutator'].step() log_vars.update(add_prefix(mutator_log_vars, 'mutator')) else: @@ -370,6 +371,8 @@ def train_step(self, data: List[dict], supernet_losses, supernet_log_vars = self.module.parse_losses( supernet_loss) + + # import ipdb; ipdb.set_trace() optim_wrapper['architecture'].update_params(supernet_losses) log_vars.update(add_prefix(supernet_log_vars, 'supernet')) @@ -411,7 +414,6 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper): # can directly obtained. For more information, please refer to # https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py optim_wrapper['mutator'].backward(mutator_losses) - d_model = [ param.grad for param in self.module.architecture.parameters() ] @@ -493,6 +495,7 @@ def _compute_hessian(self, backup_params, dw, supernet_data, optim_wrapper.backward(supernet_loss) dalpha = [param.grad for param in self.module.mutator.parameters()] dalphas.append(dalpha) + # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } dalpha_pos, dalpha_neg = dalphas hessian = [(p - n) / (2. * eps)