Skip to content

Commit

Permalink
Merge pull request espnet#5523 from Emrys365/espnet2_enh
Browse files Browse the repository at this point in the history
Improve error robustness of unit tests
  • Loading branch information
sw005320 authored Nov 2, 2023
2 parents 1dc9a11 + eb6f93a commit 0440d9b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
48 changes: 44 additions & 4 deletions test/espnet2/asr/frontend/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,47 @@
import torch

from espnet2.asr.frontend.default import DefaultFrontend
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed

random_speech = torch.tensor(
[
[
[0.026, 0.031],
[0.027, 0.031],
[0.026, 0.030],
[0.024, 0.028],
[0.025, 0.027],
[0.027, 0.026],
[0.028, 0.026],
[0.029, 0.024],
[0.028, 0.024],
[0.029, 0.026],
[0.029, 0.027],
[0.029, 0.031],
[0.030, 0.038],
[0.029, 0.040],
[0.028, 0.040],
[0.028, 0.041],
],
[
[0.015, 0.021],
[0.005, 0.034],
[0.011, 0.029],
[0.031, 0.036],
[0.031, 0.031],
[0.050, 0.038],
[0.025, 0.054],
[0.042, 0.056],
[0.053, 0.048],
[0.054, 0.044],
[0.053, 0.036],
[0.039, 0.025],
[0.053, 0.032],
[0.054, 0.034],
[0.031, 0.025],
[0.008, 0.023],
],
]
)


def test_frontend_repr():
Expand Down Expand Up @@ -39,8 +79,8 @@ def test_frontend_backward_multi_channel(train, use_wpe, use_beamformer):
frontend.train()
else:
frontend.eval()
set_all_random_seed(14)
x = torch.randn(2, 1000, 2, requires_grad=True)
x_lengths = torch.LongTensor([1000, 980])
x = random_speech.repeat(1, 1024, 1)
x.requires_grad = True
x_lengths = torch.LongTensor([1024, 1000])
y, y_lengths = frontend(x, x_lengths)
y.sum().backward()
1 change: 1 addition & 0 deletions test/espnet2/enh/layers/test_complex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def test_stack(dim):
def test_complex_impl_consistency():
if not is_torch_1_9_plus:
return
torch.random.manual_seed(0)
mat_th = torch.complex(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag))
mat_ct = ComplexTensor(torch.from_numpy(mat_np.real), torch.from_numpy(mat_np.imag))
bs = mat_th.shape[0]
Expand Down

0 comments on commit 0440d9b

Please sign in to comment.