Skip to content

Commit

Permalink
[T5] Fix init in TF and Flax for pretraining (huggingface#17294)
Browse files Browse the repository at this point in the history
* fix init

* Apply suggestions from code review

* fix

* finish

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and elusenji committed Jun 12, 2022
1 parent 6b9d938 commit 286f9b1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ def _init_weights(self, module):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseReluDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,9 @@ def _shift_right(self, input_ids):
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down Expand Up @@ -1259,8 +1261,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model

self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down Expand Up @@ -1600,7 +1603,9 @@ def _reorder_cache(self, past, beam_idx):
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down

0 comments on commit 286f9b1

Please sign in to comment.