Skip to content

Commit

Permalink
Infer model type from config (#600)
Browse files Browse the repository at this point in the history
* Inference model and tokenizer type from config

* Infer type from model name as fallback
  • Loading branch information
bogdankostic authored Oct 27, 2020
1 parent b33e17c commit b7ecb37
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 56 deletions.
73 changes: 57 additions & 16 deletions farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from transformers.modeling_distilbert import DistilBertModel, DistilBertConfig
from transformers.modeling_electra import ElectraModel, ElectraConfig
from transformers.modeling_camembert import CamembertModel, CamembertConfig
from transformers.modeling_auto import AutoModel
from transformers.modeling_auto import AutoModel, AutoConfig
from transformers.modeling_utils import SequenceSummary
from transformers.tokenization_bert import load_vocab
import transformers
Expand Down Expand Up @@ -116,7 +116,7 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl
See all supported model variations here: https://huggingface.co/models
The appropriate language model class is inferred automatically from `pretrained_model_name_or_path`
The appropriate language model class is inferred automatically from model config
or can be manually supplied via `language_model_class`.
:param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
Expand Down Expand Up @@ -162,30 +162,70 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl

return language_model

@classmethod
def get_language_model_class(cls, model_name_or_path):
@staticmethod
def get_language_model_class(model_name_or_path):
# it's transformers format (either from model hub or local)
model_name_or_path = str(model_name_or_path)
if "xlm" in model_name_or_path and "roberta" in model_name_or_path:
language_model_class = 'XLMRoberta'
elif 'roberta' in model_name_or_path:
language_model_class = 'Roberta'
elif 'codebert' in model_name_or_path.lower():

config = AutoConfig.from_pretrained(model_name_or_path)
model_type = config.model_type
if model_type == "xlm-roberta":
language_model_class = "XLMRoberta"
elif model_type == "roberta":
if "mlm" in model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
language_model_class = "Roberta"
elif model_type == "camembert":
language_model_class = "Camembert"
elif model_type == "albert":
language_model_class = "Albert"
elif model_type == "distilbert":
language_model_class = "DistilBert"
elif model_type == "bert":
language_model_class = "Bert"
elif model_type == "xlnet":
language_model_class = "XLNet"
elif model_type == "electra":
language_model_class = "Electra"
elif model_type == "dpr":
if config.architectures[0] == "DPRQuestionEncoder":
language_model_class = "DPRQuestionEncoder"
elif config.architectures[0] == "DPRContextEncoder":
language_model_class = "DPRContextEncoder"
elif config.archictectures[0] == "DPRReader":
raise NotImplementedError("DPRReader models are currently not supported.")
else:
# Fall back to inferring type from model name
logger.warning("Could not infer LanguageModel class from config. Trying to infer "
"LanguageModel class from model name.")
language_model_class = LanguageModel._infer_language_model_class_from_string(model_name_or_path)

return language_model_class

@staticmethod
def _infer_language_model_class_from_string(model_name_or_path):
# If inferring Language model class from config doesn't succeed,
# fall back to inferring Language model class from model name.
if "xlm" in model_name_or_path.lower() and "roberta" in model_name_or_path.lower():
language_model_class = "XLMRoberta"
elif "roberta" in model_name_or_path.lower():
language_model_class = "Roberta"
elif "codebert" in model_name_or_path.lower():
if "mlm" in model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
language_model_class = 'Roberta'
elif 'camembert' in model_name_or_path or 'umberto' in model_name_or_path:
language_model_class = "Roberta"
elif "camembert" in model_name_or_path.lower() or "umberto" in model_name_or_path.lower():
language_model_class = "Camembert"
elif 'albert' in model_name_or_path:
elif "albert" in model_name_or_path.lower():
language_model_class = 'Albert'
elif 'distilbert' in model_name_or_path:
elif "distilbert" in model_name_or_path.lower():
language_model_class = 'DistilBert'
elif 'bert' in model_name_or_path:
elif "bert" in model_name_or_path.lower():
language_model_class = 'Bert'
elif 'xlnet' in model_name_or_path:
elif "xlnet" in model_name_or_path.lower():
language_model_class = 'XLNet'
elif 'electra' in model_name_or_path:
elif "electra" in model_name_or_path.lower():
language_model_class = 'Electra'
elif "word2vec" in model_name_or_path.lower() or "glove" in model_name_or_path.lower():
language_model_class = 'WordEmbedding_LM'
Expand All @@ -197,6 +237,7 @@ def get_language_model_class(cls, model_name_or_path):
language_model_class = "DPRContextEncoder"
else:
language_model_class = None

return language_model_class

def get_output_dims(self):
Expand Down
139 changes: 99 additions & 40 deletions farm/modeling/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.tokenization_xlnet import XLNetTokenizer
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_auto import AutoConfig
from transformers import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast

Expand All @@ -53,7 +54,7 @@ class Tokenizer:
def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=False, **kwargs):
"""
Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
`pretrained_model_name_or_path` or define it manually via `tokenizer_class`.
model config or define it manually via `tokenizer_class`.
:param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
:type pretrained_model_name_or_path: str
Expand All @@ -66,47 +67,12 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
:param kwargs:
:return: Tokenizer
"""

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
# guess tokenizer type from name

if tokenizer_class is None:
if "albert" in pretrained_model_name_or_path.lower():
tokenizer_class = "AlbertTokenizer"
elif "xlm-roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLMRobertaTokenizer"
elif "roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "RobertaTokenizer"
elif 'codebert' in pretrained_model_name_or_path.lower():
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
tokenizer_class = "RobertaTokenizer"
elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path:
tokenizer_class = "CamembertTokenizer"
elif "distilbert" in pretrained_model_name_or_path.lower():
tokenizer_class = "DistilBertTokenizer"
elif "bert" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "xlnet" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLNetTokenizer"
elif "electra" in pretrained_model_name_or_path.lower():
tokenizer_class = "ElectraTokenizer"
elif "word2vec" in pretrained_model_name_or_path.lower() or \
"glove" in pretrained_model_name_or_path.lower() or \
"fasttext" in pretrained_model_name_or_path.lower():
tokenizer_class = "EmbeddingTokenizer"
elif "minilm" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "dpr-question_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRContextEncoderTokenizer"
else:
raise ValueError(f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set "
f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, "
f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or "
f"XLNetTokenizer.")
logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
tokenizer_class = cls._infer_tokenizer_class(pretrained_model_name_or_path)

logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
# return appropriate tokenizer object
ret = None
if tokenizer_class == "AlbertTokenizer":
Expand Down Expand Up @@ -175,6 +141,99 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
else:
return ret

@staticmethod
def _infer_tokenizer_class(pretrained_model_name_or_path):
# Infer Tokenizer from model type in config
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
except OSError:
# FARM model (no 'config.json' file)
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path + "/language_model_config.json")
except Exception as e:
logger.warning("No config file found. Trying to infer Tokenizer type from model name")
tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path)
return tokenizer_class

model_type = config.model_type
if model_type == "xlm-roberta":
tokenizer_class = "XLMRobertaTokenizer"
elif model_type == "roberta":
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
tokenizer_class = "RobertaTokenizer"
elif model_type == "camembert":
tokenizer_class = "CamembertTokenizer"
elif model_type == "albert":
tokenizer_class = "AlbertTokenizer"
elif model_type == "distilbert":
tokenizer_class = "DistilBertTokenizer"
elif model_type == "bert":
tokenizer_class = "BertTokenizer"
elif model_type == "xlnet":
tokenizer_class = "XLNetTokenizer"
elif model_type == "electra":
tokenizer_class = "ElectraTokenizer"
elif model_type == "dpr":
if config.architectures[0] == "DPRQuestionEncoder":
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif config.architectures[0] == "DPRContextEncoder":
tokenizer_class = "DPRContextEncoderTokenizer"
elif config.archictectures[0] == "DPRReader":
raise NotImplementedError("DPRReader models are currently not supported.")
else:
# Fall back to inferring type from model name
logger.warning("Could not infer Tokenizer type from config. Trying to infer "
"Tokenizer type from model name.")
tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path)

return tokenizer_class

@staticmethod
def _infer_tokenizer_class_from_string(pretrained_model_name_or_path):
# If inferring tokenizer class from config doesn't succeed,
# fall back to inferring tokenizer class from model name.
if "albert" in pretrained_model_name_or_path.lower():
tokenizer_class = "AlbertTokenizer"
elif "xlm-roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLMRobertaTokenizer"
elif "roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "RobertaTokenizer"
elif "codebert" in pretrained_model_name_or_path.lower():
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
tokenizer_class = "RobertaTokenizer"
elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path.lower():
tokenizer_class = "CamembertTokenizer"
elif "distilbert" in pretrained_model_name_or_path.lower():
tokenizer_class = "DistilBertTokenizer"
elif "bert" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "xlnet" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLNetTokenizer"
elif "electra" in pretrained_model_name_or_path.lower():
tokenizer_class = "ElectraTokenizer"
elif "word2vec" in pretrained_model_name_or_path.lower() or \
"glove" in pretrained_model_name_or_path.lower() or \
"fasttext" in pretrained_model_name_or_path.lower():
tokenizer_class = "EmbeddingTokenizer"
elif "minilm" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "dpr-question_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRContextEncoderTokenizer"
else:
raise ValueError(f"Could not infer tokenizer_class from model config or "
f"name '{pretrained_model_name_or_path}'. Set arg `tokenizer_class` "
f"in Tokenizer.load() to one of: AlbertTokenizer, XLMRobertaTokenizer, "
f"RobertaTokenizer, DistilBertTokenizer, BertTokenizer, XLNetTokenizer, "
f"CamembertTokenizer, ElectraTokenizer, DPRQuestionEncoderTokenizer,"
f"DPRContextEncoderTokenizer.")

return tokenizer_class


class EmbeddingTokenizer(PreTrainedTokenizer):
"""Constructs an EmbeddingTokenizer.
Expand Down

0 comments on commit b7ecb37

Please sign in to comment.