Skip to content

Commit

Permalink
modify optim_context of dartsddp
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Aug 3, 2022
1 parent 1eeb7fa commit c72aa30
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions mmrazor/models/algorithms/nas/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def train_step(self, data: List[dict],
with optim_wrapper['mutator'].optim_context(self):
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data,
optim_wrapper['architecture'])
mutator_data, supernet_data, optim_wrapper)
optim_wrapper['mutator'].step()
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
else:
Expand Down Expand Up @@ -187,11 +186,12 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
backup_params = copy.deepcopy(tuple(self.architecture.parameters()))

# Do virtual step on training data
lr = optim_wrapper.param_groups[0]['lr']
momentum = optim_wrapper.param_groups[0]['momentum']
weight_decay = optim_wrapper.param_groups[0]['weight_decay']
lr = optim_wrapper['architecture'].param_groups[0]['lr']
momentum = optim_wrapper['architecture'].param_groups[0]['momentum']
weight_decay = optim_wrapper['architecture'].param_groups[0][
'weight_decay']
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
optim_wrapper)
optim_wrapper['architecture'])

# Calculate unrolled loss on validation data
# Keep gradients for model here for compute hessian
Expand All @@ -205,15 +205,15 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
# the gradients of mutator loss. The gradients of model and arch
# can directly obtained. For more information, please refer to
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
optim_wrapper.backward(mutator_losses)
optim_wrapper['mutator'].backward(mutator_losses)
d_model = [param.grad for param in self.architecture.parameters()]
d_arch = [param.grad for param in self.mutator.parameters()]

# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, supernet_data,
optim_wrapper)
optim_wrapper['architecture'])

w_arch = tuple(self.architecture.parameters())
w_arch = tuple(self.mutator.parameters())

with torch.no_grad():
for param, d, h in zip(w_arch, d_arch, hessian):
Expand Down Expand Up @@ -342,9 +342,10 @@ def train_step(self, data: List[dict],

# Update the parameter of mutator
if self.module.unroll:
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data, optim_wrapper)
with optim_wrapper['mutator'].optim_context(self):
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data, optim_wrapper)
optim_wrapper['mutator'].step()
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
else:
Expand All @@ -370,6 +371,8 @@ def train_step(self, data: List[dict],

supernet_losses, supernet_log_vars = self.module.parse_losses(
supernet_loss)

# import ipdb; ipdb.set_trace()
optim_wrapper['architecture'].update_params(supernet_losses)
log_vars.update(add_prefix(supernet_log_vars, 'supernet'))

Expand Down Expand Up @@ -411,7 +414,6 @@ def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
# can directly obtained. For more information, please refer to
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/optimizer_wrapper.py
optim_wrapper['mutator'].backward(mutator_losses)

d_model = [
param.grad for param in self.module.architecture.parameters()
]
Expand Down Expand Up @@ -493,6 +495,7 @@ def _compute_hessian(self, backup_params, dw, supernet_data,
optim_wrapper.backward(supernet_loss)
dalpha = [param.grad for param in self.module.mutator.parameters()]
dalphas.append(dalpha)

# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
dalpha_pos, dalpha_neg = dalphas
hessian = [(p - n) / (2. * eps)
Expand Down

0 comments on commit c72aa30

Please sign in to comment.