Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer model type from config #600

Merged
merged 4 commits into from
Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions farm/modeling/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from transformers.modeling_distilbert import DistilBertModel, DistilBertConfig
from transformers.modeling_electra import ElectraModel, ElectraConfig
from transformers.modeling_camembert import CamembertModel, CamembertConfig
from transformers.modeling_auto import AutoModel
from transformers.modeling_auto import AutoModel, AutoConfig
from transformers.modeling_utils import SequenceSummary
from transformers.tokenization_bert import load_vocab
import transformers
Expand Down Expand Up @@ -116,7 +116,7 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl

See all supported model variations here: https://huggingface.co/models

The appropriate language model class is inferred automatically from `pretrained_model_name_or_path`
The appropriate language model class is inferred automatically from model config
or can be manually supplied via `language_model_class`.

:param pretrained_model_name_or_path: The path of the saved pretrained model or its name.
Expand Down Expand Up @@ -162,30 +162,70 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl

return language_model

@classmethod
def get_language_model_class(cls, model_name_or_path):
@staticmethod
def get_language_model_class(model_name_or_path):
# it's transformers format (either from model hub or local)
model_name_or_path = str(model_name_or_path)
if "xlm" in model_name_or_path and "roberta" in model_name_or_path:
language_model_class = 'XLMRoberta'
elif 'roberta' in model_name_or_path:
language_model_class = 'Roberta'
elif 'codebert' in model_name_or_path.lower():

config = AutoConfig.from_pretrained(model_name_or_path)
model_type = config.model_type
if model_type == "xlm-roberta":
language_model_class = "XLMRoberta"
elif model_type == "roberta":
if "mlm" in model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
language_model_class = "Roberta"
elif model_type == "camembert":
language_model_class = "Camembert"
elif model_type == "albert":
language_model_class = "Albert"
elif model_type == "distilbert":
language_model_class = "DistilBert"
elif model_type == "bert":
language_model_class = "Bert"
elif model_type == "xlnet":
language_model_class = "XLNet"
elif model_type == "electra":
language_model_class = "Electra"
elif model_type == "dpr":
if config.architectures[0] == "DPRQuestionEncoder":
language_model_class = "DPRQuestionEncoder"
elif config.architectures[0] == "DPRContextEncoder":
language_model_class = "DPRContextEncoder"
bogdankostic marked this conversation as resolved.
Show resolved Hide resolved
elif config.archictectures[0] == "DPRReader":
raise NotImplementedError("DPRReader models are currently not supported.")
else:
# Fall back to inferring type from model name
logger.warning("Could not infer LanguageModel class from config. Trying to infer "
"LanguageModel class from model name.")
language_model_class = LanguageModel._infer_language_model_class_from_string(model_name_or_path)

return language_model_class

@staticmethod
def _infer_language_model_class_from_string(model_name_or_path):
# If inferring Language model class from config doesn't succeed,
# fall back to inferring Language model class from model name.
if "xlm" in model_name_or_path.lower() and "roberta" in model_name_or_path.lower():
language_model_class = "XLMRoberta"
elif "roberta" in model_name_or_path.lower():
language_model_class = "Roberta"
elif "codebert" in model_name_or_path.lower():
if "mlm" in model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
language_model_class = 'Roberta'
elif 'camembert' in model_name_or_path or 'umberto' in model_name_or_path:
language_model_class = "Roberta"
elif "camembert" in model_name_or_path.lower() or "umberto" in model_name_or_path.lower():
language_model_class = "Camembert"
elif 'albert' in model_name_or_path:
elif "albert" in model_name_or_path.lower():
language_model_class = 'Albert'
elif 'distilbert' in model_name_or_path:
elif "distilbert" in model_name_or_path.lower():
language_model_class = 'DistilBert'
elif 'bert' in model_name_or_path:
elif "bert" in model_name_or_path.lower():
language_model_class = 'Bert'
elif 'xlnet' in model_name_or_path:
elif "xlnet" in model_name_or_path.lower():
language_model_class = 'XLNet'
elif 'electra' in model_name_or_path:
elif "electra" in model_name_or_path.lower():
language_model_class = 'Electra'
elif "word2vec" in model_name_or_path.lower() or "glove" in model_name_or_path.lower():
language_model_class = 'WordEmbedding_LM'
Expand All @@ -197,6 +237,7 @@ def get_language_model_class(cls, model_name_or_path):
language_model_class = "DPRContextEncoder"
else:
language_model_class = None

return language_model_class

def get_output_dims(self):
Expand Down
139 changes: 99 additions & 40 deletions farm/modeling/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.tokenization_xlnet import XLNetTokenizer
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_auto import AutoConfig
from transformers import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast

Expand All @@ -53,7 +54,7 @@ class Tokenizer:
def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=False, **kwargs):
"""
Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
`pretrained_model_name_or_path` or define it manually via `tokenizer_class`.
model config or define it manually via `tokenizer_class`.

:param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
:type pretrained_model_name_or_path: str
Expand All @@ -66,47 +67,12 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
:param kwargs:
:return: Tokenizer
"""

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
# guess tokenizer type from name

if tokenizer_class is None:
if "albert" in pretrained_model_name_or_path.lower():
tokenizer_class = "AlbertTokenizer"
elif "xlm-roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLMRobertaTokenizer"
elif "roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "RobertaTokenizer"
elif 'codebert' in pretrained_model_name_or_path.lower():
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
tokenizer_class = "RobertaTokenizer"
elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path:
tokenizer_class = "CamembertTokenizer"
elif "distilbert" in pretrained_model_name_or_path.lower():
tokenizer_class = "DistilBertTokenizer"
elif "bert" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "xlnet" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLNetTokenizer"
elif "electra" in pretrained_model_name_or_path.lower():
tokenizer_class = "ElectraTokenizer"
elif "word2vec" in pretrained_model_name_or_path.lower() or \
"glove" in pretrained_model_name_or_path.lower() or \
"fasttext" in pretrained_model_name_or_path.lower():
tokenizer_class = "EmbeddingTokenizer"
elif "minilm" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "dpr-question_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRContextEncoderTokenizer"
else:
raise ValueError(f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set "
f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, "
f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or "
f"XLNetTokenizer.")
logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
tokenizer_class = cls._infer_tokenizer_class(pretrained_model_name_or_path)

logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
# return appropriate tokenizer object
ret = None
if tokenizer_class == "AlbertTokenizer":
Expand Down Expand Up @@ -175,6 +141,99 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals
else:
return ret

@staticmethod
def _infer_tokenizer_class(pretrained_model_name_or_path):
# Infer Tokenizer from model type in config
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
except OSError:
# FARM model (no 'config.json' file)
try:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path + "/language_model_config.json")
except Exception as e:
logger.warning("No config file found. Trying to infer Tokenizer type from model name")
tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path)
return tokenizer_class

model_type = config.model_type
if model_type == "xlm-roberta":
tokenizer_class = "XLMRobertaTokenizer"
elif model_type == "roberta":
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
tokenizer_class = "RobertaTokenizer"
elif model_type == "camembert":
tokenizer_class = "CamembertTokenizer"
elif model_type == "albert":
tokenizer_class = "AlbertTokenizer"
elif model_type == "distilbert":
tokenizer_class = "DistilBertTokenizer"
elif model_type == "bert":
tokenizer_class = "BertTokenizer"
elif model_type == "xlnet":
tokenizer_class = "XLNetTokenizer"
elif model_type == "electra":
tokenizer_class = "ElectraTokenizer"
elif model_type == "dpr":
if config.architectures[0] == "DPRQuestionEncoder":
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif config.architectures[0] == "DPRContextEncoder":
tokenizer_class = "DPRContextEncoderTokenizer"
bogdankostic marked this conversation as resolved.
Show resolved Hide resolved
elif config.archictectures[0] == "DPRReader":
raise NotImplementedError("DPRReader models are currently not supported.")
else:
# Fall back to inferring type from model name
logger.warning("Could not infer Tokenizer type from config. Trying to infer "
"Tokenizer type from model name.")
tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path)

return tokenizer_class

@staticmethod
def _infer_tokenizer_class_from_string(pretrained_model_name_or_path):
# If inferring tokenizer class from config doesn't succeed,
# fall back to inferring tokenizer class from model name.
if "albert" in pretrained_model_name_or_path.lower():
tokenizer_class = "AlbertTokenizer"
elif "xlm-roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLMRobertaTokenizer"
elif "roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "RobertaTokenizer"
elif "codebert" in pretrained_model_name_or_path.lower():
if "mlm" in pretrained_model_name_or_path.lower():
raise NotImplementedError("MLM part of codebert is currently not supported in FARM")
else:
tokenizer_class = "RobertaTokenizer"
elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path.lower():
tokenizer_class = "CamembertTokenizer"
elif "distilbert" in pretrained_model_name_or_path.lower():
tokenizer_class = "DistilBertTokenizer"
elif "bert" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "xlnet" in pretrained_model_name_or_path.lower():
tokenizer_class = "XLNetTokenizer"
elif "electra" in pretrained_model_name_or_path.lower():
tokenizer_class = "ElectraTokenizer"
elif "word2vec" in pretrained_model_name_or_path.lower() or \
"glove" in pretrained_model_name_or_path.lower() or \
"fasttext" in pretrained_model_name_or_path.lower():
tokenizer_class = "EmbeddingTokenizer"
elif "minilm" in pretrained_model_name_or_path.lower():
tokenizer_class = "BertTokenizer"
elif "dpr-question_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRQuestionEncoderTokenizer"
elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
tokenizer_class = "DPRContextEncoderTokenizer"
else:
raise ValueError(f"Could not infer tokenizer_class from model config or "
f"name '{pretrained_model_name_or_path}'. Set arg `tokenizer_class` "
f"in Tokenizer.load() to one of: AlbertTokenizer, XLMRobertaTokenizer, "
f"RobertaTokenizer, DistilBertTokenizer, BertTokenizer, XLNetTokenizer, "
f"CamembertTokenizer, ElectraTokenizer, DPRQuestionEncoderTokenizer,"
f"DPRContextEncoderTokenizer.")

return tokenizer_class


class EmbeddingTokenizer(PreTrainedTokenizer):
"""Constructs an EmbeddingTokenizer.
Expand Down