From da4fce0b832784d03811255a96daafdb516e9dec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan=20Kosti=C4=87?= Date: Tue, 2 Jun 2020 18:24:05 +0200 Subject: [PATCH 1/5] Add support for Camembert-like models --- farm/modeling/adaptive_model.py | 1 - farm/modeling/language_model.py | 60 +++++++++++++++++++++++++++++++-- farm/modeling/tokenization.py | 5 +++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/farm/modeling/adaptive_model.py b/farm/modeling/adaptive_model.py index c3c9d8561..aada3c358 100644 --- a/farm/modeling/adaptive_model.py +++ b/farm/modeling/adaptive_model.py @@ -256,7 +256,6 @@ def __init__( self.lm_output_types = ( [lm_output_types] if isinstance(lm_output_types, str) else lm_output_types ) - assert len(self.lm_output_types) == len(self.prediction_heads) self.log_params() # default loss aggregation function is a simple sum (without using any of the optional params) if not loss_aggregation_fn: diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index 32de9ee5e..c68744f76 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -41,13 +41,13 @@ from transformers.modeling_xlm_roberta import XLMRobertaModel, XLMRobertaConfig from transformers.modeling_distilbert import DistilBertModel, DistilBertConfig from transformers.modeling_electra import ElectraModel, ElectraConfig +from transformers.modeling_camembert import CamembertModel, CamembertConfig from transformers.modeling_utils import SequenceSummary from transformers.tokenization_bert import load_vocab from farm.modeling import wordembedding_utils from farm.modeling.wordembedding_utils import s3e_pooling - # These are the names of the attributes in various model configs which refer to the number of dimensions # in the output vectors OUTPUT_DIM_NAMES = ["dim", "hidden_size", "d_model"] @@ -134,6 +134,8 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl language_model_class = 'XLMRoberta' elif 'roberta' in pretrained_model_name_or_path: language_model_class = 'Roberta' + elif 'camembert' in pretrained_model_name_or_path or 'umberto' in pretrained_model_name_or_path: + language_model_class = "Camembert" elif 'albert' in pretrained_model_name_or_path: language_model_class = 'Albert' elif 'distilbert' in pretrained_model_name_or_path: @@ -247,6 +249,16 @@ def _infer_language_from_name(cls, name): f"\t Found multiple matches: {matches}\n" "\t Please init the language model by manually supplying the 'language' as a parameter.\n" ) + elif "camembert" in name: + language = "french" + logger.info( + f"Automatically detected language from language model name: {language}" + ) + elif "umberto" in name: + language = "italian" + logger.info( + f"Automatically detected language from language model name: {language}" + ) else: language = matches[0] logger.info( @@ -879,7 +891,7 @@ def forward( ) # XLNet also only returns the sequence_output (one vec per token) # We need to manually aggregate that to get a pooled output (one vec per seq) - #TODO verify that this is really doing correct pooling + # TODO verify that this is really doing correct pooling pooled_output = self.pooler(output_tuple[0]) if self.model.output_hidden_states == True: @@ -1282,3 +1294,47 @@ def enable_hidden_states_output(self): def disable_hidden_states_output(self): self.model.config.output_hidden_states = False + +class Camembert(Roberta): + """ + A Camembert model that wraps the HuggingFace's implementation + (https://github.com/huggingface/transformers) to fit the LanguageModel class. + """ + def __init__(self): + super(Camembert, self).__init__() + self.model = None + self.name = "camembert" + + @classmethod + def load(cls, pretrained_model_name_or_path, language=None, **kwargs): + """ + Load a language model either by supplying + + * the name of a remote model on s3 ("camembert-base" ...) + * or a local path of a model trained via transformers ("some_dir/huggingface_model") + * or a local path of a model trained via FARM ("some_dir/farm_model") + + :param pretrained_model_name_or_path: name or path of a model + :param language: (Optional) Name of language the model was trained for (e.g. "german"). + If not supplied, FARM will try to infer it from the model name. + :return: Language Model + + """ + camembert = cls() + if "farm_lm_name" in kwargs: + camembert.name = kwargs["farm_lm_name"] + else: + camembert.name = pretrained_model_name_or_path + # We need to differentiate between loading model using FARM format and Pytorch-Transformers format + farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json" + if os.path.exists(farm_lm_config): + # FARM style + config = CamembertConfig.from_pretrained(farm_lm_config) + farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin" + camembert.model = CamembertModel.from_pretrained(farm_lm_model, config=config, **kwargs) + camembert.language = camembert.model.config.language + else: + # Huggingface transformer Style + camembert.model = CamembertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) + camembert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) + return camembert diff --git a/farm/modeling/tokenization.py b/farm/modeling/tokenization.py index ad53d175c..b87ab571e 100644 --- a/farm/modeling/tokenization.py +++ b/farm/modeling/tokenization.py @@ -31,6 +31,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer from transformers.tokenization_xlnet import XLNetTokenizer +from transformers.tokenization_camembert import CamembertTokenizer from farm.modeling.wordembedding_utils import load_from_cache, EMBEDDING_VOCAB_FILES_MAP, run_split_on_punc @@ -69,6 +70,8 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, **kwargs): tokenizer_class = "XLMRobertaTokenizer" elif "roberta" in pretrained_model_name_or_path.lower(): tokenizer_class = "RobertaTokenizer" + elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path: + tokenizer_class = "CamembertTokenizer" elif "distilbert" in pretrained_model_name_or_path.lower(): tokenizer_class = "DistilBertTokenizer" elif "bert" in pretrained_model_name_or_path.lower(): @@ -104,6 +107,8 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, **kwargs): ret = ElectraTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "EmbeddingTokenizer": ret = EmbeddingTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif tokenizer_class == "CamembertTokenizer": + ret = CamembertTokenizer._from_pretrained(pretrained_model_name_or_path, **kwargs) if ret is None: raise Exception("Unable to load tokenizer") else: From 6f278d24fa0e37bb580c25abc5c153c800e9f36a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan=20Kosti=C4=87?= Date: Mon, 15 Jun 2020 11:23:33 +0200 Subject: [PATCH 2/5] Update readme + langauge detection for camembert and umberto --- farm/modeling/language_model.py | 22 +++++++++++----------- readme.rst | 4 +++- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index c68744f76..58ea174af 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -236,7 +236,17 @@ def _infer_language_from_name(cls, name): "multilingual", ) matches = [lang for lang in known_languages if lang in name] - if len(matches) == 0: + if "camembert" in name: + language = "french" + logger.info( + f"Automatically detected language from language model name: {language}" + ) + elif "umberto" in name: + language = "italian" + logger.info( + f"Automatically detected language from language model name: {language}" + ) + elif len(matches) == 0: language = "english" logger.warning( "Could not automatically detect from language model name what language it is. \n" @@ -249,16 +259,6 @@ def _infer_language_from_name(cls, name): f"\t Found multiple matches: {matches}\n" "\t Please init the language model by manually supplying the 'language' as a parameter.\n" ) - elif "camembert" in name: - language = "french" - logger.info( - f"Automatically detected language from language model name: {language}" - ) - elif "umberto" in name: - language = "italian" - logger.info( - f"Automatically detected language from language model name: {language}" - ) else: language = matches[0] logger.info( diff --git a/readme.rst b/readme.rst index 89b56e984..fe303f68b 100644 --- a/readme.rst +++ b/readme.rst @@ -66,7 +66,7 @@ Core features - Simple **deployment** and **visualization** to showcase your model +------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+ -| Task | BERT | RoBERTa | XLNet | ALBERT | DistilBERT | XLMRoBERTa | +| Task | BERT | RoBERTa* | XLNet | ALBERT | DistilBERT | XLMRoBERTa | +==============================+===================+===================+===================+===================+===================+===================+ | Text classification | x | x | x | x | x | x | +------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+ @@ -89,6 +89,8 @@ Core features | Passage Ranking | x | x | x | x | x | x | +------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+ +* including CamemBERT and UmBERTo + ****NEW**** Interested in doing Question Answering at scale? Checkout `Haystack `_! Resources From c8f19571e939315f6b4056a3720809903e122c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan=20Kosti=C4=87?= Date: Mon, 15 Jun 2020 11:27:20 +0200 Subject: [PATCH 3/5] Fix format in readme --- readme.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/readme.rst b/readme.rst index fe303f68b..122085379 100644 --- a/readme.rst +++ b/readme.rst @@ -89,7 +89,7 @@ Core features | Passage Ranking | x | x | x | x | x | x | +------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+ -* including CamemBERT and UmBERTo +\* including CamemBERT and UmBERTo ****NEW**** Interested in doing Question Answering at scale? Checkout `Haystack `_! From 7df7b5f31b8398891ee4cc470b1b06d3ddd52ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan=20Kosti=C4=87?= Date: Mon, 15 Jun 2020 14:21:29 +0200 Subject: [PATCH 4/5] Implement method to get sequence 2 start for camembert models --- farm/data_handler/input_features.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/farm/data_handler/input_features.py b/farm/data_handler/input_features.py index a24335c0d..ca0850a77 100644 --- a/farm/data_handler/input_features.py +++ b/farm/data_handler/input_features.py @@ -358,6 +358,8 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None, # seq_2_start_t is the index of the first token in the second text sequence (e.g. passage) if tokenizer.__class__.__name__ in ["RobertaTokenizer", "XLMRobertaTokenizer"]: seq_2_start_t = get_roberta_seq_2_start(input_ids) + elif tokenizer.__class__.__name__ == "CamembertTokenizer": + seq_2_start_t = get_camembert_seq_2_start(input_ids) else: seq_2_start_t = segment_ids.index(1) @@ -514,6 +516,17 @@ def get_roberta_seq_2_start(input_ids): second_backslash_s = input_ids.index(2, first_backslash_s + 1) return second_backslash_s + 1 +def get_camembert_seq_2_start(input_ids): + # CamembertTokenizer.encode_plus returns only zeros in token_type_ids (same as RobertaTokenizer). + # This is another way to find the start of the second sequence (following get_roberta_seq_2_start) + # Camembert input sequences have the following + # format: P1 P2 + # has index 5 and has index 6. To find the beginning of the second sequence, this function first finds + # the index of the second + first_backslash_s = input_ids.index(6) + second_backslash_s = input_ids.index(6, first_backslash_s + 1) + return second_backslash_s + 1 + def sample_to_features_squadOLD( sample, tokenizer, max_seq_len, doc_stride, max_query_length, tasks, ): From 8e6c5aadc99da49d8966f9e390c57ac3e7727eb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bogdan=20Kosti=C4=87?= Date: Mon, 15 Jun 2020 15:40:33 +0200 Subject: [PATCH 5/5] Log warning instead of throwing error when matching more than one language --- farm/modeling/language_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index 58ea174af..f76cda425 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -254,11 +254,13 @@ def _infer_language_from_name(cls, name): "\t If not: Init the language model by supplying the 'language' param." ) elif len(matches) > 1: - raise ValueError( + logger.warning( "Could not automatically detect from language model name what language it is.\n" f"\t Found multiple matches: {matches}\n" "\t Please init the language model by manually supplying the 'language' as a parameter.\n" + f"\t Using {matches[0]} as language parameter for now.\n" ) + language = matches[0] else: language = matches[0] logger.info(