diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index bf229faade9fb9..7ca249eac7c4f0 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -25,7 +25,13 @@ from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput -from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable +from ...modeling_tf_utils import ( + TFPreTrainedModel, + booleans_processing, + get_initializer, + keras_serializable, + unpack_inputs, +) from ...tf_utils import shape_list, stable_softmax from ...utils import ( ModelOutput, @@ -1580,6 +1586,7 @@ def freeze_feature_encoder(self): """ self.wav2vec2.feature_extractor.trainable = False + @unpack_inputs @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1702,6 +1709,8 @@ def call( loss = tf.reduce_sum(loss) if self.config.ctc_loss_reduction == "mean": loss = tf.reduce_mean(loss) + + loss = tf.reshape(loss, (1,)) else: loss = None