Skip to content

Commit

Permalink
Fix importing unofficial TF models with extra optimizer weights
Browse files Browse the repository at this point in the history
  • Loading branch information
monologg authored and LysandreJik committed Feb 7, 2020
1 parent d7dabfe commit 7336896
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/transformers/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
name = name.split("/")

# Ignore the gradients applied by the LAMB/ADAM optimizers.
if "adam_m" in name or "adam_v" in name or "global_step" in name:
if (
"adam_m" in name
or "adam_v" in name
or "AdamWeightDecayOptimizer" in name
or "AdamWeightDecayOptimizer_1" in name
or "global_step" in name
):
logger.info("Skipping {}".format("/".join(name)))
continue

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
name = txt_name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
tf_weights.pop(txt_name, None)
continue
Expand Down
5 changes: 4 additions & 1 deletion templates/adding_a_new_model/modeling_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
Expand Down

0 comments on commit 7336896

Please sign in to comment.