Skip to content

Commit

Permalink
Add complex dtype support in functional autograd test (#2244)
Browse files Browse the repository at this point in the history
Summary:
In autograd tests, to guarantee the precision, the dtype of Tensors are converted to `torch.float64` if they are real. However, the complex dtype is not considered. This PR adds `self.complex_dtype` support to the inputs.

Pull Request resolved: #2244

Reviewed By: mthrok

Differential Revision: D34272998

Pulled By: nateanl

fbshipit-source-id: e8698a74d7b8d99ee0fcb5f5cb5f2ffc8c80b9b5
  • Loading branch information
nateanl authored and facebook-github-bot committed Feb 16, 2022
1 parent c2decba commit eeba91d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_grad(
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
Expand Down

0 comments on commit eeba91d

Please sign in to comment.