diff --git a/farm/data_handler/input_features.py b/farm/data_handler/input_features.py
index a24335c0d..ca0850a77 100644
--- a/farm/data_handler/input_features.py
+++ b/farm/data_handler/input_features.py
@@ -358,6 +358,8 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None,
# seq_2_start_t is the index of the first token in the second text sequence (e.g. passage)
if tokenizer.__class__.__name__ in ["RobertaTokenizer", "XLMRobertaTokenizer"]:
seq_2_start_t = get_roberta_seq_2_start(input_ids)
+ elif tokenizer.__class__.__name__ == "CamembertTokenizer":
+ seq_2_start_t = get_camembert_seq_2_start(input_ids)
else:
seq_2_start_t = segment_ids.index(1)
@@ -514,6 +516,17 @@ def get_roberta_seq_2_start(input_ids):
second_backslash_s = input_ids.index(2, first_backslash_s + 1)
return second_backslash_s + 1
+def get_camembert_seq_2_start(input_ids):
+ # CamembertTokenizer.encode_plus returns only zeros in token_type_ids (same as RobertaTokenizer).
+ # This is another way to find the start of the second sequence (following get_roberta_seq_2_start)
+ # Camembert input sequences have the following
+ # format: P1 P2
+ # has index 5 and has index 6. To find the beginning of the second sequence, this function first finds
+ # the index of the second
+ first_backslash_s = input_ids.index(6)
+ second_backslash_s = input_ids.index(6, first_backslash_s + 1)
+ return second_backslash_s + 1
+
def sample_to_features_squadOLD(
sample, tokenizer, max_seq_len, doc_stride, max_query_length, tasks,
):
diff --git a/farm/modeling/adaptive_model.py b/farm/modeling/adaptive_model.py
index c3c9d8561..aada3c358 100644
--- a/farm/modeling/adaptive_model.py
+++ b/farm/modeling/adaptive_model.py
@@ -256,7 +256,6 @@ def __init__(
self.lm_output_types = (
[lm_output_types] if isinstance(lm_output_types, str) else lm_output_types
)
- assert len(self.lm_output_types) == len(self.prediction_heads)
self.log_params()
# default loss aggregation function is a simple sum (without using any of the optional params)
if not loss_aggregation_fn:
diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py
index 32de9ee5e..f76cda425 100644
--- a/farm/modeling/language_model.py
+++ b/farm/modeling/language_model.py
@@ -41,13 +41,13 @@
from transformers.modeling_xlm_roberta import XLMRobertaModel, XLMRobertaConfig
from transformers.modeling_distilbert import DistilBertModel, DistilBertConfig
from transformers.modeling_electra import ElectraModel, ElectraConfig
+from transformers.modeling_camembert import CamembertModel, CamembertConfig
from transformers.modeling_utils import SequenceSummary
from transformers.tokenization_bert import load_vocab
from farm.modeling import wordembedding_utils
from farm.modeling.wordembedding_utils import s3e_pooling
-
# These are the names of the attributes in various model configs which refer to the number of dimensions
# in the output vectors
OUTPUT_DIM_NAMES = ["dim", "hidden_size", "d_model"]
@@ -134,6 +134,8 @@ def load(cls, pretrained_model_name_or_path, n_added_tokens=0, language_model_cl
language_model_class = 'XLMRoberta'
elif 'roberta' in pretrained_model_name_or_path:
language_model_class = 'Roberta'
+ elif 'camembert' in pretrained_model_name_or_path or 'umberto' in pretrained_model_name_or_path:
+ language_model_class = "Camembert"
elif 'albert' in pretrained_model_name_or_path:
language_model_class = 'Albert'
elif 'distilbert' in pretrained_model_name_or_path:
@@ -234,7 +236,17 @@ def _infer_language_from_name(cls, name):
"multilingual",
)
matches = [lang for lang in known_languages if lang in name]
- if len(matches) == 0:
+ if "camembert" in name:
+ language = "french"
+ logger.info(
+ f"Automatically detected language from language model name: {language}"
+ )
+ elif "umberto" in name:
+ language = "italian"
+ logger.info(
+ f"Automatically detected language from language model name: {language}"
+ )
+ elif len(matches) == 0:
language = "english"
logger.warning(
"Could not automatically detect from language model name what language it is. \n"
@@ -242,11 +254,13 @@ def _infer_language_from_name(cls, name):
"\t If not: Init the language model by supplying the 'language' param."
)
elif len(matches) > 1:
- raise ValueError(
+ logger.warning(
"Could not automatically detect from language model name what language it is.\n"
f"\t Found multiple matches: {matches}\n"
"\t Please init the language model by manually supplying the 'language' as a parameter.\n"
+ f"\t Using {matches[0]} as language parameter for now.\n"
)
+ language = matches[0]
else:
language = matches[0]
logger.info(
@@ -879,7 +893,7 @@ def forward(
)
# XLNet also only returns the sequence_output (one vec per token)
# We need to manually aggregate that to get a pooled output (one vec per seq)
- #TODO verify that this is really doing correct pooling
+ # TODO verify that this is really doing correct pooling
pooled_output = self.pooler(output_tuple[0])
if self.model.output_hidden_states == True:
@@ -1282,3 +1296,47 @@ def enable_hidden_states_output(self):
def disable_hidden_states_output(self):
self.model.config.output_hidden_states = False
+
+class Camembert(Roberta):
+ """
+ A Camembert model that wraps the HuggingFace's implementation
+ (https://github.com/huggingface/transformers) to fit the LanguageModel class.
+ """
+ def __init__(self):
+ super(Camembert, self).__init__()
+ self.model = None
+ self.name = "camembert"
+
+ @classmethod
+ def load(cls, pretrained_model_name_or_path, language=None, **kwargs):
+ """
+ Load a language model either by supplying
+
+ * the name of a remote model on s3 ("camembert-base" ...)
+ * or a local path of a model trained via transformers ("some_dir/huggingface_model")
+ * or a local path of a model trained via FARM ("some_dir/farm_model")
+
+ :param pretrained_model_name_or_path: name or path of a model
+ :param language: (Optional) Name of language the model was trained for (e.g. "german").
+ If not supplied, FARM will try to infer it from the model name.
+ :return: Language Model
+
+ """
+ camembert = cls()
+ if "farm_lm_name" in kwargs:
+ camembert.name = kwargs["farm_lm_name"]
+ else:
+ camembert.name = pretrained_model_name_or_path
+ # We need to differentiate between loading model using FARM format and Pytorch-Transformers format
+ farm_lm_config = Path(pretrained_model_name_or_path) / "language_model_config.json"
+ if os.path.exists(farm_lm_config):
+ # FARM style
+ config = CamembertConfig.from_pretrained(farm_lm_config)
+ farm_lm_model = Path(pretrained_model_name_or_path) / "language_model.bin"
+ camembert.model = CamembertModel.from_pretrained(farm_lm_model, config=config, **kwargs)
+ camembert.language = camembert.model.config.language
+ else:
+ # Huggingface transformer Style
+ camembert.model = CamembertModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs)
+ camembert.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path)
+ return camembert
diff --git a/farm/modeling/tokenization.py b/farm/modeling/tokenization.py
index ee6c7f010..d012cb0e6 100644
--- a/farm/modeling/tokenization.py
+++ b/farm/modeling/tokenization.py
@@ -31,6 +31,7 @@
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers.tokenization_xlnet import XLNetTokenizer
+from transformers.tokenization_camembert import CamembertTokenizer
from farm.modeling.wordembedding_utils import load_from_cache, EMBEDDING_VOCAB_FILES_MAP, run_split_on_punc
@@ -69,6 +70,8 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, **kwargs):
tokenizer_class = "XLMRobertaTokenizer"
elif "roberta" in pretrained_model_name_or_path.lower():
tokenizer_class = "RobertaTokenizer"
+ elif "camembert" in pretrained_model_name_or_path.lower() or "umberto" in pretrained_model_name_or_path:
+ tokenizer_class = "CamembertTokenizer"
elif "distilbert" in pretrained_model_name_or_path.lower():
tokenizer_class = "DistilBertTokenizer"
elif "bert" in pretrained_model_name_or_path.lower():
@@ -104,6 +107,8 @@ def load(cls, pretrained_model_name_or_path, tokenizer_class=None, **kwargs):
ret = ElectraTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif tokenizer_class == "EmbeddingTokenizer":
ret = EmbeddingTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ elif tokenizer_class == "CamembertTokenizer":
+ ret = CamembertTokenizer._from_pretrained(pretrained_model_name_or_path, **kwargs)
if ret is None:
raise Exception("Unable to load tokenizer")
else:
diff --git a/readme.rst b/readme.rst
index 89b56e984..122085379 100644
--- a/readme.rst
+++ b/readme.rst
@@ -66,7 +66,7 @@ Core features
- Simple **deployment** and **visualization** to showcase your model
+------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
-| Task | BERT | RoBERTa | XLNet | ALBERT | DistilBERT | XLMRoBERTa |
+| Task | BERT | RoBERTa* | XLNet | ALBERT | DistilBERT | XLMRoBERTa |
+==============================+===================+===================+===================+===================+===================+===================+
| Text classification | x | x | x | x | x | x |
+------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
@@ -89,6 +89,8 @@ Core features
| Passage Ranking | x | x | x | x | x | x |
+------------------------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+
+\* including CamemBERT and UmBERTo
+
****NEW**** Interested in doing Question Answering at scale? Checkout `Haystack `_!
Resources