diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index b2a1f08615d..98644c92d84 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -654,7 +654,7 @@ def func(tensor): def test_compute_power_spectral_density_matrix_with_mask(self): def func(tensor): - mask = torch.rand(201, 100) + mask = torch.rand(201, 100, device=tensor.device) return F.compute_power_spectral_density_matrix(tensor, mask=mask) tensor = torch.rand(2, 201, 100, dtype=torch.cfloat)