Skip to content

Commit

Permalink
Reduce memory usage in TF building (#24046)
Browse files Browse the repository at this point in the history
* Make the default dummies (2, 2) instead of (3, 3)

* Fix for Funnel

* Actually fix Funnel
  • Loading branch information
Rocketknight1 authored Jun 6, 2023
1 parent 072188d commit 7203ea6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,16 +1116,16 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
dummies = {}
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
# 3 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 3 for dim in spec.shape], dtype=spec.dtype)
# 2 is the most correct arbitrary size. I will not be taking questions
dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
if key == "token_type_ids":
# Some models have token_type_ids but with a vocab_size of 1
dummies[key] = tf.zeros_like(dummies[key])
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
if "encoder_hidden_states" not in dummies:
if self.main_input_name == "input_ids":
dummies["encoder_hidden_states"] = tf.ones(
shape=(3, 3, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
)
else:
raise NotImplementedError(
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/funnel/modeling_tf_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def get_position_embeds(self, seq_len, training=False):
# rel_pos = tf.broadcast_to(rel_pos, (rel_pos.shape[0], self.d_model))
rel_pos = tf.cast(rel_pos, dtype=zero_offset.dtype)
rel_pos = rel_pos + zero_offset
tf.debugging.assert_less(rel_pos, tf.shape(pos_embed)[0])
position_embeds_no_pooling = tf.gather(pos_embed, rel_pos, axis=0)

position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
Expand Down Expand Up @@ -974,6 +975,11 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
config_class = FunnelConfig
base_model_prefix = "funnel"

@property
def dummy_inputs(self):
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}


@dataclass
class TFFunnelForPreTrainingOutput(ModelOutput):
Expand Down Expand Up @@ -1424,6 +1430,10 @@ def __init__(self, config: FunnelConfig, *inputs, **kwargs) -> None:
self.funnel = TFFunnelBaseLayer(config, name="funnel")
self.classifier = TFFunnelClassificationHead(config, 1, name="classifier")

@property
def dummy_inputs(self):
return {"input_ids": tf.ones((3, 3, 4), dtype=tf.int32)}

@unpack_inputs
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
Expand Down

0 comments on commit 7203ea6

Please sign in to comment.