diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py index 405aeff6e0bdd5..0be2d291923812 100644 --- a/src/transformers/models/convnext/modeling_tf_convnext.py +++ b/src/transformers/models/convnext/modeling_tf_convnext.py @@ -330,7 +330,8 @@ def call( hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] + hidden_states = hidden_states if output_hidden_states else () + return (last_hidden_state, pooled_output) + hidden_states return TFBaseModelOutputWithPooling( last_hidden_state=last_hidden_state,