Skip to content

Commit

Permalink
fixed calculation of ctc loss in TFWav2Vec2ForCTC (#18014)
Browse files Browse the repository at this point in the history
Co-authored-by: Sreyan-G@NVIDIA <sreyang@nvidia.com>
  • Loading branch information
Sreyan88 and SreyanG-NVIDIA authored Jul 4, 2022
1 parent 96d833b commit e3139ad
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -1580,6 +1586,7 @@ def freeze_feature_encoder(self):
"""
self.wav2vec2.feature_extractor.trainable = False

@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1702,6 +1709,8 @@ def call(
loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss)

loss = tf.reshape(loss, (1,))
else:
loss = None

Expand Down

0 comments on commit e3139ad

Please sign in to comment.