You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was using torch.func in pytorch 2.0 to compute the Hessian-vector product of a neural network.
I first used torch.func.functional_call to define a functional version of the neural network model, and then proceeded to use torch.func.jvp and torch.func.grad to compute the hvp.
The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:
*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.
I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?
My torch version: is 2.0.0.dev20230119+cu117
The text was updated successfully, but these errors were encountered:
@XuchanBao do you have a script that reproduces the problem that we could take a look at?
DistributedDataParallel does some extra things to the model, so it's likely that your hvp result is correct but the DDP extra things are interacting badly with vmap.
Hi,
I was using
torch.func
in pytorch 2.0 to compute the Hessian-vector product of a neural network.I first used
torch.func.functional_call
to define a functional version of the neural network model, and then proceeded to usetorch.func.jvp
andtorch.func.grad
to compute the hvp.The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:
*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.
I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?
My torch version: is
2.0.0.dev20230119+cu117
The text was updated successfully, but these errors were encountered: