Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvEmformer module #2358

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/prototype.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~

.. autofunction:: conformer_rnnt_base

ConvEmformer
~~~~~~~~~~~~

.. autoclass:: ConvEmformer

.. automethod:: forward

.. automethod:: infer

References
~~~~~~~~~~

.. footbibliography::
9 changes: 9 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ @inproceedings{shi2021emformer
pages={6783-6787},
year={2021}
}
@inproceedings{9747706,
author={Shi, Yangyang and Wu, Chunyang and Wang, Dilin and Xiao, Alex and Mahadeokar, Jay and Zhang, Xiaohui and Liu, Chunxi and Li, Ke and Shangguan, Yuan and Nagaraja, Varun and Kalinli, Ozlem and Seltzer, Mike},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution},
year={2022},
volume={},
number={},
pages={8277-8281},
doi={10.1109/ICASSP43922.2022.9747706}}
@article{mises1929praktische,
title={Praktische Verfahren der Gleichungsaufl{\"o}sung.},
author={Mises, RV and Pollaczek-Geiringer, Hilda},
Expand Down
74 changes: 43 additions & 31 deletions test/torchaudio_unittest/models/emformer/emformer_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
from abc import ABC, abstractmethod

import torch
from torchaudio.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script


class EmformerTestImpl(TestBaseMixin):
def _gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
class EmformerTestMixin(ABC):
@abstractmethod
def gen_model(self, input_dim, right_context_length):
pass

def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
@abstractmethod
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
pass

def setUp(self):
super().setUp()
Expand All @@ -35,8 +25,8 @@ def test_torchscript_consistency_forward(self):
num_frames = 400
right_context_length = 1

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
scripted = torch_script(emformer)

ref_out, ref_len = emformer(input, lengths)
Expand All @@ -52,12 +42,12 @@ def test_torchscript_consistency_infer(self):
num_frames = 5
right_context_length = 1

emformer = self._gen_model(input_dim, right_context_length).eval()
emformer = self.gen_model(input_dim, right_context_length).eval()
scripted = torch_script(emformer).eval()

ref_state, scripted_state = None, None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
Expand All @@ -71,8 +61,8 @@ def test_output_shape_forward(self):
num_frames = 123
right_context_length = 9

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)

output, output_lengths = emformer(input, lengths)

Expand All @@ -86,11 +76,11 @@ def test_output_shape_infer(self):
num_frames = 6
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length).eval()
emformer = self.gen_model(input_dim, right_context_length).eval()

state = None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
self.assertEqual((batch_size,), output_lengths.shape)
Expand All @@ -102,8 +92,8 @@ def test_output_lengths_forward(self):
num_frames = 123
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths = emformer(input, lengths)
self.assertEqual(lengths, output_lengths)

Expand All @@ -114,7 +104,29 @@ def test_output_lengths_infer(self):
num_frames = 6
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length).eval()
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths)


class EmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer

def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl


class ConvEmformerFloat32CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")


class ConvEmformerFloat64CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
15 changes: 15 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_gpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl


@skipIfNoCuda
class ConvEmformerFloat32GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")


@skipIfNoCuda
class ConvEmformerFloat64GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
27 changes: 27 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_test_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torchaudio.prototype.models.conv_emformer import ConvEmformer
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin


class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = ConvEmformer(
input_dim,
8,
256,
3,
4,
12,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer

def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
Loading