diff --git a/jina/executors/encoders/nlp/transformer.py b/jina/executors/encoders/nlp/transformer.py index cbb4db76885a5..a4c6aa5df067f 100644 --- a/jina/executors/encoders/nlp/transformer.py +++ b/jina/executors/encoders/nlp/transformer.py @@ -70,9 +70,9 @@ def post_init(self): if self.model_name in ('bert-base-uncased', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base'): self.cls_pos = 'head' elif self.model_name in ('xlnet-base-cased'): - self.tokenizer.pad_token = '' self.cls_pos = 'tail' - elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'): + + if self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024', 'xlnet-base-cased'): self.tokenizer.pad_token = '' def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray': @@ -169,7 +169,7 @@ def post_init(self): import tensorflow as tf from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \ TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel - tf_model_dict = { + model_dict = { 'bert-base-uncased': TFBertModel, 'openai-gpt': TFOpenAIGPTModel, 'gpt2': TFGPT2Model, @@ -179,18 +179,16 @@ def post_init(self): 'roberta-base': TFRobertaModel, 'xlm-roberta-base': TFXLMRobertaModel, } - self.model = tf_model_dict[self.model_name].from_pretrained(self._tmp_path) + self.model = model_dict[self.model_name].from_pretrained(self._tmp_path) self._tensor_func = tf.constant self._sess_func = tf.GradientTape - if self.model_name in ('xlnet-base-cased'): - self.model.resize_token_embeddings(len(self.tokenizer)) - elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'): + if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'): self.model.resize_token_embeddings(len(self.tokenizer)) class TransformerTorchEncoder(TransformerEncoder): """ - Internally, TransformerTFEncoder wraps the pytorch-version of transformers from huggingface. + Internally, TransformerTorchEncoder wraps the pytorch-version of transformers from huggingface. """ def post_init(self): super().post_init() @@ -213,7 +211,5 @@ def post_init(self): self._tensor_func = torch.tensor self._sess_func = torch.no_grad - if self.model_name in ('xlnet-base-cased'): - self.model.resize_token_embeddings(len(self.tokenizer)) - elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'): + if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'): self.model.resize_token_embeddings(len(self.tokenizer)) \ No newline at end of file