Skip to content

Commit

Permalink
Fix return dtype in MVDR module (#2376)
Browse files Browse the repository at this point in the history
Summary:
Address #2375
The MVDR module internally transforms the dtype of complex tensors to `torch.complex128` for computation and transforms it back to the original dtype before returning the Tensor. However, it didn't convert back successfully due to `specgram_enhanced.to(dtype)`, which should be `specgram_enhanced = specgram_enhanced.to(dtype)`. Fix it to make the output dtype consistent with original input.

Pull Request resolved: #2376

Reviewed By: hwangjeff

Differential Revision: D36280851

Pulled By: nateanl

fbshipit-source-id: 553d1b98f899547209a4e3ebc59920c7ef1f3112
  • Loading branch information
nateanl authored and facebook-github-bot committed May 10, 2022
1 parent eab2f39 commit 2f4eb4a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
17 changes: 17 additions & 0 deletions test/torchaudio_unittest/transforms/transforms_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,20 @@ def test_psd(self, duration, channel, mask, multi_mask):
psd_np = psd_numpy(spectrogram.detach().numpy(), mask, multi_mask)
psd = transform(spectrogram, mask)
self.assertEqual(psd, psd_np, atol=1e-5, rtol=1e-5)

@parameterized.expand(
[
param(torch.complex64),
param(torch.complex128),
]
)
def test_mvdr(self, dtype):
"""Make sure the output dtype is the same as the input dtype"""
transform = T.MVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.5, n_channels=3)
specgram = get_spectrogram(waveform, n_fft=400) # (channel, freq, time)
specgram = specgram.to(dtype)
mask_s = torch.rand(specgram.shape[-2:])
mask_n = torch.rand(specgram.shape[-2:])
specgram_enhanced = transform(specgram, mask_s, mask_n)
assert specgram_enhanced.dtype == dtype
3 changes: 1 addition & 2 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,8 +2087,7 @@ def forward(
# unpack batch
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])

specgram_enhanced.to(dtype)
return specgram_enhanced
return specgram_enhanced.to(dtype)


class RTFMVDR(torch.nn.Module):
Expand Down

0 comments on commit 2f4eb4a

Please sign in to comment.