Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add swap_tensors path to nn.Module._apply #117167

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Jan 10, 2024

Added torch.__future__.{get/set}_swap_module_params_on_conversion that defaults to False for now, but we probably want to modify to override this and default to True in nn.Module._apply if input is a tensor subclass.

From offline discussion, for now we are not allowing swap_tensor after the first module forward has been run*** if the autograd graph is still alive. The reason being that torch.utils.swap_tensors(t1, t2) requires the use_count of both TensorImpls associated with t1 and t2 to be 1. The first forward pass will install AccumulateGrad nodes on each param, which bump the refcount of the associated TensorImpl. Future work might be to swap the refs that the AccumulateGrad nodes hold if it is necessary.

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via p.grad = grad OR the autograd graph is no longer alive because the output has been garbage collected.

If any swap_tensors fails on any of the parameters in the nn.Module we raise an error.

RNNBase overrides nn.Module._apply() and installs weakrefs on some parameters. As a result, all modules that inherit from RNNBase (RNN, GRU and LSTM) cannot use theswap_tensors path as of now

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Jan 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117167

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b16a4af with merge base d444a3b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mikaylagawarecki mikaylagawarecki added release notes: nn release notes category topic: new features topic category labels Jan 10, 2024
@mikaylagawarecki mikaylagawarecki changed the title Add swap_tensors path to nn.Module._apply [WIP] Add swap_tensors path to nn.Module._apply Jan 10, 2024
mikaylagawarecki added a commit that referenced this pull request Jan 11, 2024
ghstack-source-id: 7c5bfa9cc07c847e00765ac843347f4ba392e7fa
Pull Request resolved: #117167
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that falling back to .data= is really an option here. They are doing semantically different things and so we should only switch between the two when the user explicitly asks for it.

Towards fixing #115792

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify `nn.Module._apply. compute_should_use_swap_tensors` to override this if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). Future work might be to swap `AccumulateGrad` nodes if it is necessary.

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

### Question
If the future is set to allow swapping, there are still cases where swapping might not occur (`use_count > 1`, `weak_use_count > 1` (`weak_use_count` is `(use_count >= 1) + num_weak_refs`).

I am wondering what we should do in such cases:
1) error loudly: 
    - Pro: for use cases where `swap_tensors` is necessary for correctness (`XLATensor`, `DTensor`), this will make it very apparent that things have gone wrong
    - Con: For other use cases where `.data` setting is not semantically correct perhaps this might not preserve BC, especially if we flip the default (before weakrefs to parameters were ok, now not anymore)
2) Warn and fall back to `.data` setting: I was thinking to warn with a list of param_names that were not swapped (since only some might have weakrefs, use_count > 1 etc.)
     - Pro: Doesn't break BC for the common case, provides signal for how to fix for the case where `._apply` is currently broken
     - Con: warning might be spammy, especially if there are a lot of parameters
3) silently fall back to `.data` setting 
     - Pro: Not spammy
     - Con: V hard to debug correctness

Right now I haven't chosen which to implement yet, so the fallback is just silent.I am wondering which of (1) or (2) (or something else) might be the better solution?





[ghstack-poisoned]
Towards fixing #115792

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify `nn.Module._apply. compute_should_use_swap_tensors` to override this if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). Future work might be to swap `AccumulateGrad` nodes if it is necessary.

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

### Question
If the future is set to allow swapping, there are still cases where swapping might not occur (`use_count > 1`, `weak_use_count > 1` (`weak_use_count` is `(use_count >= 1) + num_weak_refs`).

I am wondering what we should do in such cases:
1) error loudly: 
    - Pro: for use cases where `swap_tensors` is necessary for correctness (`XLATensor`, `DTensor`), this will make it very apparent that things have gone wrong
    - Con: For other use cases where `.data` setting is not semantically correct perhaps this might not preserve BC, especially if we flip the default (before weakrefs to parameters were ok, now not anymore)
2) Warn and fall back to `.data` setting: I was thinking to warn with a list of param_names that were not swapped (since only some might have weakrefs, use_count > 1 etc.)
     - Pro: Doesn't break BC for the common case, provides signal for how to fix for the case where `._apply` is currently broken
     - Con: warning might be spammy, especially if there are a lot of parameters
3) silently fall back to `.data` setting 
     - Pro: Not spammy
     - Con: V hard to debug correctness

Right now I haven't chosen which to implement yet, so the fallback is just silent.I am wondering which of (1) or (2) (or something else) might be the better solution?





[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki added ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end labels Jan 19, 2024
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Jan 19, 2024
ghstack-source-id: b6f4ea5b09cd5133afa20d1ad2840fbb6658c83c
Pull Request resolved: #117167
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot rebase -s

Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review January 23, 2024 22:23
@mikaylagawarecki mikaylagawarecki changed the title [WIP] Add swap_tensors path to nn.Module._apply Add swap_tensors path to nn.Module._apply Jan 24, 2024
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap `AccumulateGrad` nodes if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad`.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
@ysiraichi
Copy link
Collaborator

@mikaylagawarecki Hi, there. Just wanted to give a heads-up that I opened #118783, a complementary PR which also fixes #115792. While it fixes the same issue, it does so because of another latent problem with FunctionalTensorWrapper. Check this comment to read more about it. That said, I believe this PR is still relevant, since it also addresses other issues.

.lintrunner.toml Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/nn.rst Outdated Show resolved Hide resolved
test/test_modules.py Show resolved Hide resolved
test/test_modules.py Outdated Show resolved Hide resolved
torch/nn/modules/module.py Outdated Show resolved Hide resolved
torch/nn/modules/module.py Outdated Show resolved Hide resolved
torch/nn/modules/module.py Show resolved Hide resolved
else:
assert param.grad.is_leaf
out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)
assert param_grad.is_leaf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is weird?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from existing code, I wasn't sure on the rationale for this either 🤔

torch/nn/utils/__init__.py Outdated Show resolved Hide resolved
Towards fixing #115792

Added `torch.nn.utils.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is on XLA or is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run***. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Towards fixing #115792

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.


**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**





[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

pytorchmergebot pushed a commit that referenced this pull request Feb 7, 2024
…_on_conversion (#118023)

For above PR to parametrize existing `load_state_dict` tests

Pull Request resolved: #118023
Approved by: https://github.com/albanD
ghstack dependencies: #118028, #117167
pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.

**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**

Pull Request resolved: #117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
…_on_conversion (#118023)

For above PR to parametrize existing `load_state_dict` tests

Pull Request resolved: #118023
Approved by: https://github.com/albanD
ghstack dependencies: #118028, #117167
clee2000 pushed a commit that referenced this pull request Feb 14, 2024
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](https://github.com/pytorch/pytorch/blob/6cf1fc66e340132d7e2ed9d42efea42fa7ea0183/torch/csrc/autograd/variable.cpp?fbclid=IwAR2dWDVPoXfWF0QDXhhwJ3U7CIAUcNBCAxptlTX9yDI-0pi_h0FBNsw0ig0#L307). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.

**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**

Pull Request resolved: #117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
clee2000 pushed a commit that referenced this pull request Feb 14, 2024
…_on_conversion (#118023)

For above PR to parametrize existing `load_state_dict` tests

Pull Request resolved: #118023
Approved by: https://github.com/albanD
ghstack dependencies: #118028, #117167
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged release notes: nn release notes category topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants