Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangjeff committed May 3, 2022
1 parent 51628db commit 9c984f8
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod

import torch
from torchaudio.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torchaudio.prototype.models.conv_emformer import ConvEmformer
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestMixin


class ConvEmformerTestImpl(EmformerTestMixin, TestBaseMixin):
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/prototype/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
from .conv_emformer import ConvEmformer
from .rnnt import conformer_rnnt_base, conformer_rnnt_model

__all__ = [
"conformer_rnnt_base",
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/prototype/models/conv_emformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _split_right_context(self, utterance: torch.Tensor, right_context: torch.Ten
for seg_idx in range(num_segments):
end_idx = min(self.state_size + (seg_idx + 1) * self.segment_length, utterance.size(0))
start_idx = end_idx - self.state_size
pad_segments.append(utterance[start_idx: end_idx, :, :])
pad_segments.append(utterance[start_idx:end_idx, :, :])

pad_segments = torch.cat(pad_segments, dim=1).permute(1, 0, 2) # (num_segments * B, kernel_size - 1, D)
return torch.cat([pad_segments, right_context_segments], dim=1).permute(0, 2, 1)
Expand Down

0 comments on commit 9c984f8

Please sign in to comment.