forked from pytorch/audio
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests for PyTorch Lightning modules of emformer_rnnt recipes (…
…pytorch#2240) Summary: - Refactor the current `LibriSpeechRNNTModule`'s unit test. - Add unit tests for `TEDLIUM3RNNTModule` and `MuSTCRNNTModule` - Replace the lambda with partial in `TEDLIUM3RNNTModule` to pass the lightning unit test. Pull Request resolved: pytorch#2240 Reviewed By: mthrok Differential Revision: D34285195 Pulled By: nateanl fbshipit-source-id: 4f20749c85ddd25cbb0eafc1733c64212542338f
- Loading branch information
1 parent
6393401
commit 09d7003
Showing
7 changed files
with
248 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
test/torchaudio_unittest/example/emformer_rnnt/test_mustc_lightning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from contextlib import contextmanager | ||
from functools import partial | ||
from unittest.mock import patch | ||
|
||
import torch | ||
from parameterized import parameterized | ||
from torchaudio._internal.module_utils import is_module_available | ||
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule | ||
|
||
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader | ||
|
||
if is_module_available("pytorch_lightning", "sentencepiece"): | ||
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule | ||
|
||
|
||
class MockMUSTC: | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def __getitem__(self, n: int): | ||
return ( | ||
torch.rand(1, 32640), | ||
"sup", | ||
) | ||
|
||
def __len__(self): | ||
return 10 | ||
|
||
|
||
@contextmanager | ||
def get_lightning_module(): | ||
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch( | ||
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity | ||
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch( | ||
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset | ||
), patch( | ||
"torch.utils.data.DataLoader", new=MockDataloader | ||
): | ||
yield MuSTCRNNTModule( | ||
mustc_path="mustc_path", | ||
sp_model_path="sp_model_path", | ||
global_stats_path="global_stats_path", | ||
) | ||
|
||
|
||
@skipIfNoModule("pytorch_lightning") | ||
@skipIfNoModule("sentencepiece") | ||
class TestMuSTCRNNTModule(TorchaudioTestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
super().setUpClass() | ||
torch.random.manual_seed(31) | ||
|
||
@parameterized.expand( | ||
[ | ||
("training_step", "train_dataloader"), | ||
("validation_step", "val_dataloader"), | ||
("test_step", "test_common_dataloader"), | ||
("test_step", "test_he_dataloader"), | ||
] | ||
) | ||
def test_step(self, step_fname, dataloader_fname): | ||
with get_lightning_module() as lightning_module: | ||
dataloader = getattr(lightning_module, dataloader_fname)() | ||
batch = next(iter(dataloader)) | ||
getattr(lightning_module, step_fname)(batch, 0) | ||
|
||
@parameterized.expand( | ||
[ | ||
("val_dataloader",), | ||
] | ||
) | ||
def test_forward(self, dataloader_fname): | ||
with get_lightning_module() as lightning_module: | ||
dataloader = getattr(lightning_module, dataloader_fname)() | ||
batch = next(iter(dataloader)) | ||
lightning_module(batch) |
80 changes: 80 additions & 0 deletions
80
test/torchaudio_unittest/example/emformer_rnnt/test_tedlium3_lightning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from contextlib import contextmanager | ||
from functools import partial | ||
from unittest.mock import patch | ||
|
||
import torch | ||
from parameterized import parameterized | ||
from torchaudio._internal.module_utils import is_module_available | ||
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule | ||
|
||
from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader | ||
|
||
if is_module_available("pytorch_lightning", "sentencepiece"): | ||
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule | ||
|
||
|
||
class MockTEDLIUM: | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def __getitem__(self, n: int): | ||
return ( | ||
torch.rand(1, 32640), | ||
16000, | ||
"sup", | ||
2, | ||
3, | ||
4, | ||
) | ||
|
||
def __len__(self): | ||
return 10 | ||
|
||
|
||
@contextmanager | ||
def get_lightning_module(): | ||
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch( | ||
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity | ||
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch( | ||
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset | ||
), patch( | ||
"torch.utils.data.DataLoader", new=MockDataloader | ||
): | ||
yield TEDLIUM3RNNTModule( | ||
tedlium_path="tedlium_path", | ||
sp_model_path="sp_model_path", | ||
global_stats_path="global_stats_path", | ||
) | ||
|
||
|
||
@skipIfNoModule("pytorch_lightning") | ||
@skipIfNoModule("sentencepiece") | ||
class TestTEDLIUM3RNNTModule(TorchaudioTestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
super().setUpClass() | ||
torch.random.manual_seed(31) | ||
|
||
@parameterized.expand( | ||
[ | ||
("training_step", "train_dataloader"), | ||
("validation_step", "val_dataloader"), | ||
("test_step", "test_dataloader"), | ||
] | ||
) | ||
def test_step(self, step_fname, dataloader_fname): | ||
with get_lightning_module() as lightning_module: | ||
dataloader = getattr(lightning_module, dataloader_fname)() | ||
batch = next(iter(dataloader)) | ||
getattr(lightning_module, step_fname)(batch, 0) | ||
|
||
@parameterized.expand( | ||
[ | ||
("val_dataloader",), | ||
] | ||
) | ||
def test_forward(self, dataloader_fname): | ||
with get_lightning_module() as lightning_module: | ||
dataloader = getattr(lightning_module, dataloader_fname)() | ||
batch = next(iter(dataloader)) | ||
lightning_module(batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
class MockSentencePieceProcessor: | ||
def __init__(self, num_symbols, *args, **kwargs): | ||
self.num_symbols = num_symbols | ||
|
||
def get_piece_size(self): | ||
return self.num_symbols | ||
|
||
def encode(self, input): | ||
return [1, 5, 2] | ||
|
||
def decode(self, input): | ||
return "hey" | ||
|
||
def unk_id(self): | ||
return 0 | ||
|
||
def eos_id(self): | ||
return 1 | ||
|
||
def pad_id(self): | ||
return 2 | ||
|
||
|
||
class MockCustomDataset: | ||
def __init__(self, base_dataset, *args, **kwargs): | ||
self.base_dataset = base_dataset | ||
|
||
def __getitem__(self, n: int): | ||
return [self.base_dataset[n]] | ||
|
||
def __len__(self): | ||
return len(self.base_dataset) | ||
|
||
|
||
class MockDataloader: | ||
def __init__(self, base_dataset, batch_size, collate_fn, *args, **kwargs): | ||
self.base_dataset = base_dataset | ||
self.batch_size = batch_size | ||
self.collate_fn = collate_fn | ||
|
||
def __iter__(self): | ||
for sample in iter(self.base_dataset): | ||
if self.batch_size == 1: | ||
sample = [sample] | ||
yield self.collate_fn(sample) | ||
|
||
def __len__(self): | ||
return len(self.base_dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters