Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better TF docstring types #23477

Merged
merged 9 commits into from
May 24, 2023
190 changes: 96 additions & 94 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -43,8 +45,8 @@ class TFBaseModelOutput(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -96,8 +98,8 @@ class TFBaseModelOutputWithPooling(ModelOutput):

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -164,10 +166,10 @@ class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -201,9 +203,9 @@ class TFBaseModelOutputWithPast(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -234,9 +236,9 @@ class TFBaseModelOutputWithCrossAttentions(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -276,10 +278,10 @@ class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -333,13 +335,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -365,10 +367,10 @@ class TFCausalLMOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -400,11 +402,11 @@ class TFCausalLMOutputWithPast(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -442,12 +444,12 @@ class TFCausalLMOutputWithCrossAttentions(ModelOutput):
`past_key_values` input) to speed up sequential decoding.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -473,10 +475,10 @@ class TFMaskedLMOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -527,15 +529,15 @@ class TFSeq2SeqLMOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -562,10 +564,10 @@ class TFNextSentencePredictorOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -591,10 +593,10 @@ class TFSequenceClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -642,15 +644,15 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
cross_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -684,10 +686,10 @@ class TFSemanticSegmenterOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -716,9 +718,9 @@ class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -742,10 +744,10 @@ class TFImageClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -773,10 +775,10 @@ class TFMultipleChoiceModelOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -802,10 +804,10 @@ class TFTokenClassifierOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -833,11 +835,11 @@ class TFQuestionAnsweringModelOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -884,15 +886,15 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
self-attention heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
decoder_hidden_states: Tuple[tf.Tensor] | None = None
decoder_attentions: Tuple[tf.Tensor] | None = None
encoder_last_hidden_state: tf.Tensor | None = None
encoder_hidden_states: Tuple[tf.Tensor] | None = None
encoder_attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand Down Expand Up @@ -924,11 +926,11 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
past_key_values: List[tf.Tensor] | None = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None


@dataclass
Expand All @@ -947,7 +949,7 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
feature maps) of the model at the output of each stage.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None

Expand All @@ -974,10 +976,10 @@ class TFMaskedImageModelingOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
loss: tf.Tensor | None = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None

@property
def logits(self):
Expand Down
Loading