Fixing Tensor.backward's function signature #1376
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #692
TLDR:
Tensor.backward
has a different parameter order compared to PyTorch and also swapsretain_graph
andcreate_graph
in its internal function call.See https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html for backward's function signature:
Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)
The current TorchSharp version's function signature is:
Tensor.backward(grad_tensors=null, create_graph=false, retain_graph=false, inputs=null)
Note the difference between the ordering of
retain_graph
andcreate_graph
.Tensor.backward
is just a wrapper totorch.autograd.backward
which has a function signature of:autograd.backward(tensors, grad_tensors=null, retain_graph=null, create_graph=false, inputs=null)
This means calling
Tensor.backward(retain_graph: true)
in TorchSharp is actuallyTensor.backward(create_graph:true)
in PyTorch. Same thing forTensor.backward(create_graph: true)
actually beingTensor.backward(retain_graph:true)
.The proposed fix is breaking and would change the
Tensor.backward
function signature to match PyTorch. However, nobody noticed for like 2 years anyway and imoretain_graph
should actually meanretain_graph
(and same forcreate_graph
) 🙂.