Skip to content

Commit

Permalink
Add unit tests for PyTorch Lightning modules of emformer_rnnt recipes (
Browse files Browse the repository at this point in the history
…#2240)

Summary:
- Refactor the current `LibriSpeechRNNTModule`'s unit test.
- Add unit tests for `TEDLIUM3RNNTModule` and `MuSTCRNNTModule`
- Replace the lambda with partial in `TEDLIUM3RNNTModule` to pass the lightning unit test.

Pull Request resolved: #2240

Reviewed By: mthrok

Differential Revision: D34285195

Pulled By: nateanl

fbshipit-source-id: 4f20749c85ddd25cbb0eafc1733c64212542338f
  • Loading branch information
nateanl authored and facebook-github-bot committed Feb 17, 2022
1 parent c5c4bbf commit b5d77b1
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 67 deletions.
17 changes: 9 additions & 8 deletions examples/asr/emformer_rnnt/tedlium3/lightning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import List

import sentencepiece as spm
Expand Down Expand Up @@ -86,20 +87,20 @@ def __init__(
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
FunctionalModule(lambda x: x.transpose(1, 2)),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)

self.tedlium_path = tedlium_path
Expand Down Expand Up @@ -197,8 +198,8 @@ def training_step(self, batch: Batch, batch_idx):
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")

def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
def test_step(self, batch_tuple, batch_idx):
return self._step(batch_tuple[0], batch_idx, "test")

def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,18 @@
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch

import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule

from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader

if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.librispeech.lightning import LibriSpeechRNNTModule


class MockSentencePieceProcessor:
def __init__(self, *args, **kwargs):
pass

def get_piece_size(self):
return 4096

def encode(self, input):
return [1, 5, 2]

def decode(self, input):
return "hey"

def unk_id(self):
return 0

def eos_id(self):
return 1

def pad_id(self):
return 2


class MockLIBRISPEECH:
def __init__(self, *args, **kwargs):
pass
Expand All @@ -50,23 +31,16 @@ def __len__(self):
return 10


class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset

def __getitem__(self, n: int):
return [self.base_dataset[n]]

def __len__(self):
return len(self.base_dataset)


@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=MockSentencePieceProcessor), patch(
"asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH), patch(
with patch(
"sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=4096)
), patch("asr.emformer_rnnt.librispeech.lightning.GlobalStatsNormalization", new=torch.nn.Identity), patch(
"torchaudio.datasets.LIBRISPEECH", new=MockLIBRISPEECH
), patch(
"asr.emformer_rnnt.librispeech.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield LibriSpeechRNNTModule(
librispeech_path="librispeech_path",
Expand All @@ -80,28 +54,29 @@ def get_lightning_module():
class TestLibriSpeechRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)

def test_training_step(self):
@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
train_dataloader = lightning_module.train_dataloader()
batch = next(iter(train_dataloader))
lightning_module.training_step(batch, 0)

def test_validation_step(self):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
lightning_module.validation_step(batch, 0)

def test_test_step(self):
with get_lightning_module() as lightning_module:
test_dataloader = lightning_module.test_dataloader()
batch = next(iter(test_dataloader))
lightning_module.test_step(batch, 0)

def test_forward(self):
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)

@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
val_dataloader = lightning_module.val_dataloader()
batch = next(iter(val_dataloader))
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch

import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule

from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader

if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.mustc.lightning import MuSTCRNNTModule


class MockMUSTC:
def __init__(self, *args, **kwargs):
pass

def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
"sup",
)

def __len__(self):
return 10


@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.mustc.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("asr.emformer_rnnt.mustc.lightning.MUSTC", new=MockMUSTC), patch(
"asr.emformer_rnnt.mustc.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield MuSTCRNNTModule(
mustc_path="mustc_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)


@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestMuSTCRNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)

@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_common_dataloader"),
("test_step", "test_he_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)

@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from contextlib import contextmanager
from functools import partial
from unittest.mock import patch

import torch
from parameterized import parameterized
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import TorchaudioTestCase, skipIfNoModule

from .utils import MockSentencePieceProcessor, MockCustomDataset, MockDataloader

if is_module_available("pytorch_lightning", "sentencepiece"):
from asr.emformer_rnnt.tedlium3.lightning import TEDLIUM3RNNTModule


class MockTEDLIUM:
def __init__(self, *args, **kwargs):
pass

def __getitem__(self, n: int):
return (
torch.rand(1, 32640),
16000,
"sup",
2,
3,
4,
)

def __len__(self):
return 10


@contextmanager
def get_lightning_module():
with patch("sentencepiece.SentencePieceProcessor", new=partial(MockSentencePieceProcessor, num_symbols=500)), patch(
"asr.emformer_rnnt.tedlium3.lightning.GlobalStatsNormalization", new=torch.nn.Identity
), patch("torchaudio.datasets.TEDLIUM", new=MockTEDLIUM), patch(
"asr.emformer_rnnt.tedlium3.lightning.CustomDataset", new=MockCustomDataset
), patch(
"torch.utils.data.DataLoader", new=MockDataloader
):
yield TEDLIUM3RNNTModule(
tedlium_path="tedlium_path",
sp_model_path="sp_model_path",
global_stats_path="global_stats_path",
)


@skipIfNoModule("pytorch_lightning")
@skipIfNoModule("sentencepiece")
class TestTEDLIUM3RNNTModule(TorchaudioTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
torch.random.manual_seed(31)

@parameterized.expand(
[
("training_step", "train_dataloader"),
("validation_step", "val_dataloader"),
("test_step", "test_dataloader"),
]
)
def test_step(self, step_fname, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
getattr(lightning_module, step_fname)(batch, 0)

@parameterized.expand(
[
("val_dataloader",),
]
)
def test_forward(self, dataloader_fname):
with get_lightning_module() as lightning_module:
dataloader = getattr(lightning_module, dataloader_fname)()
batch = next(iter(dataloader))
lightning_module(batch)
48 changes: 48 additions & 0 deletions test/torchaudio_unittest/example/emformer_rnnt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
class MockSentencePieceProcessor:
def __init__(self, num_symbols, *args, **kwargs):
self.num_symbols = num_symbols

def get_piece_size(self):
return self.num_symbols

def encode(self, input):
return [1, 5, 2]

def decode(self, input):
return "hey"

def unk_id(self):
return 0

def eos_id(self):
return 1

def pad_id(self):
return 2


class MockCustomDataset:
def __init__(self, base_dataset, *args, **kwargs):
self.base_dataset = base_dataset

def __getitem__(self, n: int):
return [self.base_dataset[n]]

def __len__(self):
return len(self.base_dataset)


class MockDataloader:
def __init__(self, base_dataset, batch_size, collate_fn, *args, **kwargs):
self.base_dataset = base_dataset
self.batch_size = batch_size
self.collate_fn = collate_fn

def __iter__(self):
for sample in iter(self.base_dataset):
if self.batch_size == 1:
sample = [sample]
yield self.collate_fn(sample)

def __len__(self):
return len(self.base_dataset)
2 changes: 1 addition & 1 deletion torchaudio/pipelines/rnnt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_token_processor(self) -> TokenProcessor:
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments.
`here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with default arguments.
Please refer to :py:class:`RNNTBundle` for usage instructions.
"""
2 changes: 1 addition & 1 deletion torchaudio/prototype/pipelines/rnnt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on TED-LIUM Release 3 dataset using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/tedlium3_emformer_rnnt>`__ with ``num_symbols=501``.
`here <https://github.com/pytorch/audio/tree/main/examples/asr/emformer_rnnt>`__ with ``num_symbols=501``.
Please refer to :py:class:`torchaudio.pipelines.RNNTBundle` for usage instructions.
"""

0 comments on commit b5d77b1

Please sign in to comment.