Skip to content

Commit

Permalink
Add complex dtype support in functional autograd test (pytorch#2244)
Browse files Browse the repository at this point in the history
Summary:
In autograd tests, to guarantee the precision, the dtype of Tensors are converted to `torch.float64` if they are real. However, the complex dtype is not considered. This PR adds `self.complex_dtype` support to the inputs.

Pull Request resolved: pytorch#2244

Reviewed By: mthrok

Differential Revision: D34272998

Pulled By: nateanl

fbshipit-source-id: e8698a74d7b8d99ee0fcb5f5cb5f2ffc8c80b9b5
  • Loading branch information
nateanl authored and xiaohui-zhang committed May 4, 2022
1 parent 9903752 commit 9ef7891
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_grad(
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
Expand Down

0 comments on commit 9ef7891

Please sign in to comment.