Skip to content

Commit

Permalink
update darts algorithm [untested]
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Aug 1, 2022
1 parent c6bce2d commit 567125c
Showing 1 changed file with 99 additions and 1 deletion.
100 changes: 99 additions & 1 deletion mmrazor/models/algorithms/nas/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def train_step(self, data: List[dict],
# if unroll is True
# TODO does unroll support mixed precision?
optim_wrapper['mutator'].zero_grad()
mutator_log_vars = self.module._unrolled_backward(
mutator_log_vars = self._unrolled_backward(
mutator_data, supernet_data, optim_wrapper['architecture'])
optim_wrapper['mutator'].step()
log_vars.update(add_prefix(mutator_log_vars, 'mutator'))
Expand Down Expand Up @@ -375,3 +375,101 @@ def train_step(self, data: List[dict],
optim_wrapper.update_params(parsed_losses)

return log_vars

def _unrolled_backward(self, mutator_data, supernet_data, optim_wrapper):
"""Compute unrolled loss and backward its gradients."""
backup_params = copy.deepcopy(
tuple(self.module.architecture.parameters()))

# do virtual step on training data
lr = optim_wrapper.param_groups[0]['lr']
momentum = optim_wrapper.param_groups[0]['momentum']
weight_decay = optim_wrapper.param_groups[0]['weight_decay']
self._compute_virtual_model(supernet_data, lr, momentum, weight_decay,
optim_wrapper)

# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
mutator_batch_inputs, mutator_data_samples = \
self.module.data_preprocessor(mutator_data, True)
mutator_loss = self(
mutator_batch_inputs, mutator_data_samples, mode='loss')
mutator_losses, mutator_log_vars = self.module.parse_losses(
mutator_loss)

w_model, w_arch = tuple(self.module.architecture.parameters()), tuple(
self.module.mutator.parameters())
w_grads = torch.autograd.grad(mutator_losses, w_model + w_arch)
d_model, d_arch = w_grads[:len(w_model)], w_grads[len(w_model):]

# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, supernet_data)
with torch.no_grad():
for param, d, h in zip(w_arch, d_arch, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h

# restore weights
self._restore_weights(backup_params)
return mutator_log_vars

def _compute_virtual_model(self, supernet_data, lr, momentum, weight_decay,
optim_wrapper):
"""Compute unrolled weights w`"""
# don't need zero_grad, using autograd to calculate gradients
supernet_batch_inputs, supernet_data_samples = \
self.module.data_preprocessor(supernet_data, True)
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.module.parse_losses(supernet_loss)
gradients = torch.autograd.grad(supernet_loss,
self.module.architecture.parameters())
with torch.no_grad():
for w, g in zip(self.module.architecture.parameters(), gradients):
m = optim_wrapper.optimizer.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)

def _restore_weights(self, backup_params):
"""restore weight from backup params."""
with torch.no_grad():
for param, backup in zip(self.module.architecture.parameters(),
backup_params):
param.copy_(backup)

def _compute_hessian(self, backup_params, dw, supernet_data) -> List:
"""compute hession metric
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } \
- dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
print(
'In computing hessian, norm is smaller than 1E-8, \
cause eps to be %.6f.', norm.item())

dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.module.architecture.parameters(), dw):
p += e * d

supernet_batch_inputs, supernet_data_samples = \
self.module.data_preprocessor(supernet_data, True)
supernet_loss = self(
supernet_batch_inputs, supernet_data_samples, mode='loss')
supernet_loss, _ = self.module.parse_losses(supernet_loss)
dalphas.append(
torch.autograd.grad(supernet_loss,
tuple(self.module.mutator.parameters())))
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
dalpha_pos, dalpha_neg = dalphas
hessian = [(p - n) / (2. * eps)
for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian

0 comments on commit 567125c

Please sign in to comment.