diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index effb8bf84..ee257f6bf 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -53,7 +53,7 @@ def forward(self, input_ids, padding_mask, **kwargs): raise NotImplementedError @classmethod - def load(cls, pretrained_model_name_or_path): + def load(cls, pretrained_model_name_or_path, **kwargs): """ Load a pretrained language model either by @@ -87,11 +87,11 @@ def load(cls, pretrained_model_name_or_path): else: # it's a model name which we try to resolve from s3. for now only works for bert models if 'roberta' in pretrained_model_name_or_path: - language_model = cls.subclasses["Roberta"].load(pretrained_model_name_or_path) + language_model = cls.subclasses["Roberta"].load(pretrained_model_name_or_path, **kwargs) elif 'bert' in pretrained_model_name_or_path: - language_model = cls.subclasses["Bert"].load(pretrained_model_name_or_path) + language_model = cls.subclasses["Bert"].load(pretrained_model_name_or_path, **kwargs) elif 'xlnet' in pretrained_model_name_or_path: - language_model = cls.subclasses["XLNet"].load(pretrained_model_name_or_path) + language_model = cls.subclasses["XLNet"].load(pretrained_model_name_or_path, **kwargs) assert language_model is not None @@ -225,7 +225,7 @@ def __init__(self): self.name = "bert" @classmethod - def load(cls, pretrained_model_name_or_path, language=None): + def load(cls, pretrained_model_name_or_path, language=None, **kwargs): """ Load a pretrained model by supplying @@ -246,11 +246,11 @@ def load(cls, pretrained_model_name_or_path, language=None): # FARM style bert_config = BertConfig.from_pretrained(farm_lm_config) farm_lm_model = os.path.join(pretrained_model_name_or_path, "language_model.bin") - bert.model = BertModel.from_pretrained(farm_lm_model, config=bert_config) + bert.model = BertModel.from_pretrained(farm_lm_model, config=bert_config, **kwargs) bert.language = bert.model.config.language else: # Pytorch-transformer Style - bert.model = BertModel.from_pretrained(pretrained_model_name_or_path) + bert.model = BertModel.from_pretrained(pretrained_model_name_or_path, **kwargs) bert.language = cls._infer_language_from_name(pretrained_model_name_or_path) return bert @@ -316,7 +316,7 @@ def __init__(self): self.name = "roberta" @classmethod - def load(cls, pretrained_model_name_or_path, language=None): + def load(cls, pretrained_model_name_or_path, language=None, **kwargs): """ Load a language model either by supplying @@ -338,11 +338,11 @@ def load(cls, pretrained_model_name_or_path, language=None): # FARM style config = RobertaConfig.from_pretrained(farm_lm_config) farm_lm_model = os.path.join(pretrained_model_name_or_path, "language_model.bin") - roberta.model = RobertaModel.from_pretrained(farm_lm_model, config=config) + roberta.model = RobertaModel.from_pretrained(farm_lm_model, config=config, **kwargs) roberta.language = roberta.model.config.language else: # Huggingface transformer Style - roberta.model = RobertaModel.from_pretrained(pretrained_model_name_or_path) + roberta.model = RobertaModel.from_pretrained(pretrained_model_name_or_path, **kwargs) roberta.language = cls._infer_language_from_name(pretrained_model_name_or_path) return roberta @@ -408,7 +408,7 @@ def __init__(self): self.pooler = None @classmethod - def load(cls, pretrained_model_name_or_path, language=None): + def load(cls, pretrained_model_name_or_path, language=None, **kwargs): """ Load a language model either by supplying @@ -430,11 +430,11 @@ def load(cls, pretrained_model_name_or_path, language=None): # FARM style config = XLNetConfig.from_pretrained(farm_lm_config) farm_lm_model = os.path.join(pretrained_model_name_or_path, "language_model.bin") - xlnet.model = XLNetModel.from_pretrained(farm_lm_model, config=config) + xlnet.model = XLNetModel.from_pretrained(farm_lm_model, config=config, **kwargs) xlnet.language = xlnet.model.config.language else: # Pytorch-transformer Style - xlnet.model = XLNetModel.from_pretrained(pretrained_model_name_or_path) + xlnet.model = XLNetModel.from_pretrained(pretrained_model_name_or_path, **kwargs) xlnet.language = cls._infer_language_from_name(pretrained_model_name_or_path) config = xlnet.model.config # XLNet does not provide a pooled_output by default. Therefore, we need to initialize an extra pooler. @@ -499,4 +499,4 @@ def save_config(self, save_dir): setattr(self.model.config, "name", self.__class__.__name__) setattr(self.model.config, "language", self.language) string = self.model.config.to_json_string() - file.write(string) \ No newline at end of file + file.write(string)