-
Notifications
You must be signed in to change notification settings - Fork 657
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Adds an implementation of the convolution-augmented streaming transformer (effectively Emformer with convolution block) described in https://arxiv.org/abs/2110.05241. Continuation of #2324. Pull Request resolved: #2358 Reviewed By: nateanl, xiaohui-zhang Differential Revision: D36137992 Pulled By: hwangjeff fbshipit-source-id: 9c7a7c233944fe9ef15b9ba397d7f0809da1f063
- Loading branch information
1 parent
2f4eb4a
commit 2c79b55
Showing
9 changed files
with
306 additions
and
106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
test/torchaudio_unittest/prototype/conv_emformer_cpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch | ||
from torchaudio_unittest.common_utils import PytorchTestCase | ||
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl | ||
|
||
|
||
class ConvEmformerFloat32CPUTest(ConvEmformerTestImpl, PytorchTestCase): | ||
dtype = torch.float32 | ||
device = torch.device("cpu") | ||
|
||
|
||
class ConvEmformerFloat64CPUTest(ConvEmformerTestImpl, PytorchTestCase): | ||
dtype = torch.float64 | ||
device = torch.device("cpu") |
15 changes: 15 additions & 0 deletions
15
test/torchaudio_unittest/prototype/conv_emformer_gpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase | ||
from torchaudio_unittest.prototype.conv_emformer_test_impl import ConvEmformerTestImpl | ||
|
||
|
||
@skipIfNoCuda | ||
class ConvEmformerFloat32GPUTest(ConvEmformerTestImpl, PytorchTestCase): | ||
dtype = torch.float32 | ||
device = torch.device("cuda") | ||
|
||
|
||
@skipIfNoCuda | ||
class ConvEmformerFloat64GPUTest(ConvEmformerTestImpl, PytorchTestCase): | ||
dtype = torch.float64 | ||
device = torch.device("cuda") |
27 changes: 27 additions & 0 deletions
27
test/torchaudio_unittest/prototype/conv_emformer_test_impl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
from torchaudio.prototype.models.conv_emformer import ConvEmformer | ||
from torchaudio_unittest.common_utils import TestBaseMixin | ||
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin | ||
|
||
|
||
class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin): | ||
def gen_model(self, input_dim, right_context_length): | ||
emformer = ConvEmformer( | ||
input_dim, | ||
8, | ||
256, | ||
3, | ||
4, | ||
12, | ||
left_context_length=30, | ||
right_context_length=right_context_length, | ||
max_memory_size=1, | ||
).to(device=self.device, dtype=self.dtype) | ||
return emformer | ||
|
||
def gen_inputs(self, input_dim, batch_size, num_frames, right_context_length): | ||
input = torch.rand(batch_size, num_frames, input_dim).to(device=self.device, dtype=self.dtype) | ||
lengths = torch.randint(1, num_frames - right_context_length, (batch_size,)).to( | ||
device=self.device, dtype=self.dtype | ||
) | ||
return input, lengths |
Oops, something went wrong.