Skip to content

Commit

Permalink
Add ConvEmformer module (#2358)
Browse files Browse the repository at this point in the history
Summary:
Adds an implementation of the convolution-augmented streaming transformer (effectively Emformer with convolution block) described in https://arxiv.org/abs/2110.05241.

Continuation of #2324.

Pull Request resolved: #2358

Reviewed By: nateanl, xiaohui-zhang

Differential Revision: D36137992

Pulled By: hwangjeff

fbshipit-source-id: 9c7a7c233944fe9ef15b9ba397d7f0809da1f063
  • Loading branch information
hwangjeff authored and facebook-github-bot committed May 10, 2022
1 parent 2f4eb4a commit 2c79b55
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 106 deletions.
14 changes: 14 additions & 0 deletions docs/source/prototype.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,17 @@ conformer_rnnt_base
~~~~~~~~~~~~~~~~~~~

.. autofunction:: conformer_rnnt_base

ConvEmformer
~~~~~~~~~~~~

.. autoclass:: ConvEmformer

.. automethod:: forward

.. automethod:: infer

References
~~~~~~~~~~

.. footbibliography::
9 changes: 9 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ @inproceedings{shi2021emformer
pages={6783-6787},
year={2021}
}
@inproceedings{9747706,
author={Shi, Yangyang and Wu, Chunyang and Wang, Dilin and Xiao, Alex and Mahadeokar, Jay and Zhang, Xiaohui and Liu, Chunxi and Li, Ke and Shangguan, Yuan and Nagaraja, Varun and Kalinli, Ozlem and Seltzer, Mike},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution},
year={2022},
volume={},
number={},
pages={8277-8281},
doi={10.1109/ICASSP43922.2022.9747706}}
@article{mises1929praktische,
title={Praktische Verfahren der Gleichungsaufl{\"o}sung.},
author={Mises, RV and Pollaczek-Geiringer, Hilda},
Expand Down
74 changes: 43 additions & 31 deletions test/torchaudio_unittest/models/emformer/emformer_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
from abc import ABC, abstractmethod

import torch
from torchaudio.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script


class EmformerTestImpl(TestBaseMixin):
def _gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer
class EmformerTestMixin(ABC):
@abstractmethod
def gen_model(self, input_dim, right_context_length):
pass

def _gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
@abstractmethod
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
pass

def setUp(self):
super().setUp()
Expand All @@ -35,8 +25,8 @@ def test_torchscript_consistency_forward(self):
num_frames = 400
right_context_length = 1

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
scripted = torch_script(emformer)

ref_out, ref_len = emformer(input, lengths)
Expand All @@ -52,12 +42,12 @@ def test_torchscript_consistency_infer(self):
num_frames = 5
right_context_length = 1

emformer = self._gen_model(input_dim, right_context_length).eval()
emformer = self.gen_model(input_dim, right_context_length).eval()
scripted = torch_script(emformer).eval()

ref_state, scripted_state = None, None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
ref_out, ref_len, ref_state = emformer.infer(input, lengths, ref_state)
scripted_out, scripted_len, scripted_state = scripted.infer(input, lengths, scripted_state)
self.assertEqual(ref_out, scripted_out)
Expand All @@ -71,8 +61,8 @@ def test_output_shape_forward(self):
num_frames = 123
right_context_length = 9

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)

output, output_lengths = emformer(input, lengths)

Expand All @@ -86,11 +76,11 @@ def test_output_shape_infer(self):
num_frames = 6
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length).eval()
emformer = self.gen_model(input_dim, right_context_length).eval()

state = None
for _ in range(3):
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
output, output_lengths, state = emformer.infer(input, lengths, state)
self.assertEqual((batch_size, num_frames - right_context_length, input_dim), output.shape)
self.assertEqual((batch_size,), output_lengths.shape)
Expand All @@ -102,8 +92,8 @@ def test_output_lengths_forward(self):
num_frames = 123
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length)
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length)
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths = emformer(input, lengths)
self.assertEqual(lengths, output_lengths)

Expand All @@ -114,7 +104,29 @@ def test_output_lengths_infer(self):
num_frames = 6
right_context_length = 2

emformer = self._gen_model(input_dim, right_context_length).eval()
input, lengths = self._gen_inputs(input_dim, batch_size, num_frames, right_context_length)
emformer = self.gen_model(input_dim, right_context_length).eval()
input, lengths = self.gen_inputs(input_dim, batch_size, num_frames, right_context_length)
_, output_lengths, _ = emformer.infer(input, lengths)
self.assertEqual(torch.clamp(lengths - right_context_length, min=0), output_lengths)


class EmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = Emformer(
input_dim,
8,
256,
3,
4,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer

def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
13 changes: 13 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl


class ConvEmformerFloat32CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")


class ConvEmformerFloat64CPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
15 changes: 15 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_gpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl


@skipIfNoCuda
class ConvEmformerFloat32GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")


@skipIfNoCuda
class ConvEmformerFloat64GPUTest(ConvEmformerTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
27 changes: 27 additions & 0 deletions test/torchaudio_unittest/prototype/conv_emformer_test_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torchaudio.prototype.models.conv_emformer import ConvEmformer
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin


class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin):
def gen_model(self, input_dim, right_context_length):
emformer = ConvEmformer(
input_dim,
8,
256,
3,
4,
12,
left_context_length=30,
right_context_length=right_context_length,
max_memory_size=1,
).to(device=self.device, dtype=self.dtype)
return emformer

def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length):
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to(
device=self.device, dtype=self.dtype
)
return input, lengths
Loading

0 comments on commit 2c79b55

Please sign in to comment.