Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper build() methods for TF #27794

Merged
merged 83 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
778e8d4
Add a convenience method for building in your own name scope
Rocketknight1 Dec 1, 2023
078d6d1
Second attempt at auto layer building
Rocketknight1 Dec 1, 2023
560f1fb
Revert "Second attempt at auto layer building"
Rocketknight1 Dec 4, 2023
f226cb8
Attempt #3
Rocketknight1 Dec 4, 2023
b85e78e
Revert "Attempt #3"
Rocketknight1 Dec 4, 2023
fcb197c
Add missing attributes that we're going to need later
Rocketknight1 Dec 4, 2023
80ab7a2
Add some attributes we're going to need later
Rocketknight1 Dec 4, 2023
23a01d6
A fourth attempt! Feel the power flow through you!
Rocketknight1 Dec 4, 2023
3512ad8
Revert "A fourth attempt! Feel the power flow through you!"
Rocketknight1 Dec 4, 2023
2fc9c52
Add more values we'll need later
Rocketknight1 Dec 4, 2023
1598499
TF refactor that we'll need later
Rocketknight1 Dec 4, 2023
08f6920
Revert "TF refactor that we'll need later"
Rocketknight1 Dec 4, 2023
fda421c
Revert "Revert "TF refactor that we'll need later""
Rocketknight1 Dec 4, 2023
6d3d21b
make fixup
Rocketknight1 Dec 4, 2023
4e05867
Attempt five!
Rocketknight1 Dec 4, 2023
74fa85e
Revert "Attempt five!"
Rocketknight1 Dec 4, 2023
ff32320
Attempt six - this time don't add empty methods
Rocketknight1 Dec 4, 2023
eaa2657
Revert "Attempt six - this time don't add empty methods"
Rocketknight1 Dec 4, 2023
fb341eb
Attempt seven - better base model class detection!
Rocketknight1 Dec 4, 2023
454b4e9
Revert "Attempt seven - better base model class detection!"
Rocketknight1 Dec 4, 2023
3acfb9b
Another attribute we'll need later
Rocketknight1 Dec 4, 2023
6107b79
Try again with the missing attribute!
Rocketknight1 Dec 4, 2023
0d1f682
Revert "Try again with the missing attribute!"
Rocketknight1 Dec 5, 2023
e3cc28e
This is the attempt that will pierce the heavens!
Rocketknight1 Dec 5, 2023
01de371
Revert "This is the attempt that will pierce the heavens!"
Rocketknight1 Dec 5, 2023
2de4263
Attempt seven - snag list is steadily decreasing
Rocketknight1 Dec 5, 2023
f4216fb
Revert "Attempt seven - snag list is steadily decreasing"
Rocketknight1 Dec 5, 2023
d58b13d
Attempt eight - will an empty snag list do it?
Rocketknight1 Dec 5, 2023
0f1fdb3
Revert "Attempt eight - will an empty snag list do it?"
Rocketknight1 Dec 5, 2023
3086783
Fixes to Hubert issues that cause problems later
Rocketknight1 Dec 5, 2023
f632a59
Trying again with Conv1D/SeparableConv fixes
Rocketknight1 Dec 5, 2023
6a33350
Revert "Trying again with Conv1D/SeparableConv fixes"
Rocketknight1 Dec 5, 2023
f3c5145
Apply the build shape fixes to Wav2Vec2 as well
Rocketknight1 Dec 5, 2023
a4cd293
One more attempt!
Rocketknight1 Dec 5, 2023
8e49790
Revert "One more attempt!"
Rocketknight1 Dec 5, 2023
a169fd2
Another attempt!
Rocketknight1 Dec 5, 2023
c5175d9
Revert "Another attempt!"
Rocketknight1 Dec 5, 2023
7cb3679
Let's see how many failures we get without the internal build method
Rocketknight1 Dec 5, 2023
3da8e67
Fix OpenAI
Rocketknight1 Dec 5, 2023
04264f4
Fix MobileBERT
Rocketknight1 Dec 5, 2023
cbcd61e
(Mostly) fix GroupVIT
Rocketknight1 Dec 5, 2023
a814073
Fix BLIP
Rocketknight1 Dec 5, 2023
355de20
One more BLIP fix
Rocketknight1 Dec 5, 2023
2ea0f87
One more BLIP fix!
Rocketknight1 Dec 5, 2023
c771a90
Fix Regnet
Rocketknight1 Dec 5, 2023
6ec68e0
Finally fully fix GroupViT
Rocketknight1 Dec 5, 2023
a421239
Fix Data2Vec and add the new AdaptivePool
Rocketknight1 Dec 6, 2023
d619185
Fix Segformer
Rocketknight1 Dec 6, 2023
790222c
Fix Albert
Rocketknight1 Dec 6, 2023
7e6f7ce
Fix Deberta/DebertaV2
Rocketknight1 Dec 6, 2023
a858347
Fix XLM
Rocketknight1 Dec 6, 2023
d73fbbe
Actually fix XLM
Rocketknight1 Dec 6, 2023
46d3e66
Fix Flaubert
Rocketknight1 Dec 6, 2023
79b5c42
Fix lxmert
Rocketknight1 Dec 6, 2023
6591d93
Fix Resnet
Rocketknight1 Dec 6, 2023
f58a28f
Fix ConvBERT
Rocketknight1 Dec 6, 2023
47fd55e
Fix ESM
Rocketknight1 Dec 6, 2023
53f8c3e
Fix Convnext / ConvnextV2
Rocketknight1 Dec 6, 2023
2b3badc
Fix SAM
Rocketknight1 Dec 6, 2023
c6d784d
Fix Efficientformer
Rocketknight1 Dec 6, 2023
8aca4d6
Fix LayoutLMv3
Rocketknight1 Dec 6, 2023
843b021
Fix speech_to_text
Rocketknight1 Dec 6, 2023
eba42e5
Fix mpnet and mobilevit
Rocketknight1 Dec 6, 2023
73a71c2
Fix Swin
Rocketknight1 Dec 6, 2023
8dc4808
Fix CTRL
Rocketknight1 Dec 6, 2023
54ce3a9
Fix CVT
Rocketknight1 Dec 6, 2023
a560a33
Fix DPR
Rocketknight1 Dec 6, 2023
d59aaed
Fix Wav2Vec2
Rocketknight1 Dec 6, 2023
08a72b6
Fix T5
Rocketknight1 Dec 6, 2023
1855891
Fix Hubert
Rocketknight1 Dec 6, 2023
6c25b54
Fix GPT2
Rocketknight1 Dec 6, 2023
3a2c834
Fix Whisper
Rocketknight1 Dec 6, 2023
b7553a2
Fix DeiT
Rocketknight1 Dec 6, 2023
5e6524e
Fix the encoder-decoder / dual-encoder classes
Rocketknight1 Dec 6, 2023
25550fb
make fix-copies
Rocketknight1 Dec 6, 2023
a130cbc
build in name scope
Rocketknight1 Dec 7, 2023
7aebcb4
Merge branch 'main' into proper_tf_weight_building
Rocketknight1 Dec 7, 2023
ef281aa
Fix summarization test
Rocketknight1 Dec 7, 2023
5adb89b
Fix tied weight names for BART + Blenderbot
Rocketknight1 Dec 7, 2023
deff359
Fix tied weight name building
Rocketknight1 Dec 7, 2023
b13ed89
Fix to TFESM weight building
Rocketknight1 Dec 13, 2023
568f9aa
Update TF SAM
Rocketknight1 Dec 13, 2023
5e31b77
Expand all the shapes out into Big Boy Shapes
Rocketknight1 Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from huggingface_hub import Repository, list_repo_files
from keras import backend as K
from packaging.version import parse
from tensorflow.python.util.keras_deps import get_call_context_function

from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
Expand Down Expand Up @@ -1122,6 +1121,10 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
)
return dummies

def build_in_name_scope(self):
with tf.name_scope(self.name):
self.build(input_shape=None)

@property
def framework(self) -> str:
"""
Expand All @@ -1130,15 +1133,7 @@ def framework(self) -> str:
return "tf"

def build(self, input_shape=None):
call_context = get_call_context_function()
if self.built or call_context().in_call:
self.built = True
else:
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False)
pass # This is just here to make sure we don't call the superclass build()

def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down Expand Up @@ -1869,7 +1864,7 @@ def set_input_embeddings(self, value):
main_layer.set_input_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
main_layer.set_input_embeddings(value)

def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
Expand All @@ -1886,7 +1881,7 @@ def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]:
return lm_head.get_output_embeddings()
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()

return lm_head().get_output_embeddings()

Expand All @@ -1906,7 +1901,7 @@ def set_output_embeddings(self, value):
lm_head.set_output_embeddings(value)
except AttributeError:
logger.info("Building the model")
self.build()
self.build_in_name_scope()
lm_head.set_output_embeddings(value)

def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]:
Expand Down Expand Up @@ -1944,7 +1939,7 @@ def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
try:
return lm_head.get_bias()
except AttributeError:
self.build()
self.build_in_name_scope()

return lm_head.get_bias()
return None
Expand All @@ -1962,7 +1957,7 @@ def set_bias(self, value):
try:
lm_head.set_bias(value)
except AttributeError:
self.build()
self.build_in_name_scope()
lm_head.set_bias(value)

def get_lm_head(self) -> tf.keras.layers.Layer:
Expand Down Expand Up @@ -2049,7 +2044,7 @@ def _get_word_embedding_weight(model, embedding_layer):
# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model.build()
model.build_in_name_scope()

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
Expand Down Expand Up @@ -2914,9 +2909,9 @@ def from_pretrained(
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs
else:
model.build() # build the network with dummy inputs
model.build_in_name_scope() # build the network with dummy inputs

if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
Expand Down Expand Up @@ -3215,6 +3210,9 @@ def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
self.initializer_range = initializer_range

def build(self, input_shape):
if self.built:
return
self.built = True
self.weight = self.add_weight(
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
Expand Down Expand Up @@ -3398,6 +3396,7 @@ def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
self.hidden_size = config.hidden_size

def call(self, inputs, cls_index=None, training=False):
if not isinstance(inputs, (dict, tuple, list)):
Expand Down Expand Up @@ -3450,6 +3449,14 @@ def call(self, inputs, cls_index=None, training=False):

return output

def build(self, input_shape):
if self.built:
return
self.built = True
if getattr(self, "summary", None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the significance of "summary" here?

with tf.name_scope("summary"):
self.summary.build(self.hidden_size)


def get_initializer(initializer_range: float = 0.02) -> tf.keras.initializers.TruncatedNormal:
"""
Expand Down
Loading
Loading