You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!
The text was updated successfully, but these errors were encountered:
functorch.grad computes gradients w.r.t. to the first argument you pass it. This is currently params (all parameters in the model), but the solution is to pass it only the parameters that you want gradients of.
Some pseudocode.
fromfunctorchimportmake_functional_with_buffers, vmap, gradfmodel, params, buffers=make_functional_with_buffers(net,disable_autograd_tracking=True)
defcompute_loss_stateless_model (last_layers_params, first_layers_params, buffers, sample, target):
batch=sample.unsqueeze(0)
targets=target.unsqueeze(0)
# pseudocode: we need to put the params together back into a single params list# that fmodel can understandparams= (*first_layers_params, *last_layers_params)
predictions=fmodel(params, buffers, batch)
loss=criterion(predictions, targets)
returnlossft_compute_grad=grad(compute_loss_stateless_model)
# pseudocode: we need to split the params we want to compute gradients of from the params we don't# want to compute gradients of.first_layers_params, last_layers_params=partition(params)
gradinet=ft_compute_grad(last_layers_params, first_layers_params, buffers, train_poi_set[0][0].cuda(), torch.tensor(train_poi_set[0][1]).cuda())
This will return the gradient of the whole model. However, I only want the second last layers' gradient, like:
Although this method can also obtain the required gradient, it will cause a lot of unnecessary overhead. Is there any way to close the 'require_grad' of all previous layers? Thanks for your answer!
The text was updated successfully, but these errors were encountered: