-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Cherry-picked 0.9] Add dtype argument for kernel caching precision #1556
Conversation
`T.Resample` precomputes and caches resampling kernel for performance improvement. Currently, the kernel computed at the construction time and is cached with float32 first. This causes degredation if one wants to perform resampling on float64. In 0.9.0, we decided to use float32 for initial caching for the sake of minimum disruption to user experience. (The original implementation computed the kernel on-the-fly on the same device/dtype of the input Tensor, but most of use cases we are aware of are CPU/float32.) We are now asking users to use `.to` to move the Module to the appropriate device/dtype. But if users are using float64, this is numerically BC-breaking. For DL application this might be fine, but for sound engineering, it might be not. This PR adopts `LazyModuleMixin` and `UninitializedBuffer` so that resampling kernel is computed and cached at the first forward call, assuming that `T.Resample` object has been moved to the right `device/dtype`. This resolves the numerical BC-breaking for float64 case. The downside of this adoptation is that 1. To TorchScript the resulting model, one has to do dry-run so that the kernel buffer is materialized. This is another BC-breaking. 2. It issues `UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.`.
@carolineechen We can add this |
I also want to note that #1514 was BC-breaking for other dtypes as well. Instead of the entire operation being performed with the final dtype, we first compute the kernel with cc #1487 |
didn't read in depth, but dumb question: Can you add a |
Having read the issue in depth I still ask why explicit dtype arg doesn't work |
It totally works. I just did not think of it, because I was constrained on the |
@carolineechen I updated the PR to add |
@@ -18,6 +18,9 @@ def _assert_consistency(self, transform, tensor): | |||
tensor = tensor.to(device=self.device, dtype=self.dtype) | |||
transform = transform.to(device=self.device, dtype=self.dtype) | |||
|
|||
# Perform dry-run so that UninitializedBuffer/Parameter are materialized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary if taking the dtype
initialization approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops I forgot to remove it.
lowpass_filter_width: int = 6, | ||
rolloff: float = 0.99, | ||
beta: Optional[float] = None, | ||
*, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is * used for here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for keyword-only arguments. For Tensor-related operations, the common pattern in PyTorch is ..., (*, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
. See https://pytorch.org/docs/stable/generated/torch.zeros.html#torch.zeros
Now we can argue if we should follow the same pattern in transform or not, but unlike the other parameters that control the overall nature of how resampling should run, this dtype
argument, in my eyes, has different properties. And when we think of the possibility to add more arguments for resampling algorithm itself, I would like to keep the possibility for us to add these new arguments before the dtype
arguments without BC-breaking.
dtype (torch.device, optional): | ||
Determnines the precision that resampling kernel is pre-computed and cached. If not provided, | ||
kernel is computed with ``torch.float64`` then cached as ``torch.float32``. | ||
If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few questions here
- if the user provides
dtype=torch.float64
, should they still provide.to(torch.float64)
afterwards or is that unnecessary? - should we assume that it is preferable for the user to have the generation in
float64
even if they are using something like float16? using.to(dtype)
in these cases will improve precision, but the resulting behavior will be different than the functional, which does the entire computation in the final dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- if the user provides
dtype=torch.float64
, should they still provide.to(torch.float64)
afterwards or is that unnecessary?
It depends. If the transform object user is manipulating is only composed of Resample
, then it makes no difference, but if the processing pipeline is composed of multiple transforms (say, resample -> spectrogram -> mel-scale), then user will have to call .to
for the entire transforms. The result-wise it should be same.
- should we assume that it is preferable for the user to have the generation in
float64
even if they are using something like float16?
Yes, that's my view. The only reason we are adding dtype
is otherwise we cannot provide a mean to resample in higher precision. We are not trying to get rid of the need for users to call .to
. Therefore, in the doc, I only instructed the case where users absolutely are necessary to use it.
using
.to(dtype)
in these cases will improve precision, but the resulting behavior will be different than the functional, which does the entire computation in the final dtype
I think 1. users can still provide dtype=float16
, if they want the exact same result as F.resample
. 2. it's okay to have small discrepancy here, as it presumably produces better result (the quality of resampled audio).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
Since 0.9.0-RC1, `T.Resample` precomputes and caches resampling kernel for performance improvement. (10x improvement). The implementation from 0.8.0 computed the kernel on-the-fly on the same `device`/`dtype` as the input Tensor, but in the newer version, the kernel is precomputed at the construction time and is cached with `float32` first. This causes degradation if one wants to perform resampling on `float64`, because `sinc` values computed on `float32`s are not good enough for resampling in `float64`. The reason why we decided to use `float32` for initial caching is to keep the UX disruption minimum, and there were no way to make it work for `float64`. This PR adds `dtype` argument, that can be used for overwriting the cache precision.
This PR adds
dtype
argument, that can be used for overwriting the cache precision. Providingtorch.float64
is the only use case.Since 0.9.0,
T.Resample
precomputes and caches resampling kernel for performance improvement. (10x improvement).The previous implementation computed the kernel on-the-fly on the same
device
/dtype
as the input Tensor,but in the newer version, the kernel is precomputed at the construction time and is cached with
float32
first.This causes degradation if one wants to perform resampling on
float64
, becausesinc
values computed onfloat32
s are not good enough for resampling infloat64
.The reason why we decided to use
float32
for initial caching is to keep the UX disruption minimum.Most of use cases we are aware of are CPU/float32.
We are now asking users to use
.to
to move the Module to the appropriatedevice
/dtype
,but if users are using float64, this is numerically BC-breaking.
For DL application this might be fine, but for sound engineering, it is probably not, and we would like to well support audio engineers as well.