Skip to content

Commit

Permalink
Add an option to use Tanh instead of ReLU in RNNT joiner (pytorch#2319)
Browse files Browse the repository at this point in the history
Summary:
Add an option to use Tanh instead of ReLU in RNNT joiner, which enables better training performance sometimes.

 ---

Pull Request resolved: pytorch#2319

Reviewed By: hwangjeff

Differential Revision: D35422122

Pulled By: xiaohui-zhang

fbshipit-source-id: c6a0f8b25936e47081110af046b57d0e8751f9a2
  • Loading branch information
xiaohui-zhang committed May 4, 2022
1 parent cccbf51 commit 4075644
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torchaudio/models/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,20 @@ class _Joiner(torch.nn.Module):
Args:
input_dim (int): source and target input dimension.
output_dim (int): output dimension.
activation (str, optional): activation function to use in the joiner
Must be one of ("relu", "tanh"). (Default: "relu")
"""

def __init__(self, input_dim: int, output_dim: int) -> None:
def __init__(self, input_dim: int, output_dim: int, activation: str = "relu") -> None:
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
self.relu = torch.nn.ReLU()
if activation == "relu":
self.activation = torch.nn.ReLU()
elif activation == "tanh":
self.activation = torch.nn.Tanh()
else:
raise ValueError(f"Unsupported activation {activation}")

def forward(
self,
Expand Down Expand Up @@ -419,8 +427,8 @@ def forward(
number of valid elements along dim 2 for i-th batch element in joint network output.
"""
joint_encodings = source_encodings.unsqueeze(2).contiguous() + target_encodings.unsqueeze(1).contiguous()
relu_out = self.relu(joint_encodings)
output = self.linear(relu_out)
activation_out = self.activation(joint_encodings)
output = self.linear(activation_out)
return output, source_lengths, target_lengths


Expand Down

0 comments on commit 4075644

Please sign in to comment.