Skip to content

Commit

Permalink
fix(executor): add pytorch at first
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiong Ma committed Apr 3, 2020
1 parent be1a4ec commit 78048e5
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,42 @@ def post_init(self):
raise ValueError

try:
import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel
tf_model_dict = {
'bert-base-uncased': TFBertModel,
'openai-gpt': TFOpenAIGPTModel,
'gpt2': TFGPT2Model,
'xlnet-base-cased': TFXLNetModel,
'xlm-mlm-enfr-1024': TFXLMModel,
'distilbert-base-cased': TFDistilBertModel,
'roberta-base': TFRobertaModel,
'xlm-roberta-base': TFXLMRobertaModel,
import torch
from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \
RobertaModel, XLMRobertaModel

model_dict = {
'bert-base-uncased': BertModel,
'openai-gpt': OpenAIGPTModel,
'gpt2': GPT2Model,
'xlnet-base-cased': XLNetModel,
'xlm-mlm-enfr-1024': XLMModel,
'distilbert-base-cased': DistilBertModel,
'roberta-base': RobertaModel,
'xlm-roberta-base': XLMRobertaModel,
}
model_class = tf_model_dict[self.model_name]
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape
model_class = model_dict[self.model_name]
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad

except:
try:
import torch
from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \
RobertaModel, XLMRobertaModel

model_dict = {
'bert-base-uncased': BertModel,
'openai-gpt': OpenAIGPTModel,
'gpt2': GPT2Model,
'xlnet-base-cased': XLNetModel,
'xlm-mlm-enfr-1024': XLMModel,
'distilbert-base-cased': DistilBertModel,
'roberta-base': RobertaModel,
'xlm-roberta-base': XLMRobertaModel,
import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel
tf_model_dict = {
'bert-base-uncased': TFBertModel,
'openai-gpt': TFOpenAIGPTModel,
'gpt2': TFGPT2Model,
'xlnet-base-cased': TFXLNetModel,
'xlm-mlm-enfr-1024': TFXLMModel,
'distilbert-base-cased': TFDistilBertModel,
'roberta-base': TFRobertaModel,
'xlm-roberta-base': TFXLMRobertaModel,
}
model_class = model_dict[self.model_name]
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad
model_class = tf_model_dict[self.model_name]
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape
except:
raise ModuleNotFoundError('Tensorflow or Pytorch is required!')

Expand Down

0 comments on commit 78048e5

Please sign in to comment.