Skip to content

Commit

Permalink
Fix SEW-D implementation differences (#14191)
Browse files Browse the repository at this point in the history
* Fix SEW-D

* Update tests

* isort
  • Loading branch information
anton-l authored Oct 28, 2021
1 parent 78b6a2e commit 1251072
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
5 changes: 3 additions & 2 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger = logging.get_logger(__name__)


def _gelu_python(x):
def gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
Expand All @@ -43,7 +43,7 @@ def gelu_new(x):


if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
gelu = gelu_python
else:
gelu = nn.functional.gelu

Expand Down Expand Up @@ -97,6 +97,7 @@ def linear_act(x):
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_python": gelu_python,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
Expand Down
16 changes: 10 additions & 6 deletions src/transformers/models/sew_d/configuration_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class SEWDConfig(PretrainedConfig):
:obj:`("p2c")`, :obj:`("p2c", "c2p")`, :obj:`("p2c", "c2p", 'p2p")`.
norm_rel_ebd (:obj:`str`, `optional`, defaults to :obj:`"layer_norm"`):
Whether to use layer norm in relative embedding (:obj:`"layer_norm"` if yes)
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_python"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"`, :obj:`"gelu_python"` and :obj:`"gelu_new"` are supported.
hidden_dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.1):
Expand All @@ -78,8 +78,10 @@ class SEWDConfig(PretrainedConfig):
The dropout probability for the final projection layer of :class:`SEWDForCTC`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-7):
The epsilon used by the layer normalization layers in the transformer encoder.
feature_layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon used by the layer normalization after the feature extractor.
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
Expand Down Expand Up @@ -167,15 +169,16 @@ def __init__(
position_biased_input=False,
pos_att_type=("p2c", "c2p"),
norm_rel_ebd="layer_norm",
hidden_act="gelu",
hidden_act="gelu_python",
hidden_dropout=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.0,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
layer_norm_eps=1e-7,
feature_layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_activation="gelu",
conv_dim=(64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512),
Expand Down Expand Up @@ -228,6 +231,7 @@ def __init__(
self.final_dropout = final_dropout
self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps
self.feature_layer_norm_eps = feature_layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,13 +1310,13 @@ def _set_gradient_checkpointing(self, module, value=False):
"The bare SEW-D Model transformer outputting raw hidden-states without any specific head on top.",
SEWD_START_DOCSTRING,
)
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD
# Copied from transformers.models.sew.modeling_sew.SEWModel with SEW->SEWD, layer_norm_eps->feature_layer_norm_eps
class SEWDModel(SEWDPreTrainedModel):
def __init__(self, config: SEWDConfig):
super().__init__(config)
self.config = config
self.feature_extractor = SEWDFeatureExtractor(config)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.feature_layer_norm_eps)

self.project_features = config.conv_dim[-1] != config.hidden_size
if self.project_features:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
if is_torch_available():
import torch

from transformers.activations import _gelu_python, gelu_new, get_activation
from transformers.activations import gelu_new, gelu_python, get_activation


@require_torch
class TestActivations(unittest.TestCase):
def test_gelu_versions(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
self.assertTrue(torch.allclose(_gelu_python(x), torch_builtin(x)))
self.assertFalse(torch.allclose(_gelu_python(x), gelu_new(x)))
self.assertTrue(torch.allclose(gelu_python(x), torch_builtin(x)))
self.assertFalse(torch.allclose(gelu_python(x), gelu_new(x)))

def test_get_activation(self):
get_activation("swish")
Expand All @@ -39,6 +39,7 @@ def test_get_activation(self):
get_activation("tanh")
get_activation("gelu_new")
get_activation("gelu_fast")
get_activation("gelu_python")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,9 @@ def test_inference_pretrained_batched(self):
)
expected_output_sum = 54201.0469

self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 5)
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=1e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=1e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 1)

def test_inference_ctc_batched(self):
model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k-ft-ls100h").to(torch_device)
Expand Down

0 comments on commit 1251072

Please sign in to comment.