-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add swap_tensors path to nn.Module._apply #117167
Add swap_tensors path to nn.Module._apply #117167
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117167
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b16a4af with merge base d444a3b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Towards fixing #115792 [ghstack-poisoned]
ghstack-source-id: 7c5bfa9cc07c847e00765ac843347f4ba392e7fa Pull Request resolved: #117167
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that falling back to .data= is really an option here. They are doing semantically different things and so we should only switch between the two when the user explicitly asks for it.
Towards fixing #115792 Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify `nn.Module._apply. compute_should_use_swap_tensors` to override this if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). Future work might be to swap `AccumulateGrad` nodes if it is necessary. ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. ### Question If the future is set to allow swapping, there are still cases where swapping might not occur (`use_count > 1`, `weak_use_count > 1` (`weak_use_count` is `(use_count >= 1) + num_weak_refs`). I am wondering what we should do in such cases: 1) error loudly: - Pro: for use cases where `swap_tensors` is necessary for correctness (`XLATensor`, `DTensor`), this will make it very apparent that things have gone wrong - Con: For other use cases where `.data` setting is not semantically correct perhaps this might not preserve BC, especially if we flip the default (before weakrefs to parameters were ok, now not anymore) 2) Warn and fall back to `.data` setting: I was thinking to warn with a list of param_names that were not swapped (since only some might have weakrefs, use_count > 1 etc.) - Pro: Doesn't break BC for the common case, provides signal for how to fix for the case where `._apply` is currently broken - Con: warning might be spammy, especially if there are a lot of parameters 3) silently fall back to `.data` setting - Pro: Not spammy - Con: V hard to debug correctness Right now I haven't chosen which to implement yet, so the fallback is just silent.I am wondering which of (1) or (2) (or something else) might be the better solution? [ghstack-poisoned]
Towards fixing #115792 Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify `nn.Module._apply. compute_should_use_swap_tensors` to override this if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). Future work might be to swap `AccumulateGrad` nodes if it is necessary. ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. ### Question If the future is set to allow swapping, there are still cases where swapping might not occur (`use_count > 1`, `weak_use_count > 1` (`weak_use_count` is `(use_count >= 1) + num_weak_refs`). I am wondering what we should do in such cases: 1) error loudly: - Pro: for use cases where `swap_tensors` is necessary for correctness (`XLATensor`, `DTensor`), this will make it very apparent that things have gone wrong - Con: For other use cases where `.data` setting is not semantically correct perhaps this might not preserve BC, especially if we flip the default (before weakrefs to parameters were ok, now not anymore) 2) Warn and fall back to `.data` setting: I was thinking to warn with a list of param_names that were not swapped (since only some might have weakrefs, use_count > 1 etc.) - Pro: Doesn't break BC for the common case, provides signal for how to fix for the case where `._apply` is currently broken - Con: warning might be spammy, especially if there are a lot of parameters 3) silently fall back to `.data` setting - Pro: Not spammy - Con: V hard to debug correctness Right now I haven't chosen which to implement yet, so the fallback is just silent.I am wondering which of (1) or (2) (or something else) might be the better solution? [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
ghstack-source-id: b6f4ea5b09cd5133afa20d1ad2840fbb6658c83c Pull Request resolved: #117167
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
@pytorchbot rebase -s |
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
@mikaylagawarecki Hi, there. Just wanted to give a heads-up that I opened #118783, a complementary PR which also fixes #115792. While it fixes the same issue, it does so because of another latent problem with |
else: | ||
assert param.grad.is_leaf | ||
out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad) | ||
assert param_grad.is_leaf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is weird?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from existing code, I wasn't sure on the rationale for this either 🤔
Towards fixing #115792 Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
Towards fixing #115792 Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…_on_conversion (#118023) For above PR to parametrize existing `load_state_dict` tests Pull Request resolved: #118023 Approved by: https://github.com/albanD ghstack dependencies: #118028, #117167
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** Pull Request resolved: #117167 Approved by: https://github.com/albanD ghstack dependencies: #118028
…_on_conversion (#118023) For above PR to parametrize existing `load_state_dict` tests Pull Request resolved: #118023 Approved by: https://github.com/albanD ghstack dependencies: #118028, #117167
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass. From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.** ***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected. If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error. **`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now** Pull Request resolved: #117167 Approved by: https://github.com/albanD ghstack dependencies: #118028
…_on_conversion (#118023) For above PR to parametrize existing `load_state_dict` tests Pull Request resolved: #118023 Approved by: https://github.com/albanD ghstack dependencies: #118028, #117167
Added
torch.__future__.{get/set}_swap_module_params_on_conversion
that defaults toFalse
for now, but we probably want to modify to override this and default toTrue
innn.Module._apply
if input is a tensor subclass.From offline discussion, for now we are not allowing
swap_tensor
after the first module forward has been run*** if the autograd graph is still alive. The reason being thattorch.utils.swap_tensors(t1, t2)
requires theuse_count
of bothTensorImpl
s associated witht1
andt2
to be 1. The first forward pass will installAccumulateGrad
nodes on each param, which bump the refcount of the associated TensorImpl. Future work might be to swap the refs that theAccumulateGrad
nodes hold if it is necessary.***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via
p.grad = grad
OR the autograd graph is no longer alive because the output has been garbage collected.If any
swap_tensors
fails on any of the parameters in thenn.Module
we raise an error.RNNBase
overridesnn.Module._apply()
and installs weakrefs on some parameters. As a result, all modules that inherit fromRNNBase
(RNN
,GRU
andLSTM
) cannot use theswap_tensors
path as of nowStack from ghstack (oldest at bottom):