Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-picked 0.9] Add dtype argument for kernel caching precision #1556

Merged
merged 5 commits into from
Jun 9, 2021

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Jun 6, 2021

This PR adds dtype argument, that can be used for overwriting the cache precision. Providing torch.float64 is the only use case.

Since 0.9.0, T.Resample precomputes and caches resampling kernel for performance improvement. (10x improvement).

The previous implementation computed the kernel on-the-fly on the same device/dtype as the input Tensor,
but in the newer version, the kernel is precomputed at the construction time and is cached with float32 first.
This causes degradation if one wants to perform resampling on float64, because sinc values computed on float32s are not good enough for resampling in float64.

The reason why we decided to use float32 for initial caching is to keep the UX disruption minimum.
Most of use cases we are aware of are CPU/float32.
We are now asking users to use .to to move the Module to the appropriate device/dtype,
but if users are using float64, this is numerically BC-breaking.
For DL application this might be fine, but for sound engineering, it is probably not, and we would like to well support audio engineers as well.

`T.Resample` precomputes and caches resampling kernel for performance improvement.
Currently, the kernel computed at the construction time and is cached with float32 first.
This causes degredation if one wants to perform resampling on float64.

In 0.9.0, we decided to use float32 for initial caching for the sake of minimum disruption
to user experience. (The original implementation computed the kernel on-the-fly on the same
device/dtype of the input Tensor, but most of use cases we are aware of are CPU/float32.)
We are now asking users to use `.to` to move the Module to the appropriate device/dtype.
But if users are using float64, this is numerically BC-breaking.
For DL application this might be fine, but for sound engineering, it might be not.

This PR adopts `LazyModuleMixin` and `UninitializedBuffer` so that resampling kernel is
computed and cached at the first forward call, assuming that `T.Resample` object has been
moved to the right `device/dtype`. This resolves the numerical BC-breaking for float64 case.

The downside of this adoptation is that
1. To TorchScript the resulting model, one has to do dry-run so that the kernel buffer is materialized. This is another BC-breaking.
2. It issues `UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.`.
@mthrok
Copy link
Collaborator Author

mthrok commented Jun 7, 2021

@carolineechen We can add this LazyResample alongside with the Resample and direct users with need for float64 to use it. Then later, when LazyModuleMixin becomes stable, we can merge two.

@carolineechen
Copy link
Contributor

I also want to note that #1514 was BC-breaking for other dtypes as well. Instead of the entire operation being performed with the final dtype, we first compute the kernel with float64 to maintain maximal precision, then cast it to float32, and cast it once again when the user calls resampler.to. With dtypes smaller than float64, the computation in float64 will result in a higher precision intermediate calculation and therefore slightly different numerical results when it is cast, and with float64, the intermediate cast to float32 causes it to lose a small amount of precision.

cc #1487

@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2021

didn't read in depth, but dumb question: Can you add a dtype= parameter to the Module in question?

@ezyang
Copy link
Contributor

ezyang commented Jun 8, 2021

Having read the issue in depth I still ask why explicit dtype arg doesn't work

@mthrok
Copy link
Collaborator Author

mthrok commented Jun 8, 2021

@ezyang

It totally works. I just did not think of it, because I was constrained on the Module.to(device, dtype) semantics. I will update the PR.

@mthrok mthrok changed the title PoC: Use UninitializedBuffer for resampling kernel cache Add dtype argument for kernel caching precision Jun 9, 2021
@mthrok
Copy link
Collaborator Author

mthrok commented Jun 9, 2021

@carolineechen I updated the PR to add dtype. Review please.

@@ -18,6 +18,9 @@ def _assert_consistency(self, transform, tensor):
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

# Perform dry-run so that UninitializedBuffer/Parameter are materialized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary if taking the dtype initialization approach

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops I forgot to remove it.

lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is * used for here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for keyword-only arguments. For Tensor-related operations, the common pattern in PyTorch is ..., (*, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False). See https://pytorch.org/docs/stable/generated/torch.zeros.html#torch.zeros

Now we can argue if we should follow the same pattern in transform or not, but unlike the other parameters that control the overall nature of how resampling should run, this dtype argument, in my eyes, has different properties. And when we think of the possibility to add more arguments for resampling algorithm itself, I would like to keep the possibility for us to add these new arguments before the dtype arguments without BC-breaking.

dtype (torch.device, optional):
Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions here

  • if the user provides dtype=torch.float64, should they still provide .to(torch.float64) afterwards or is that unnecessary?
  • should we assume that it is preferable for the user to have the generation in float64 even if they are using something like float16? using .to(dtype) in these cases will improve precision, but the resulting behavior will be different than the functional, which does the entire computation in the final dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • if the user provides dtype=torch.float64, should they still provide .to(torch.float64) afterwards or is that unnecessary?

It depends. If the transform object user is manipulating is only composed of Resample, then it makes no difference, but if the processing pipeline is composed of multiple transforms (say, resample -> spectrogram -> mel-scale), then user will have to call .to for the entire transforms. The result-wise it should be same.

  • should we assume that it is preferable for the user to have the generation in float64 even if they are using something like float16?

Yes, that's my view. The only reason we are adding dtype is otherwise we cannot provide a mean to resample in higher precision. We are not trying to get rid of the need for users to call .to. Therefore, in the doc, I only instructed the case where users absolutely are necessary to use it.

using .to(dtype) in these cases will improve precision, but the resulting behavior will be different than the functional, which does the entire computation in the final dtype

I think 1. users can still provide dtype=float16, if they want the exact same result as F.resample. 2. it's okay to have small discrepancy here, as it presumably produces better result (the quality of resampled audio).

Copy link
Contributor

@carolineechen carolineechen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@mthrok mthrok merged commit aec0e8c into pytorch:master Jun 9, 2021
@mthrok mthrok deleted the resample-lazy branch June 9, 2021 18:25
mthrok added a commit that referenced this pull request Jun 9, 2021
Since 0.9.0-RC1, `T.Resample` precomputes and caches resampling kernel for performance improvement. (10x improvement).

The implementation from 0.8.0 computed the kernel on-the-fly on the same `device`/`dtype` as the input Tensor, 
but in the newer version, the kernel is precomputed at the construction time and is cached with `float32` first.
This causes degradation if one wants to perform resampling on `float64`, because `sinc` values computed on `float32`s are not good enough for resampling in `float64`.

The reason why we decided to use `float32` for initial caching is to keep the UX disruption minimum, and there were no way to make it work for `float64`. This PR adds `dtype` argument, that can be used for overwriting the cache precision.
@mthrok mthrok changed the title Add dtype argument for kernel caching precision [Cherry-picked 0.9] Add dtype argument for kernel caching precision Jun 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants