From b7ecb3797680bdbd19855846ce732bd623852de2 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Tue, 27 Oct 2020 18:34:47 +0100 Subject: [PATCH] Infer model type from config (#600) * Inference model and tokenizer type from config * Infer type from model name as fallback --- farm/modeling/language_model.py | 73 +++++++++++++---- farm/modeling/tokenization.py | 139 +++++++++++++++++++++++--------- 2 files changed, 156 insertions(+), 56 deletions(-) diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index 28f0a9327..f4f40eee3 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -42,7 +42,7 @@ from transformers.modeling_distilbert import DistilBertModel, DistilBertConfig from transformers.modeling_electra import ElectraModel, ElectraConfig from transformers.modeling_camembert import CamembertModel, CamembertConfig -from transformers.modeling_auto import AutoModel +from transformers.modeling_auto import AutoModel, AutoConfig from transformers.modeling_utils import SequenceSummary from transformers.tokenization_bert import load_vocab import transformers @@ -116,7 +116,7 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl See all supported model variations here: https://huggingface.co/models - The appropriate language model class is inferred automatically from `pretrained_model_name_or_path` + The appropriate language model class is inferred automatically from model config or can be manually supplied via `language_model_class`. :param pretrained_model_name_or_path: The path of the saved pretrained model or its name. @@ -162,30 +162,70 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl return language_model - @classmethod - def get_language_model_class(cls, model_name_or_path): + @staticmethod + def get_language_model_class(model_name_or_path): # it's transformers format (either from model hub or local) model_name_or_path = str(model_name_or_path) - if "xlm" in model_name_or_path and "roberta" in model_name_or_path: - language_model_class = 'XLMRoberta' - elif 'roberta' in model_name_or_path: - language_model_class = 'Roberta' - elif 'codebert' in model_name_or_path.lower(): + + config = AutoConfig.from_pretrained(model_name_or_path) + model_type = config.model_type + if model_type == "xlm-roberta": + language_model_class = "XLMRoberta" + elif model_type == "roberta": + if "mlm" in model_name_or_path.lower(): + raise NotImplementedError("MLM part of codebert is currently not supported in FARM") + language_model_class = "Roberta" + elif model_type == "camembert": + language_model_class = "Camembert" + elif model_type == "albert": + language_model_class = "Albert" + elif model_type == "distilbert": + language_model_class = "DistilBert" + elif model_type == "bert": + language_model_class = "Bert" + elif model_type == "xlnet": + language_model_class = "XLNet" + elif model_type == "electra": + language_model_class = "Electra" + elif model_type == "dpr": + if config.architectures[0] == "DPRQuestionEncoder": + language_model_class = "DPRQuestionEncoder" + elif config.architectures[0] == "DPRContextEncoder": + language_model_class = "DPRContextEncoder" + elif config.archictectures[0] == "DPRReader": + raise NotImplementedError("DPRReader models are currently not supported.") + else: + # Fall back to inferring type from model name + logger.warning("Could not infer LanguageModel class from config. Trying to infer " + "LanguageModel class from model name.") + language_model_class = LanguageModel._infer_language_model_class_from_string(model_name_or_path) + + return language_model_class + + @staticmethod + def _infer_language_model_class_from_string(model_name_or_path): + # If inferring Language model class from config doesn't succeed, + # fall back to inferring Language model class from model name. + if "xlm" in model_name_or_path.lower() and "roberta" in model_name_or_path.lower(): + language_model_class = "XLMRoberta" + elif "roberta" in model_name_or_path.lower(): + language_model_class = "Roberta" + elif "codebert" in model_name_or_path.lower(): if "mlm" in model_name_or_path.lower(): raise NotImplementedError("MLM part of codebert is currently not supported in FARM") else: - language_model_class = 'Roberta' - elif 'camembert' in model_name_or_path or 'umberto' in model_name_or_path: + language_model_class = "Roberta" + elif "camembert" in model_name_or_path.lower() or "umberto" in model_name_or_path.lower(): language_model_class = "Camembert" - elif 'albert' in model_name_or_path: + elif "albert" in model_name_or_path.lower(): language_model_class = 'Albert' - elif 'distilbert' in model_name_or_path: + elif "distilbert" in model_name_or_path.lower(): language_model_class = 'DistilBert' - elif 'bert' in model_name_or_path: + elif "bert" in model_name_or_path.lower(): language_model_class = 'Bert' - elif 'xlnet' in model_name_or_path: + elif "xlnet" in model_name_or_path.lower(): language_model_class = 'XLNet' - elif 'electra' in model_name_or_path: + elif "electra" in model_name_or_path.lower(): language_model_class = 'Electra' elif "word2vec" in model_name_or_path.lower() or "glove" in model_name_or_path.lower(): language_model_class = 'WordEmbedding_LM' @@ -197,6 +237,7 @@ def get_language_model_class(cls, model_name_or_path): language_model_class = "DPRContextEncoder" else: language_model_class = None + return language_model_class def get_output_dims(self): diff --git a/farm/modeling/tokenization.py b/farm/modeling/tokenization.py index ebcf48cf6..ce1bfae1f 100644 --- a/farm/modeling/tokenization.py +++ b/farm/modeling/tokenization.py @@ -32,6 +32,7 @@ from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer from transformers.tokenization_xlnet import XLNetTokenizer from transformers.tokenization_camembert import CamembertTokenizer +from transformers.modeling_auto import AutoConfig from transformers import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast @@ -53,7 +54,7 @@ class Tokenizer: def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=False, **kwargs): """ Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from - `pretrained_model_name_or_path` or define it manually via `tokenizer_class`. + model config or define it manually via `tokenizer_class`. :param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`) :type pretrained_model_name_or_path: str @@ -66,47 +67,12 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals :param kwargs: :return: Tokenizer """ - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - # guess tokenizer type from name + if tokenizer_class is None: - if "albert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "AlbertTokenizer" - elif "xlm-roberta" in pretrained_model_name_or_path.lower(): - tokenizer_class = "XLMRobertaTokenizer" - elif "roberta" in pretrained_model_name_or_path.lower(): - tokenizer_class = "RobertaTokenizer" - elif 'codebert' in pretrained_model_name_or_path.lower(): - if "mlm" in pretrained_model_name_or_path.lower(): - raise NotImplementedError("MLM part of codebert is currently not supported in FARM") - else: - tokenizer_class = "RobertaTokenizer" - elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path: - tokenizer_class = "CamembertTokenizer" - elif "distilbert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DistilBertTokenizer" - elif "bert" in pretrained_model_name_or_path.lower(): - tokenizer_class = "BertTokenizer" - elif "xlnet" in pretrained_model_name_or_path.lower(): - tokenizer_class = "XLNetTokenizer" - elif "electra" in pretrained_model_name_or_path.lower(): - tokenizer_class = "ElectraTokenizer" - elif "word2vec" in pretrained_model_name_or_path.lower() or \ - "glove" in pretrained_model_name_or_path.lower() or \ - "fasttext" in pretrained_model_name_or_path.lower(): - tokenizer_class = "EmbeddingTokenizer" - elif "minilm" in pretrained_model_name_or_path.lower(): - tokenizer_class = "BertTokenizer" - elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DPRQuestionEncoderTokenizer" - elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower(): - tokenizer_class = "DPRContextEncoderTokenizer" - else: - raise ValueError(f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set " - f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, " - f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or " - f"XLNetTokenizer.") - logger.info(f"Loading tokenizer of type '{tokenizer_class}'") + tokenizer_class = cls._infer_tokenizer_class(pretrained_model_name_or_path) + + logger.info(f"Loading tokenizer of type '{tokenizer_class}'") # return appropriate tokenizer object ret = None if tokenizer_class == "AlbertTokenizer": @@ -175,6 +141,99 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=Fals else: return ret + @staticmethod + def _infer_tokenizer_class(pretrained_model_name_or_path): + # Infer Tokenizer from model type in config + try: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + except OSError: + # FARM model (no 'config.json' file) + try: + config = AutoConfig.from_pretrained(pretrained_model_name_or_path + "/language_model_config.json") + except Exception as e: + logger.warning("No config file found. Trying to infer Tokenizer type from model name") + tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path) + return tokenizer_class + + model_type = config.model_type + if model_type == "xlm-roberta": + tokenizer_class = "XLMRobertaTokenizer" + elif model_type == "roberta": + if "mlm" in pretrained_model_name_or_path.lower(): + raise NotImplementedError("MLM part of codebert is currently not supported in FARM") + tokenizer_class = "RobertaTokenizer" + elif model_type == "camembert": + tokenizer_class = "CamembertTokenizer" + elif model_type == "albert": + tokenizer_class = "AlbertTokenizer" + elif model_type == "distilbert": + tokenizer_class = "DistilBertTokenizer" + elif model_type == "bert": + tokenizer_class = "BertTokenizer" + elif model_type == "xlnet": + tokenizer_class = "XLNetTokenizer" + elif model_type == "electra": + tokenizer_class = "ElectraTokenizer" + elif model_type == "dpr": + if config.architectures[0] == "DPRQuestionEncoder": + tokenizer_class = "DPRQuestionEncoderTokenizer" + elif config.architectures[0] == "DPRContextEncoder": + tokenizer_class = "DPRContextEncoderTokenizer" + elif config.archictectures[0] == "DPRReader": + raise NotImplementedError("DPRReader models are currently not supported.") + else: + # Fall back to inferring type from model name + logger.warning("Could not infer Tokenizer type from config. Trying to infer " + "Tokenizer type from model name.") + tokenizer_class = Tokenizer._infer_tokenizer_class_from_string(pretrained_model_name_or_path) + + return tokenizer_class + + @staticmethod + def _infer_tokenizer_class_from_string(pretrained_model_name_or_path): + # If inferring tokenizer class from config doesn't succeed, + # fall back to inferring tokenizer class from model name. + if "albert" in pretrained_model_name_or_path.lower(): + tokenizer_class = "AlbertTokenizer" + elif "xlm-roberta" in pretrained_model_name_or_path.lower(): + tokenizer_class = "XLMRobertaTokenizer" + elif "roberta" in pretrained_model_name_or_path.lower(): + tokenizer_class = "RobertaTokenizer" + elif "codebert" in pretrained_model_name_or_path.lower(): + if "mlm" in pretrained_model_name_or_path.lower(): + raise NotImplementedError("MLM part of codebert is currently not supported in FARM") + else: + tokenizer_class = "RobertaTokenizer" + elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path.lower(): + tokenizer_class = "CamembertTokenizer" + elif "distilbert" in pretrained_model_name_or_path.lower(): + tokenizer_class = "DistilBertTokenizer" + elif "bert" in pretrained_model_name_or_path.lower(): + tokenizer_class = "BertTokenizer" + elif "xlnet" in pretrained_model_name_or_path.lower(): + tokenizer_class = "XLNetTokenizer" + elif "electra" in pretrained_model_name_or_path.lower(): + tokenizer_class = "ElectraTokenizer" + elif "word2vec" in pretrained_model_name_or_path.lower() or \ + "glove" in pretrained_model_name_or_path.lower() or \ + "fasttext" in pretrained_model_name_or_path.lower(): + tokenizer_class = "EmbeddingTokenizer" + elif "minilm" in pretrained_model_name_or_path.lower(): + tokenizer_class = "BertTokenizer" + elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(): + tokenizer_class = "DPRQuestionEncoderTokenizer" + elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower(): + tokenizer_class = "DPRContextEncoderTokenizer" + else: + raise ValueError(f"Could not infer tokenizer_class from model config or " + f"name '{pretrained_model_name_or_path}'. Set arg `tokenizer_class` " + f"in Tokenizer.load() to one of: AlbertTokenizer, XLMRobertaTokenizer, " + f"RobertaTokenizer, DistilBertTokenizer, BertTokenizer, XLNetTokenizer, " + f"CamembertTokenizer, ElectraTokenizer, DPRQuestionEncoderTokenizer," + f"DPRContextEncoderTokenizer.") + + return tokenizer_class + class EmbeddingTokenizer(PreTrainedTokenizer): """Constructs an EmbeddingTokenizer.