Skip to content

Commit

Permalink
fix xlmr bug (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandenchan authored Feb 13, 2020
1 parent 41ef0c0 commit 5e62b71
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion farm/data_handler/input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def sample_to_features_squad(sample, tokenizer, max_seq_len, max_answers=6):
segment_ids = encoded["token_type_ids"]

# seq_2_start_t is the index of the first token in the second text sequence (e.g. passage)
if tokenizer.__class__.__name__ == "RobertaTokenizer":
if tokenizer.__class__.__name__ in ["RobertaTokenizer", "XLMRobertaTokenizer"]:
seq_2_start_t = get_roberta_seq_2_start(input_ids)
else:
seq_2_start_t = segment_ids.index(1)
Expand Down
4 changes: 2 additions & 2 deletions farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, **kwargs):
# it's transformers format (either from model hub or local)
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if "xlm" in pretrained_model_name_or_path and "roberta" in pretrained_model_name_or_path:
language_model = cls.subclasses["XLMRoberta"].load(pretrained_model_name_or_path, **kwargs)
# TODO: for some reason, the pretrained XLMRoberta has different vocab size in the tokenizer compared to the model this is a hack to resolve that
n_added_tokens = 3
language_model = cls.subclasses["XLMRoberta"].load(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
language_model = cls.subclasses["Roberta"].load(pretrained_model_name_or_path, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
Expand All @@ -139,7 +139,7 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, **kwargs):
model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings
vocab_size = model_emb_size + n_added_tokens
logger.info(
f"Resizing embedding layer of LM from {model_emb_size} to {vocab_size} to cope for custom vocab.")
f"Resizing embedding layer of LM from {model_emb_size} to {vocab_size} to cope with custom vocab.")
language_model.model.resize_token_embeddings(vocab_size)
# verify
model_emb_size = language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings
Expand Down

0 comments on commit 5e62b71

Please sign in to comment.