Skip to content

Commit

Permalink
fix(executor): delete some useless codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiong Ma committed Apr 3, 2020
1 parent aeef453 commit 7fc5eef
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def post_init(self):
if self.model_name in ('bert-base-uncased', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base'):
self.cls_pos = 'head'
elif self.model_name in ('xlnet-base-cased'):
self.tokenizer.pad_token = '<PAD>'
self.cls_pos = 'tail'
elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):

if self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024', 'xlnet-base-cased'):
self.tokenizer.pad_token = '<PAD>'

def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
Expand Down Expand Up @@ -169,7 +169,7 @@ def post_init(self):
import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel
tf_model_dict = {
model_dict = {
'bert-base-uncased': TFBertModel,
'openai-gpt': TFOpenAIGPTModel,
'gpt2': TFGPT2Model,
Expand All @@ -179,18 +179,16 @@ def post_init(self):
'roberta-base': TFRobertaModel,
'xlm-roberta-base': TFXLMRobertaModel,
}
self.model = tf_model_dict[self.model_name].from_pretrained(self._tmp_path)
self.model = model_dict[self.model_name].from_pretrained(self._tmp_path)
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape

if self.model_name in ('xlnet-base-cased'):
self.model.resize_token_embeddings(len(self.tokenizer))
elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
self.model.resize_token_embeddings(len(self.tokenizer))

class TransformerTorchEncoder(TransformerEncoder):
"""
Internally, TransformerTFEncoder wraps the pytorch-version of transformers from huggingface.
Internally, TransformerTorchEncoder wraps the pytorch-version of transformers from huggingface.
"""
def post_init(self):
super().post_init()
Expand All @@ -213,7 +211,5 @@ def post_init(self):
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad

if self.model_name in ('xlnet-base-cased'):
self.model.resize_token_embeddings(len(self.tokenizer))
elif self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
self.model.resize_token_embeddings(len(self.tokenizer))

0 comments on commit 7fc5eef

Please sign in to comment.