Skip to content

Commit

Permalink
Merge branch 'staging_0_1_1' into issue_340
Browse files Browse the repository at this point in the history
  • Loading branch information
harshbafna authored May 19, 2020
2 parents a2958bc + d96ab79 commit 11adf0e
Show file tree
Hide file tree
Showing 29 changed files with 1,843 additions and 638 deletions.
103 changes: 103 additions & 0 deletions examples/Huggingface_Transformers/Download_Transformer_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import transformers
from pathlib import Path
import os
import json
import torch
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,
AutoModelForTokenClassification, AutoConfig)
""" This function, save the checkpoint, config file along with tokenizer config and vocab files
of a transformer model of your choice.
"""
print('Transformers version',transformers.__version__)

def transformers_model_dowloader(mode,pretrained_model_name,num_labels,do_lower_case):
print("Download model and tokenizer", pretrained_model_name)
#loading pre-trained model and tokenizer
if mode== "sequence_classification":
config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels)
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
elif mode== "question_answering":
model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
elif mode== "token_classification":
config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels)
model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)


# NOTE : for demonstration purposes, we do not go through the fine-tune processing here.
# A Fine_tunining process based on your needs can be added.
# An example of Colab notebook for Fine_tunining process has been provided in the README.


""" For demonstration purposes, we show an example of using question answering
data preprocessing and inference. Using the pre-trained models will not yeild
good results, we need to use fine_tuned models. For example, instead of "bert-base-uncased",
if 'bert-large-uncased-whole-word-masking-finetuned-squad' be passed to
the AutoModelForSequenceClassification.from_pretrained(pretrained_model_name)
a better result can be achieved. Models such as RoBERTa, xlm, xlnet,etc. can be
passed as the pre_trained models.
"""

text = r"""
🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
TensorFlow 2.0 and PyTorch.
"""

questions = [
"How many pretrained models are available in Transformers?",
"What does Transformers provide?",
"Transformers provides interoperability between which frameworks?",
]

device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
model.to(device)


for question in questions:
inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
for key in inputs.keys():
inputs[key]= inputs[key].to(device)

input_ids = inputs["input_ids"].tolist()[0]

# text_tokens = tokenizer.convert_ids_to_tokens(input_ids)

answer_start_scores, answer_end_scores = model(**inputs)

answer_start = torch.argmax(
answer_start_scores
) # Get the most likely beginning of answer with the argmax of the score
answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score

answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

print(f"Question: {question}")
print(f"Answer: {answer}\n")


NEW_DIR = "./Transformer_model"
try:
os.mkdir(NEW_DIR)
except OSError:
print ("Creation of directory %s failed" % NEW_DIR)
else:
print ("Successfully created directory %s " % NEW_DIR)

print("Save model and tokenizer", pretrained_model_name, 'in directory', NEW_DIR)
model.save_pretrained(NEW_DIR)
tokenizer.save_pretrained(NEW_DIR)
return
if __name__== "__main__":
dirname = os.path.dirname(__file__)
filename = os.path.join(dirname, 'setup_config.json')
f = open(filename)
options = json.load(f)
mode = options["mode"]
model_name = options["model_name"]
num_labels = int(options["num_labels"])
do_lower_case = options["do_lower_case"]
transformers_model_dowloader(mode,model_name, num_labels,do_lower_case)
87 changes: 87 additions & 0 deletions examples/Huggingface_Transformers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
## Serving Huggingface Transformers using TorchServe

In this example, we show how to serve a Fine_tuned or off-the-shelf Transformer model from [huggingface](https://huggingface.co/transformers/index.html) using TorchServe. We use a custom handler, Transformer_handler.py. This handler enables us to use pre-trained transformer models from Hugginface, such as BERT, RoBERTA, XLM, etc. for use-cases defined in the AutoModel class such as AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification, and AutoModelWithLMHead will be added later.

We borrowed ideas to write a custom handler for transformers from tutorial presented in [mnist from image classifiers examples](https://github.com/pytorch/serve/tree/master/examples/image_classifier/mnist) and the post by [MFreidank](https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18).

First, we need to make sure that have installed the Transformers, it can be installed using the following command.

`pip install transformers`

### Objectives
1. Demonstrate how to package a transformer model with custom handler into torch model archive (.mar) file
2. Demonstrate how to load model archive (.mar) file into TorchServe and run inference.

### Serving a Model using TorchServe

To serve a model using TrochServe following steps are required.

- Frist, preparing the requirements for packaging a model, including serialized model, and other required files.
- Create a torch model archive using the "torch-model-archiver" utility to archive the above files along with a handler ( in this example custom handler for transformers) .
- Register the model on TorchServe using the above model archive file and run the inference.

### **Getting Started with the Demo**

There are two paths to obtain the required model files for this demo.

- **Option A** : To yield desired results, one should fine-tuned each of the intended models to use before hand and saving the model and tokenizer using "save_pretrained() ". This will result in pytorch_model.bin file along with vocab.txt and config.json files. These files should be moved to a folder named "Transformer_model" in the current directory.

- **Option B**: There is another option just for demonstration purposes, to simply run "Download_Transformer_models.py", . The "Download_Transformer_models.py" script loads and saves the required files mentioned above in "Transformer_model" directory, using a setup config file, "setup_config.json". Also, settings in "setup_config.json", are used in the handler, "Transformer_handler_generalized.py", as well to operate on the selected mode and other related settings.

#### Setting the setup_config.json

In the setup_config.json :

*model_name* : bert-base-uncased , roberta-base or other available pre-trained models.

*mode:* "sequence_classification "for sequence classification, "question_answering "for question answering and "token_classification" for token classification.

*do_lower_case* : True or False for use of the Tokenizer.

*num_labels* : number of outputs for "sequence_classification", or "token_classification".

Once, setup_config.json has been set properly, the next step is to run " Download_Transformer_models.py":

`python Download_Transformer_models.py`

This produces all the required files for packaging using a huggingface transformer model off-the-shelf without fine-tuning process. Using this option will create and saved the required files into Transformer_model directory. In case, the "vocab.txt" was not saved into this directory, we can load the tokenizer from pre-trained model vocab, this case has been addressed in the handler.



#### Setting the extra_files

There are few files that are used for model packaging and at the inference time. "index_to_name.json" is passed as extra file to the model archiver and used for mapping predictions to labels. "sample_text.txt", is used at the inference time to pass the text that we want to get the inference on.

index_to_name.json for question answering is not required.

If intended to use Transformer handler for Token classification, the index_to_name.json should be formatted as follows for example:

`{"label_list":"[O, B-MISC, I-MISC, B-PER,I-PER,B-ORG,I-ORG,B-LOC,I-LOC]"}`

To use Transformer handler for question answering, the sample_text.txt should be formatted as follows:

`{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"}`

"question" represents the question to be asked from the source text named as "context" here.

### Creating a torch Model Archive

Once, setup_config.json, sample_text.txt and index_to_name.json are set properly, we can go ahead and package the model and start serving it. The current setting in "setup_config.json" is based on "roberta_base " model for question answering. To fine-tuned RoBERTa can be obtained from running [squad example](https://huggingface.co/transformers/examples.html#squad) from huggingface. Alternatively a fine_tuned BERT model can be used by setting "model_name" to "bert-large-uncased-whole-word-masking-finetuned-squad" in the "setup_config.json".

```
torch-model-archiver --model-name RobertaQA --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json"
```

### Registering the Model on TorchServe and Running Inference

To register the model on TorchServe using the above model archive file, we run the following commands:

```
mkdir model_store
mv RobertaQA.mar model_store/
torchserve --start --model-store model_store --models my_tc=RobertaQA.mar
```

- To run the inference using our registered model, open a new terminal and run: `curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ./sample_text.txt`
157 changes: 157 additions & 0 deletions examples/Huggingface_Transformers/Transformer_handler_generalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from abc import ABC
import json
import logging
import os
import ast
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,AutoModelForTokenClassification

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class TransformersSeqClassifierHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""
def __init__(self):
super(TransformersSeqClassifierHandler, self).__init__()
self.initialized = False

def initialize(self, ctx):
self.manifest = ctx.manifest

properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
#read configs for the mode, model_name, etc. from setup_config.json
setup_config_path = os.path.join(model_dir, "setup_config.json")

if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_file:
self.setup_config = json.load(setup_config_file)
else:
logger.warning('Missing the setup_config.json file.')

#Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
#further setup config can be added.

if self.setup_config["mode"]== "sequence_classification":
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
elif self.setup_config["mode"]== "question_answering":
self.model = AutoModelForQuestionAnswering.from_pretrained(model_dir)
elif self.setup_config["mode"]== "token_classification":
self.model = AutoModelForTokenClassification.from_pretrained(model_dir)
else:
logger.warning('Missing the operation mode.')

if not os.path.isfile(os.path.join(model_dir, "vocab.txt")):
self.tokenizer = AutoTokenizer.from_pretrained(self.setup_config["model_name"],do_lower_case=self.setup_config["do_lower_case"])
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_dir,do_lower_case=self.setup_config["do_lower_case"])

self.model.to(self.device)
self.model.eval()

logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))

# Read the mapping file, index to object name
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
# Question answering does not need the index_to_name.json file.
if not self.setup_config["mode"]== "question_answering":
if os.path.isfile(mapping_file_path):
with open(mapping_file_path) as f:
self.mapping = json.load(f)
else:
logger.warning('Missing the index_to_name.json file.')

self.initialized = True

def preprocess(self, data):
""" Basic text preprocessing, based on the user's chocie of application mode.
"""
text = data[0].get("data")
if text is None:
text = data[0].get("body")
input_text = text.decode('utf-8')
logger.info("Received text: '%s'", input_text)
#preprocessing text for sequence_classification and token_classification.
if self.setup_config["mode"]== "sequence_classification" or self.setup_config["mode"]== "token_classification" :
inputs = self.tokenizer.encode_plus(input_text, add_special_tokens = True, return_tensors = 'pt')
#preprocessing text for question_answering.
elif self.setup_config["mode"]== "question_answering":
# the sample text for question_answering should be formated as dictionary
# with question and text as keys and related text as values.
# we use this format here seperate question and text for encoding.
#TODO extend to handle multiple questions, cleaning and dealing with long the context.
question_context= ast.literal_eval(input_text)
question = question_context["question"]
context = question_context["context"]
inputs = self.tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")

return inputs

def inference(self, inputs):
""" Predict the class (or classes) of the received text using the serialized transformers checkpoint.
"""


input_ids = inputs["input_ids"].to(self.device)
# Handling inference for sequence_classification.
if self.setup_config["mode"]== "sequence_classification":
predictions = self.model(input_ids)
prediction = predictions[0].argmax(1).item()

logger.info("Model predicted: '%s'", prediction)

if self.mapping:
prediction = self.mapping[str(prediction)]
# Handling inference for question_answering.
elif self.setup_config["mode"]== "question_answering":
# the output should be only answer_start and answer_end
# we are outputing the words just for demonstration.
answer_start_scores, answer_end_scores = self.model(input_ids)
answer_start = torch.argmax(answer_start_scores) # Get the most likely beginning of answer with the argmax of the score
answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score
input_ids = inputs["input_ids"].tolist()[0]
prediction = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))

logger.info("Model predicted: '%s'", prediction)
# Handling inference for token_classification.
elif self.setup_config["mode"]== "token_classification":
outputs = self.model(input_ids)[0]
predictions = torch.argmax(outputs, dim=2)
tokens = self.tokenizer.tokenize(self.tokenizer.decode(inputs["input_ids"][0]))
if self.mapping:
label_list = self.mapping["label_list"]
label_list = label_list.strip('][').split(', ')
prediction = [(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())]

logger.info("Model predicted: '%s'", prediction)

return [prediction]

def postprocess(self, inference_output):
# TODO: Add any needed post-processing of the model predictions here
return inference_output


_service = TransformersSeqClassifierHandler()


def handle(data, context):
try:
if not _service.initialized:
_service.initialize(context)

if data is None:
return None

data = _service.preprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)

return data
except Exception as e:
raise e
4 changes: 4 additions & 0 deletions examples/Huggingface_Transformers/index_to_name.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"0":"Not Accepted",
"1":"Accepted"
}
1 change: 1 addition & 0 deletions examples/Huggingface_Transformers/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"}
6 changes: 6 additions & 0 deletions examples/Huggingface_Transformers/setup_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_name":"roberta-base",
"mode":"question_answering",
"do_lower_case":"True",
"num_labels":"0"
}
Loading

0 comments on commit 11adf0e

Please sign in to comment.