diff --git a/examples/Huggingface_Transformers/Download_Transformer_models.py b/examples/Huggingface_Transformers/Download_Transformer_models.py new file mode 100644 index 0000000000..3357ba8240 --- /dev/null +++ b/examples/Huggingface_Transformers/Download_Transformer_models.py @@ -0,0 +1,103 @@ +import transformers +from pathlib import Path +import os +import json +import torch +from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering, + AutoModelForTokenClassification, AutoConfig) +""" This function, save the checkpoint, config file along with tokenizer config and vocab files + of a transformer model of your choice. +""" +print('Transformers version',transformers.__version__) + +def transformers_model_dowloader(mode,pretrained_model_name,num_labels,do_lower_case): + print("Download model and tokenizer", pretrained_model_name) + #loading pre-trained model and tokenizer + if mode== "sequence_classification": + config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels) + model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, config=config) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) + elif mode== "question_answering": + model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) + elif mode== "token_classification": + config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels) + model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case) + + + # NOTE : for demonstration purposes, we do not go through the fine-tune processing here. + # A Fine_tunining process based on your needs can be added. + # An example of Colab notebook for Fine_tunining process has been provided in the README. + + + """ For demonstration purposes, we show an example of using question answering + data preprocessing and inference. Using the pre-trained models will not yeild + good results, we need to use fine_tuned models. For example, instead of "bert-base-uncased", + if 'bert-large-uncased-whole-word-masking-finetuned-squad' be passed to + the AutoModelForSequenceClassification.from_pretrained(pretrained_model_name) + a better result can be achieved. Models such as RoBERTa, xlm, xlnet,etc. can be + passed as the pre_trained models. + """ + + text = r""" + 🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose + architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural + Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between + TensorFlow 2.0 and PyTorch. + """ + + questions = [ + "How many pretrained models are available in Transformers?", + "What does Transformers provide?", + "Transformers provides interoperability between which frameworks?", + ] + + device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + model.to(device) + + + for question in questions: + inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt") + for key in inputs.keys(): + inputs[key]= inputs[key].to(device) + + input_ids = inputs["input_ids"].tolist()[0] + + # text_tokens = tokenizer.convert_ids_to_tokens(input_ids) + + answer_start_scores, answer_end_scores = model(**inputs) + + answer_start = torch.argmax( + answer_start_scores + ) # Get the most likely beginning of answer with the argmax of the score + answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score + + answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) + + print(f"Question: {question}") + print(f"Answer: {answer}\n") + + + NEW_DIR = "./Transformer_model" + try: + os.mkdir(NEW_DIR) + except OSError: + print ("Creation of directory %s failed" % NEW_DIR) + else: + print ("Successfully created directory %s " % NEW_DIR) + + print("Save model and tokenizer", pretrained_model_name, 'in directory', NEW_DIR) + model.save_pretrained(NEW_DIR) + tokenizer.save_pretrained(NEW_DIR) + return +if __name__== "__main__": + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, 'setup_config.json') + f = open(filename) + options = json.load(f) + mode = options["mode"] + model_name = options["model_name"] + num_labels = int(options["num_labels"]) + do_lower_case = options["do_lower_case"] + transformers_model_dowloader(mode,model_name, num_labels,do_lower_case) diff --git a/examples/Huggingface_Transformers/README.md b/examples/Huggingface_Transformers/README.md new file mode 100644 index 0000000000..24b9b70a9a --- /dev/null +++ b/examples/Huggingface_Transformers/README.md @@ -0,0 +1,87 @@ +## Serving Huggingface Transformers using TorchServe + +In this example, we show how to serve a Fine_tuned or off-the-shelf Transformer model from [huggingface](https://huggingface.co/transformers/index.html) using TorchServe. We use a custom handler, Transformer_handler.py. This handler enables us to use pre-trained transformer models from Hugginface, such as BERT, RoBERTA, XLM, etc. for use-cases defined in the AutoModel class such as AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification, and AutoModelWithLMHead will be added later. + +We borrowed ideas to write a custom handler for transformers from tutorial presented in [mnist from image classifiers examples](https://github.com/pytorch/serve/tree/master/examples/image_classifier/mnist) and the post by [MFreidank](https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18). + +First, we need to make sure that have installed the Transformers, it can be installed using the following command. + + `pip install transformers` + +### Objectives +1. Demonstrate how to package a transformer model with custom handler into torch model archive (.mar) file +2. Demonstrate how to load model archive (.mar) file into TorchServe and run inference. + +### Serving a Model using TorchServe + +To serve a model using TrochServe following steps are required. + +- Frist, preparing the requirements for packaging a model, including serialized model, and other required files. +- Create a torch model archive using the "torch-model-archiver" utility to archive the above files along with a handler ( in this example custom handler for transformers) . +- Register the model on TorchServe using the above model archive file and run the inference. + +### **Getting Started with the Demo** + +There are two paths to obtain the required model files for this demo. + +- **Option A** : To yield desired results, one should fine-tuned each of the intended models to use before hand and saving the model and tokenizer using "save_pretrained() ". This will result in pytorch_model.bin file along with vocab.txt and config.json files. These files should be moved to a folder named "Transformer_model" in the current directory. + +- **Option B**: There is another option just for demonstration purposes, to simply run "Download_Transformer_models.py", . The "Download_Transformer_models.py" script loads and saves the required files mentioned above in "Transformer_model" directory, using a setup config file, "setup_config.json". Also, settings in "setup_config.json", are used in the handler, "Transformer_handler_generalized.py", as well to operate on the selected mode and other related settings. + +#### Setting the setup_config.json + +In the setup_config.json : + +*model_name* : bert-base-uncased , roberta-base or other available pre-trained models. + +*mode:* "sequence_classification "for sequence classification, "question_answering "for question answering and "token_classification" for token classification. + +*do_lower_case* : True or False for use of the Tokenizer. + +*num_labels* : number of outputs for "sequence_classification", or "token_classification". + +Once, setup_config.json has been set properly, the next step is to run " Download_Transformer_models.py": + +`python Download_Transformer_models.py` + +This produces all the required files for packaging using a huggingface transformer model off-the-shelf without fine-tuning process. Using this option will create and saved the required files into Transformer_model directory. In case, the "vocab.txt" was not saved into this directory, we can load the tokenizer from pre-trained model vocab, this case has been addressed in the handler. + + + +#### Setting the extra_files + +There are few files that are used for model packaging and at the inference time. "index_to_name.json" is passed as extra file to the model archiver and used for mapping predictions to labels. "sample_text.txt", is used at the inference time to pass the text that we want to get the inference on. + +index_to_name.json for question answering is not required. + +If intended to use Transformer handler for Token classification, the index_to_name.json should be formatted as follows for example: + +`{"label_list":"[O, B-MISC, I-MISC, B-PER,I-PER,B-ORG,I-ORG,B-LOC,I-LOC]"}` + +To use Transformer handler for question answering, the sample_text.txt should be formatted as follows: + +`{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"}` + +"question" represents the question to be asked from the source text named as "context" here. + +### Creating a torch Model Archive + +Once, setup_config.json, sample_text.txt and index_to_name.json are set properly, we can go ahead and package the model and start serving it. The current setting in "setup_config.json" is based on "roberta_base " model for question answering. To fine-tuned RoBERTa can be obtained from running [squad example](https://huggingface.co/transformers/examples.html#squad) from huggingface. Alternatively a fine_tuned BERT model can be used by setting "model_name" to "bert-large-uncased-whole-word-masking-finetuned-squad" in the "setup_config.json". + +``` +torch-model-archiver --model-name RobertaQA --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json" + +``` + +### Registering the Model on TorchServe and Running Inference + +To register the model on TorchServe using the above model archive file, we run the following commands: + +``` +mkdir model_store +mv RobertaQA.mar model_store/ +torchserve --start --model-store model_store --models my_tc=RobertaQA.mar + +``` + +- To run the inference using our registered model, open a new terminal and run: `curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ./sample_text.txt` \ No newline at end of file diff --git a/examples/Huggingface_Transformers/Transformer_handler_generalized.py b/examples/Huggingface_Transformers/Transformer_handler_generalized.py new file mode 100644 index 0000000000..084d233d09 --- /dev/null +++ b/examples/Huggingface_Transformers/Transformer_handler_generalized.py @@ -0,0 +1,157 @@ +from abc import ABC +import json +import logging +import os +import ast +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,AutoModelForTokenClassification + +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class TransformersSeqClassifierHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + def __init__(self): + super(TransformersSeqClassifierHandler, self).__init__() + self.initialized = False + + def initialize(self, ctx): + self.manifest = ctx.manifest + + properties = ctx.system_properties + model_dir = properties.get("model_dir") + self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") + #read configs for the mode, model_name, etc. from setup_config.json + setup_config_path = os.path.join(model_dir, "setup_config.json") + + if os.path.isfile(setup_config_path): + with open(setup_config_path) as setup_config_file: + self.setup_config = json.load(setup_config_file) + else: + logger.warning('Missing the setup_config.json file.') + + #Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode + #further setup config can be added. + + if self.setup_config["mode"]== "sequence_classification": + self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) + elif self.setup_config["mode"]== "question_answering": + self.model = AutoModelForQuestionAnswering.from_pretrained(model_dir) + elif self.setup_config["mode"]== "token_classification": + self.model = AutoModelForTokenClassification.from_pretrained(model_dir) + else: + logger.warning('Missing the operation mode.') + + if not os.path.isfile(os.path.join(model_dir, "vocab.txt")): + self.tokenizer = AutoTokenizer.from_pretrained(self.setup_config["model_name"],do_lower_case=self.setup_config["do_lower_case"]) + else: + self.tokenizer = AutoTokenizer.from_pretrained(model_dir,do_lower_case=self.setup_config["do_lower_case"]) + + self.model.to(self.device) + self.model.eval() + + logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir)) + + # Read the mapping file, index to object name + mapping_file_path = os.path.join(model_dir, "index_to_name.json") + # Question answering does not need the index_to_name.json file. + if not self.setup_config["mode"]== "question_answering": + if os.path.isfile(mapping_file_path): + with open(mapping_file_path) as f: + self.mapping = json.load(f) + else: + logger.warning('Missing the index_to_name.json file.') + + self.initialized = True + + def preprocess(self, data): + """ Basic text preprocessing, based on the user's chocie of application mode. + """ + text = data[0].get("data") + if text is None: + text = data[0].get("body") + input_text = text.decode('utf-8') + logger.info("Received text: '%s'", input_text) + #preprocessing text for sequence_classification and token_classification. + if self.setup_config["mode"]== "sequence_classification" or self.setup_config["mode"]== "token_classification" : + inputs = self.tokenizer.encode_plus(input_text, add_special_tokens = True, return_tensors = 'pt') + #preprocessing text for question_answering. + elif self.setup_config["mode"]== "question_answering": + # the sample text for question_answering should be formated as dictionary + # with question and text as keys and related text as values. + # we use this format here seperate question and text for encoding. + #TODO extend to handle multiple questions, cleaning and dealing with long the context. + question_context= ast.literal_eval(input_text) + question = question_context["question"] + context = question_context["context"] + inputs = self.tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt") + + return inputs + + def inference(self, inputs): + """ Predict the class (or classes) of the received text using the serialized transformers checkpoint. + """ + + + input_ids = inputs["input_ids"].to(self.device) + # Handling inference for sequence_classification. + if self.setup_config["mode"]== "sequence_classification": + predictions = self.model(input_ids) + prediction = predictions[0].argmax(1).item() + + logger.info("Model predicted: '%s'", prediction) + + if self.mapping: + prediction = self.mapping[str(prediction)] + # Handling inference for question_answering. + elif self.setup_config["mode"]== "question_answering": + # the output should be only answer_start and answer_end + # we are outputing the words just for demonstration. + answer_start_scores, answer_end_scores = self.model(input_ids) + answer_start = torch.argmax(answer_start_scores) # Get the most likely beginning of answer with the argmax of the score + answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score + input_ids = inputs["input_ids"].tolist()[0] + prediction = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])) + + logger.info("Model predicted: '%s'", prediction) + # Handling inference for token_classification. + elif self.setup_config["mode"]== "token_classification": + outputs = self.model(input_ids)[0] + predictions = torch.argmax(outputs, dim=2) + tokens = self.tokenizer.tokenize(self.tokenizer.decode(inputs["input_ids"][0])) + if self.mapping: + label_list = self.mapping["label_list"] + label_list = label_list.strip('][').split(', ') + prediction = [(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())] + + logger.info("Model predicted: '%s'", prediction) + + return [prediction] + + def postprocess(self, inference_output): + # TODO: Add any needed post-processing of the model predictions here + return inference_output + + +_service = TransformersSeqClassifierHandler() + + +def handle(data, context): + try: + if not _service.initialized: + _service.initialize(context) + + if data is None: + return None + + data = _service.preprocess(data) + data = _service.inference(data) + data = _service.postprocess(data) + + return data + except Exception as e: + raise e diff --git a/examples/Huggingface_Transformers/index_to_name.json b/examples/Huggingface_Transformers/index_to_name.json new file mode 100644 index 0000000000..9ccff719f6 --- /dev/null +++ b/examples/Huggingface_Transformers/index_to_name.json @@ -0,0 +1,4 @@ +{ + "0":"Not Accepted", + "1":"Accepted" +} diff --git a/examples/Huggingface_Transformers/sample_text.txt b/examples/Huggingface_Transformers/sample_text.txt new file mode 100644 index 0000000000..e011e0184b --- /dev/null +++ b/examples/Huggingface_Transformers/sample_text.txt @@ -0,0 +1 @@ +{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"} diff --git a/examples/Huggingface_Transformers/setup_config.json b/examples/Huggingface_Transformers/setup_config.json new file mode 100644 index 0000000000..c24780961e --- /dev/null +++ b/examples/Huggingface_Transformers/setup_config.json @@ -0,0 +1,6 @@ +{ + "model_name":"roberta-base", + "mode":"question_answering", + "do_lower_case":"True", + "num_labels":"0" +} diff --git a/examples/text_classification/Transformers_README.md b/examples/text_classification/Transformers_README.md new file mode 100644 index 0000000000..15a92e6668 --- /dev/null +++ b/examples/text_classification/Transformers_README.md @@ -0,0 +1,46 @@ +# Serving Transformers Bert using TorchServe + +#### Prerequisite: + +``` +pip install transformers +pip install torch torchtext +pip install torchserve torch-model-archiver +``` + +``` +git clone https://github.com/HamidShojanazeri/serve.git + +``` + +#### Preparing Serialized file for torch-model-archiver: + +``` +cd serve/examples/text_classification +python bert/bert_serialization.py # outputs the jit.traced model "traced_bert.pt". +mv traced_bert.pt bert/ +``` + +#### Archive the model: + +``` +torch-model-archiver --model-name BertSeqClassification --version 1.0 --serialized-file bert/traced_bert.pt --handler bert/bert_handler.py --extra-files ./index_to_name.json +``` + +``` +mkdir model_store +mv BertSeqClassification.mar model_store/ +``` + +#### Start TorchServe to serve the model: + +``` +torchserve --start --model-store model_store --models my_tc=BertSeqClassification.mar +``` + +#### Get predictions from a model: + +``` +curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ./sample_text.txt +``` + diff --git a/examples/text_classification/sample_text.txt b/examples/text_classification/sample_text.txt index dccbbf96f9..760903ca07 100644 --- a/examples/text_classification/sample_text.txt +++ b/examples/text_classification/sample_text.txt @@ -1 +1 @@ -MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind. \ No newline at end of file +Bloomberg has decided to publish a new report on global economic situation. diff --git a/examples/text_to_speech_synthesizer/README.md b/examples/text_to_speech_synthesizer/README.md new file mode 100644 index 0000000000..8fac78d154 --- /dev/null +++ b/examples/text_to_speech_synthesizer/README.md @@ -0,0 +1,59 @@ +# Text to speech synthesis using WaveGlow & Tacotron2 model. + +**This example works only on NVIDIA CUDA device and not on CPU** + +We have used the following Waveglow/Tacotron2 model for this example: + +https://pytorch.org/hub/nvidia_deeplearningexamples_waveglow/ + +We have copied WaveGlow's model file from following github repo: +https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py + + +# Install pip dependencies using following commands + +```bash +pip install numpy scipy unidecode inflect +pip install librosa --user +``` + +# Serve the WaveGlow speech synthesis model on TorchServe + + * Generate the model archive for waveglow speech synthesis model using following command + + ```bash + ./create_mar.sh + ``` + + * Register the model on TorchServe using the above model archive file + + ```bash + mkdir model_store + mv waveglow_synthesizer.mar model_store/ + torchserve --start --model-store model_store --models waveglow_synthesizer.mar + ``` + * Run inference and download audio output using curl command : + ```bash + curl -X POST http://127.0.0.1:8080/predictions/waveglow_synthesizer -T sample_text.txt -o audio.wav + ``` + + * Run inference and download audio output using python script : + + ```python + import requests + + files = {'data': open('sample_text.txt','rb')} + response = requests.post('http://localhost:8080/predictions/waveglow_synthesizer', files=files) + data = response.content + + with open('audio.wav', 'wb') as audio_file: + audio_file.write(data) + ``` + + * Change the host and port in above samples as per your server configuration. + + * Response : + An audio.wav file gets downloaded. + + **Note :** The above example works only for smaller text size. Refer following NVidia/DeepLearningExamples ticket for more details : + https://github.com/NVIDIA/DeepLearningExamples/issues/497 diff --git a/examples/text_to_speech_synthesizer/create_mar.sh b/examples/text_to_speech_synthesizer/create_mar.sh new file mode 100755 index 0000000000..b4c7b6e8ef --- /dev/null +++ b/examples/text_to_speech_synthesizer/create_mar.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -euxo pipefail + +cd /tmp +rm torchhub.zip +wget https://github.com/nvidia/DeepLearningExamples/archive/torchhub.zip +rm -rf DeepLearningExamples-torchhub +unzip torchhub.zip +cd - +rm tacotron.zip +rm -rf PyTorch +mkdir -p PyTorch/SpeechSynthesis +cp -r /tmp/DeepLearningExamples-torchhub/PyTorch/SpeechSynthesis/* PyTorch/SpeechSynthesis/ +zip -r tacotron.zip PyTorch +wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/1/files/nvidia_tacotron2pyt_fp32_20190306.pth +wget https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth +torch-model-archiver --model-name waveglow_synthesizer --version 1.0 --model-file waveglow_model.py --serialized-file nvidia_waveglowpyt_fp32_20190306.pth --handler waveglow_handler.py --extra-files tacotron.zip,nvidia_tacotron2pyt_fp32_20190306.pth +rm nvidia_* diff --git a/examples/text_to_speech_synthesizer/sample_text.txt b/examples/text_to_speech_synthesizer/sample_text.txt new file mode 100644 index 0000000000..98cfee721f --- /dev/null +++ b/examples/text_to_speech_synthesizer/sample_text.txt @@ -0,0 +1 @@ +hello world, I missed you \ No newline at end of file diff --git a/examples/text_to_speech_synthesizer/waveglow_handler.py b/examples/text_to_speech_synthesizer/waveglow_handler.py new file mode 100644 index 0000000000..88f1b82f88 --- /dev/null +++ b/examples/text_to_speech_synthesizer/waveglow_handler.py @@ -0,0 +1,120 @@ +import logging +import numpy as np +import os +import torch +import uuid +import zipfile +from waveglow_model import WaveGlow +from scipy.io.wavfile import write, read + +logger = logging.getLogger(__name__) + + +class WaveGlowSpeechSynthesizer(object): + + def __init__(self): + self.waveglow_model = None + self.tacotron2_model = None + self.mapping = None + self.device = None + self.initialized = False + self.metrics = None + + # From https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py + def _unwrap_distributed(self, state_dict): + """ + Unwraps model from DistributedDataParallel. + DDP wraps model in additional "module.", it needs to be removed for single + GPU inference. + :param state_dict: model's state dict + """ + new_state_dict = {} + for key, value in state_dict.items(): + new_key = key.replace('module.', '') + new_state_dict[new_key] = value + return new_state_dict + + def _load_tacotron2_model(self, model_dir): + from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import model as tacotron2 + from PyTorch.SpeechSynthesis.Tacotron2.tacotron2.text import text_to_sequence + tacotron2_checkpoint = torch.load(os.path.join(model_dir, 'nvidia_tacotron2pyt_fp32_20190306.pth')) + tacotron2_state_dict = self._unwrap_distributed(tacotron2_checkpoint['state_dict']) + tacotron2_config = tacotron2_checkpoint['config'] + self.tacotron2_model = tacotron2.Tacotron2(**tacotron2_config) + self.tacotron2_model.load_state_dict(tacotron2_state_dict) + self.tacotron2_model.text_to_sequence = text_to_sequence + self.tacotron2_model.to(self.device) + + def initialize(self, ctx): + """First try to load torchscript else load eager mode state_dict based model""" + + properties = ctx.system_properties + model_dir = properties.get("model_dir") + if not torch.cuda.is_available(): + raise RuntimeError("This model is not supported on CPU machines.") + self.device = torch.device("cuda:" + str(properties.get("gpu_id"))) + + with zipfile.ZipFile(model_dir + '/tacotron.zip', 'r') as zip_ref: + zip_ref.extractall(model_dir) + + waveglow_checkpoint = torch.load(os.path.join(model_dir, "nvidia_waveglowpyt_fp32_20190306.pth")) + waveglow_state_dict = self._unwrap_distributed(waveglow_checkpoint['state_dict']) + waveglow_config = waveglow_checkpoint['config'] + self.waveglow_model = WaveGlow(**waveglow_config) + self.waveglow_model.load_state_dict(waveglow_state_dict) + self.waveglow_model = self.waveglow_model.remove_weightnorm(self.waveglow_model) + self.waveglow_model.to(self.device) + self.waveglow_model.eval() + + self._load_tacotron2_model(model_dir) + + logger.debug('WaveGlow model file loaded successfully') + self.initialized = True + + def preprocess(self, data): + """ + Scales, crops, and normalizes a PIL image for a MNIST model, + returns an Numpy array + """ + text = data[0].get("data") + if text is None: + text = data[0].get("body") + text = text.decode('utf-8') + + sequence = np.array(self.tacotron2_model.text_to_sequence(text, ['english_cleaners']))[None, :] + sequence = torch.from_numpy(sequence).to(device=self.device, dtype=torch.int64) + + return sequence + + def inference(self, data): + with torch.no_grad(): + _, mel, _, _ = self.tacotron2_model.infer(data) + audio = self.waveglow_model.infer(mel) + + return audio + + def postprocess(self, inference_output): + audio_numpy = inference_output[0].data.cpu().numpy() + path = "/tmp/{}.wav".format(uuid.uuid4().hex) + write(path, 22050, audio_numpy) + with open(path, 'rb') as output: + data = output.read() + os.remove(path) + return [data] + + +_service = WaveGlowSpeechSynthesizer() + + +def handle(data, context): + if not _service.initialized: + _service.initialize(context) + + if data is None: + return None + + data = _service.preprocess(data) + data = _service.inference(data) + data = _service.postprocess(data) + + return data diff --git a/examples/text_to_speech_synthesizer/waveglow_model.py b/examples/text_to_speech_synthesizer/waveglow_model.py new file mode 100644 index 0000000000..c799709a87 --- /dev/null +++ b/examples/text_to_speech_synthesizer/waveglow_model.py @@ -0,0 +1,292 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** +import torch +from torch.autograd import Variable +import torch.nn.functional as F + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Invertible1x1Conv(torch.nn.Module): + """ + The layer outputs both the convolution, and the log determinant + of its weight matrix. If reverse=True it does convolution with + inverse + """ + + def __init__(self, c): + super(Invertible1x1Conv, self).__init__() + self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, + bias=False) + + # Sample a random orthonormal matrix to initialize weights + W = torch.qr(torch.FloatTensor(c, c).normal_())[0] + + # Ensure determinant is 1.0 not -1.0 + if torch.det(W) < 0: + W[:, 0] = -1 * W[:, 0] + W = W.view(c, c, 1) + self.conv.weight.data = W + + def forward(self, z, reverse=False): + # shape + batch_size, group_size, n_of_groups = z.size() + + W = self.conv.weight.squeeze() + + if reverse: + if not hasattr(self, 'W_inverse'): + # Reverse computation + W_inverse = W.float().inverse() + W_inverse = Variable(W_inverse[..., None]) + if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor': + W_inverse = W_inverse.half() + self.W_inverse = W_inverse + z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) + return z + else: + # Forward computation + log_det_W = batch_size * n_of_groups * torch.logdet(W.unsqueeze(0).float()).squeeze() + z = self.conv(z) + return z, log_det_W + + +class WN(torch.nn.Module): + """ + This is the WaveNet like layer for the affine coupling. The primary + difference from WaveNet is the convolutions need not be causal. There is + also no dilation size reset. The dilation only doubles on each layer + """ + + def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, + kernel_size): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + assert(n_channels % 2 == 0) + self.n_layers = n_layers + self.n_channels = n_channels + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.cond_layers = torch.nn.ModuleList() + + start = torch.nn.Conv1d(n_in_channels, n_channels, 1) + start = torch.nn.utils.weight_norm(start, name='weight') + self.start = start + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + for i in range(n_layers): + dilation = 2 ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(n_channels, 2 * n_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels, 1) + cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + self.cond_layers.append(cond_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm( + res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, forward_input): + audio, spect = forward_input + audio = self.start(audio) + + for i in range(self.n_layers): + acts = fused_add_tanh_sigmoid_multiply( + self.in_layers[i](audio), + self.cond_layers[i](spect), + torch.IntTensor([self.n_channels])) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + audio = res_skip_acts[:, :self.n_channels, :] + audio + skip_acts = res_skip_acts[:, self.n_channels:, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output = skip_acts + output + return self.end(output) + + +class WaveGlow(torch.nn.Module): + def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, + n_early_size, WN_config): + super(WaveGlow, self).__init__() + + self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, + n_mel_channels, + 1024, stride=256) + assert(n_group % 2 == 0) + self.n_flows = n_flows + self.n_group = n_group + self.n_early_every = n_early_every + self.n_early_size = n_early_size + self.WN = torch.nn.ModuleList() + self.convinv = torch.nn.ModuleList() + + n_half = int(n_group / 2) + + # Set up layers with the right sizes based on how many dimensions + # have been output already + n_remaining_channels = n_group + for k in range(n_flows): + if k % self.n_early_every == 0 and k > 0: + n_half = n_half - int(self.n_early_size / 2) + n_remaining_channels = n_remaining_channels - self.n_early_size + self.convinv.append(Invertible1x1Conv(n_remaining_channels)) + self.WN.append(WN(n_half, n_mel_channels * n_group, **WN_config)) + self.n_remaining_channels = n_remaining_channels + + def forward(self, forward_input): + """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time + """ + spect, audio = forward_input + + # Upsample spectrogram to size of audio + spect = self.upsample(spect) + assert(spect.size(2) >= audio.size(1)) + if spect.size(2) > audio.size(1): + spect = spect[:, :, :audio.size(1)] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) + spect = spect.permute(0, 2, 1) + + audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) + output_audio = [] + log_s_list = [] + log_det_W_list = [] + + for k in range(self.n_flows): + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:, :self.n_early_size, :]) + audio = audio[:, self.n_early_size:, :] + + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) + + n_half = int(audio.size(1) / 2) + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] + + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = torch.exp(log_s) * audio_1 + b + log_s_list.append(log_s) + + audio = torch.cat([audio_0, audio_1], 1) + + output_audio.append(audio) + return torch.cat(output_audio, 1), log_s_list, log_det_W_list + + def infer(self, spect, sigma=1.0): + + spect = self.upsample(spect) + # trim conv artifacts. maybe pad spec to kernel multiple + time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] + spect = spect[:, :, :-time_cutoff] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1) + spect = spect.permute(0, 2, 1) + + audio = torch.randn(spect.size(0), + self.n_remaining_channels, + spect.size(2), device=spect.device).to(spect.dtype) + + audio = torch.autograd.Variable(sigma * audio) + + for k in reversed(range(self.n_flows)): + n_half = int(audio.size(1) / 2) + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] + + output = self.WN[k]((audio_0, spect)) + s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = (audio_1 - b) / torch.exp(s) + audio = torch.cat([audio_0, audio_1], 1) + + audio = self.convinv[k](audio, reverse=True) + + if k % self.n_early_every == 0 and k > 0: + z = torch.randn(spect.size(0), self.n_early_size, spect.size( + 2), device=spect.device).to(spect.dtype) + audio = torch.cat((sigma * z, audio), 1) + + audio = audio.permute( + 0, 2, 1).contiguous().view( + audio.size(0), -1).data + return audio + + + @staticmethod + def remove_weightnorm(model): + waveglow = model + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layers = remove(WN.cond_layers) + WN.res_skip_layers = remove(WN.res_skip_layers) + return waveglow + + +def remove(conv_list): + new_conv_list = torch.nn.ModuleList() + for old_conv in conv_list: + old_conv = torch.nn.utils.remove_weight_norm(old_conv) + new_conv_list.append(old_conv) + return new_conv_list diff --git a/frontend/build.gradle b/frontend/build.gradle index 105a3b62aa..8c9bc4d7ba 100644 --- a/frontend/build.gradle +++ b/frontend/build.gradle @@ -30,7 +30,7 @@ def javaProjects() { } configure(javaProjects()) { - apply plugin: 'java' + apply plugin: 'java-library' sourceCompatibility = 1.8 targetCompatibility = 1.8 @@ -42,7 +42,7 @@ configure(javaProjects()) { test { useTestNG() { - // suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml + suites 'testng.xml' } testLogging { diff --git a/frontend/gradle.properties b/frontend/gradle.properties index 97d3803c09..8c7e8ceab5 100644 --- a/frontend/gradle.properties +++ b/frontend/gradle.properties @@ -5,5 +5,5 @@ slf4j_api_version=1.7.25 slf4j_log4j12_version=1.7.25 gson_version=2.8.5 commons_cli_version=1.3.1 -testng_version=6.8.1 -torchserve_sdk_version=0.0.3 \ No newline at end of file +testng_version=7.1.0 +torchserve_sdk_version=0.0.3 diff --git a/frontend/gradle/wrapper/gradle-wrapper.jar b/frontend/gradle/wrapper/gradle-wrapper.jar index 6ffa237849..62d4c05355 100644 Binary files a/frontend/gradle/wrapper/gradle-wrapper.jar and b/frontend/gradle/wrapper/gradle-wrapper.jar differ diff --git a/frontend/gradle/wrapper/gradle-wrapper.properties b/frontend/gradle/wrapper/gradle-wrapper.properties index 57c147560f..4c5803d13c 100644 --- a/frontend/gradle/wrapper/gradle-wrapper.properties +++ b/frontend/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,5 @@ -#Thu Apr 13 16:20:04 PDT 2017 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.2.2-bin.zip diff --git a/frontend/gradlew b/frontend/gradlew index 9aa616c273..fbd7c51583 100755 --- a/frontend/gradlew +++ b/frontend/gradlew @@ -1,4 +1,20 @@ -#!/usr/bin/env bash +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ############################################################################## ## @@ -28,16 +44,16 @@ APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" -warn ( ) { +warn () { echo "$*" } -die ( ) { +die () { echo echo "$*" echo @@ -66,6 +82,7 @@ esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then @@ -109,10 +126,11 @@ if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath @@ -138,32 +156,30 @@ if $cygwin ; then else eval `echo args$i`="\"$arg\"" fi - i=$((i+1)) + i=`expr $i + 1` done case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; esac fi -# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules -function splitJvmOpts() { - JVM_OPTS=("$@") +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " } -eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS -JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" +APP_ARGS=`save "$@"` -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [[ "$(uname)" == "Darwin" ]] && [[ "$HOME" == "$PWD" ]]; then - cd "$(dirname "$0")" -fi +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" -exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" +exec "$JAVACMD" "$@" diff --git a/frontend/gradlew.bat b/frontend/gradlew.bat index e95643d6a2..a9f778a7a9 100755 --- a/frontend/gradlew.bat +++ b/frontend/gradlew.bat @@ -1,3 +1,19 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @@ -13,8 +29,11 @@ if "%DIRNAME%" == "" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome @@ -65,6 +84,7 @@ set CMD_LINE_ARGS=%* set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + @rem Execute Gradle "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% diff --git a/frontend/modelarchive/build.gradle b/frontend/modelarchive/build.gradle index b080f62765..b8d6e992a4 100644 --- a/frontend/modelarchive/build.gradle +++ b/frontend/modelarchive/build.gradle @@ -1,9 +1,9 @@ dependencies { - compile "commons-io:commons-io:2.6" - compile "org.slf4j:slf4j-api:${slf4j_api_version}" - compile "org.slf4j:slf4j-log4j12:${slf4j_log4j12_version}" - compile "com.google.code.gson:gson:${gson_version}" + api "commons-io:commons-io:2.6" + api "org.slf4j:slf4j-api:${slf4j_api_version}" + api "org.slf4j:slf4j-log4j12:${slf4j_log4j12_version}" + api "com.google.code.gson:gson:${gson_version}" - testCompile "commons-cli:commons-cli:${commons_cli_version}" - testCompile "org.testng:testng:${testng_version}" + testImplementation "commons-cli:commons-cli:${commons_cli_version}" + testImplementation "org.testng:testng:${testng_version}" } diff --git a/frontend/modelarchive/testng.xml b/frontend/modelarchive/testng.xml new file mode 100644 index 0000000000..b34b60ac4b --- /dev/null +++ b/frontend/modelarchive/testng.xml @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle index 75b3a80aa9..6f4d114340 100644 --- a/frontend/server/build.gradle +++ b/frontend/server/build.gradle @@ -1,9 +1,9 @@ dependencies { - compile "io.netty:netty-all:${netty_version}" - compile project(":modelarchive") - compile "commons-cli:commons-cli:${commons_cli_version}" - compile "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}" - testCompile "org.testng:testng:${testng_version}" + implementation "io.netty:netty-all:${netty_version}" + implementation project(":modelarchive") + implementation "commons-cli:commons-cli:${commons_cli_version}" + implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}" + testImplementation "org.testng:testng:${testng_version}" } apply from: file("${project.rootProject.projectDir}/tools/gradle/launcher.gradle") @@ -13,7 +13,7 @@ jar { attributes 'Main-Class': 'org.pytorch.serve.ModelServer' } includeEmptyDirs = false - from configurations.runtime.collect { it.isDirectory() ? it : zipTree(it) } + from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } exclude "META-INF/maven/**" exclude "META-INF/INDEX.LIST" diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index 3a67719f8f..bc8cafb599 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -44,7 +44,6 @@ import org.testng.annotations.Test; public class ModelServerTest { - private static final String ERROR_NOT_FOUND = "Requested resource is not found, please refer to API document."; private static final String ERROR_METHOD_NOT_ALLOWED = @@ -92,157 +91,8 @@ public void afterSuite() { } @Test - public void test() - throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, - IOException, NoSuchFieldException, IllegalAccessException { - Channel channel = null; - Channel managementChannel = null; - for (int i = 0; i < 5; ++i) { - channel = TestUtils.connect(false, configManager); - if (channel != null) { - break; - } - Thread.sleep(100); - } - - for (int i = 0; i < 5; ++i) { - managementChannel = TestUtils.connect(true, configManager); - if (managementChannel != null) { - break; - } - Thread.sleep(100); - } - - Assert.assertNotNull(channel, "Failed to connect to inference port."); - Assert.assertNotNull(managementChannel, "Failed to connect to management port."); - - testPing(channel); - - testRoot(channel, listInferenceApisResult); - testRoot(managementChannel, listManagementApisResult); - testApiDescription(channel, listInferenceApisResult); - testDescribeApi(channel); - testUnregisterModel(managementChannel, "noop", null); - testLoadModel(managementChannel, "noop.mar", "noop_v1.0"); - testSyncScaleModel(managementChannel, "noop_v1.0", null); - testListModels(managementChannel); - testDescribeModel(managementChannel, "noop_v1.0", null, "1.11"); - testLoadModelWithInitialWorkers(managementChannel, "noop.mar", "noop"); - testLoadModelWithInitialWorkers(managementChannel, "noop.mar", "noopversioned"); - testLoadModelWithInitialWorkers(managementChannel, "noop_v2.mar", "noopversioned"); - testDescribeModel(managementChannel, "noopversioned", null, "1.11"); - testDescribeModel(managementChannel, "noopversioned", "all", "1.2.1"); - testDescribeModel(managementChannel, "noopversioned", "1.11", "1.11"); - testPredictions(channel, "noopversioned", "OK", "1.2.1"); - testSetDefault(managementChannel, "noopversioned", "1.2.1"); - testLoadModelWithInitialWorkersWithJSONReqBody(managementChannel); - testScaleModel(managementChannel); - testPredictions(channel, "noop", "OK", null); - testPredictionsBinary(channel); - testPredictionsJson(channel); - testInvocationsJson(channel); - testInvocationsMultipart(channel); - testModelsInvokeJson(channel); - testModelsInvokeMultipart(channel); - testLegacyPredict(channel); - testPredictionsInvalidRequestSize(channel); - testPredictionsValidRequestSize(channel); - testPredictionsDecodeRequest(channel, managementChannel); - testPredictionsDoNotDecodeRequest(channel, managementChannel); - testPredictionsModifyResponseHeader(channel, managementChannel); - testPredictionsNoManifest(channel, managementChannel); - testModelRegisterWithDefaultWorkers(managementChannel); - testLoadModelFromURL(managementChannel); - testUnregisterURLModel(managementChannel); - testLoadingMemoryError(); - testPredictionMemoryError(); - testMetricManager(); - testErrorBatch(); - - channel.close(); - managementChannel.close(); - - // negative test case, channel will be closed by server - testInvalidRootRequest(); - testInvalidInferenceUri(); - testInvalidPredictionsUri(); - testInvalidDescribeModel(); - testPredictionsModelNotFound(); - testPredictionsModelVersionNotFound(); - - testInvalidManagementUri(); - testInvalidModelsMethod(); - testInvalidModelMethod(); - testDescribeModelNotFound(); - testDescribeModelVersionNotFound(); - testRegisterModelMissingUrl(); - testRegisterModelInvalidRuntime(); - testRegisterModelNotFound(); - testRegisterModelConflict(); - testRegisterModelMalformedUrl(); - testRegisterModelConnectionFailed(); - testRegisterModelHttpError(); - testRegisterModelInvalidPath(); - testScaleModelNotFound(); - testScaleModelVersionNotFound(); - testScaleModelFailure(); - testUnregisterModelNotFound(); - testUnregisterModelVersionNotFound(); - testUnregisterModelTimeout(); - testSetInvalidVersionDefault("noopversioned", "3.3.3"); - testUnregisterModelFailure("noopversioned", "1.2.1"); - - testTS(); - } - - public void testTS() - throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, - IOException, NoSuchFieldException, IllegalAccessException { - Channel channel = null; - Channel managementChannel = null; - for (int i = 0; i < 5; ++i) { - channel = TestUtils.connect(false, configManager, 300); - if (channel != null) { - break; - } - Thread.sleep(100); - } - - for (int i = 0; i < 5; ++i) { - managementChannel = TestUtils.connect(true, configManager, 300); - if (managementChannel != null) { - break; - } - Thread.sleep(100); - } - - Assert.assertNotNull(channel, "Failed to connect to inference port."); - Assert.assertNotNull(managementChannel, "Failed to connect to management port."); - - testLoadModelWithInitialWorkers(managementChannel, "mnist.mar", "mnist"); - testPredictions(channel, "mnist", "0", null); - testUnregisterModel(managementChannel, "mnist", null); - testLoadModelWithInitialWorkers(managementChannel, "mnist_scripted.mar", "mnist_scripted"); - testPredictions(channel, "mnist_scripted", "0", null); - testUnregisterModel(managementChannel, "mnist_scripted", null); - testLoadModelWithInitialWorkers(managementChannel, "mnist_traced.mar", "mnist_traced"); - testPredictions(channel, "mnist_traced", "0", null); - testUnregisterModel(managementChannel, "mnist_traced", null); - - channel.close(); - managementChannel.close(); - } - - private void testRoot(Channel channel, String expected) throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.getRoot(channel); - TestUtils.getLatch().await(); - - Assert.assertEquals(TestUtils.getResult(), expected); - } - - private void testPing(Channel channel) throws InterruptedException { + public void testPing() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/ping"); @@ -254,259 +104,240 @@ private void testPing(Channel channel) throws InterruptedException { Assert.assertTrue(TestUtils.getHeaders().contains("x-request-id")); } - private void testApiDescription(Channel channel, String expected) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPing"}) + public void testRootInference() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.getApiDescription(channel); + TestUtils.getRoot(channel); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getResult(), expected); + Assert.assertEquals(TestUtils.getResult(), listInferenceApisResult); } - private void testDescribeApi(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRootInference"}) + public void testRootManagement() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.describeModelApi(channel, "noop"); + TestUtils.getRoot(channel); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getResult(), noopApiResult); + Assert.assertEquals(TestUtils.getResult(), listManagementApisResult); } - private void testLoadModel(Channel channel, String url, String modelName) - throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRootManagement"}) + public void testApiDescription() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, url, modelName, false, false); + TestUtils.getApiDescription(channel); TestUtils.getLatch().await(); - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" registered"); + Assert.assertEquals(TestUtils.getResult(), listInferenceApisResult); } - private void testLoadModelFromURL(Channel channel) throws InterruptedException { - testLoadModel( - channel, - "https://torchserve.s3.amazonaws.com/mar_files/squeezenet1_1.mar", - "squeezenet"); - Assert.assertTrue(new File(configManager.getModelStore(), "squeezenet1_1.mar").exists()); - } - - private void testUnregisterURLModel(Channel channel) throws InterruptedException { - testUnregisterModel(channel, "squeezenet", null); - Assert.assertTrue(!new File(configManager.getModelStore(), "squeezenet1_1.mar").exists()); - } - - private void testLoadModelWithInitialWorkers(Channel channel, String url, String modelName) - throws InterruptedException { - + @Test( + alwaysRun = true, + dependsOnMethods = {"testApiDescription"}) + public void testDescribeApi() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, url, modelName, true, false); + TestUtils.describeModelApi(channel, "noop"); TestUtils.getLatch().await(); - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Workers scaled"); + Assert.assertEquals(TestUtils.getResult(), noopApiResult); } - private void testLoadModelWithInitialWorkersWithJSONReqBody(Channel channel) - throws InterruptedException { - testUnregisterModel(channel, "noop", null); - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - DefaultFullHttpRequest req = - new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models"); - req.headers().add("Content-Type", "application/json"); - req.content() - .writeCharSequence( - "{'url':'noop.mar', 'model_name':'noop', 'initial_workers':'1', 'synchronous':'true'}", - CharsetUtil.UTF_8); - HttpUtil.setContentLength(req, req.content().readableBytes()); - channel.writeAndFlush(req); - TestUtils.getLatch().await(); - - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Workers scaled"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeApi"}) + public void testUnregisterNoopModel() throws InterruptedException { + testUnregisterModel("noop", null); } - private void testScaleModel(Channel channel) throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.scaleModel(channel, "noop_v1.0", null, 2, false); - TestUtils.getLatch().await(); - - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Processing worker updates..."); + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterNoopModel"}) + public void testLoadNoopModel() throws InterruptedException { + testLoadModel("noop.mar", "noop_v1.0"); } - private void testSyncScaleModel(Channel channel, String modelName, String version) - throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.scaleModel(channel, modelName, version, 1, true); - - TestUtils.getLatch().await(); - - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Workers scaled"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadNoopModel"}) + public void testSyncScaleNoopModel() throws InterruptedException { + testSyncScaleModel("noop_v1.0", null); } - private void testUnregisterModel(Channel channel, String modelName, String version) - throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testSyncScaleNoopModel"}) + public void testListModels() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.unregisterModel(channel, modelName, version, false); + TestUtils.listModels(channel); TestUtils.getLatch().await(); - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" unregistered"); + ListModelsResponse resp = + JsonUtils.GSON.fromJson(TestUtils.getResult(), ListModelsResponse.class); + Assert.assertEquals(resp.getModels().size(), 1); } - private void testUnregisterModelFailure(String modelName, String version) - throws InterruptedException { - Channel channel = TestUtils.connect(true, configManager); - Assert.assertNotNull(channel); - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.unregisterModel(channel, modelName, version, false); - TestUtils.getLatch().await(); + @Test( + alwaysRun = true, + dependsOnMethods = {"testListModels"}) + public void testDescribeNoopModel() throws InterruptedException { + testDescribeModel("noop_v1.0", null, "1.11"); + } - ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); - Assert.assertEquals(resp.getCode(), HttpResponseStatus.FORBIDDEN.code()); - Assert.assertEquals( - resp.getMessage(), "Cannot remove default version for model " + modelName); + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeNoopModel"}) + public void testLoadNoopModelWithInitialWorkers() throws InterruptedException { + testLoadModelWithInitialWorkers("noop.mar", "noop"); + } - channel = TestUtils.connect(true, configManager); - Assert.assertNotNull(channel); - testUnregisterModel(channel, "noopversioned", "1.11"); - testUnregisterModel(channel, "noopversioned", "1.2.1"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadNoopModelWithInitialWorkers"}) + public void testLoadNoopV1ModelWithInitialWorkers() throws InterruptedException { + testLoadModelWithInitialWorkers("noop.mar", "noopversioned"); } - private void testListModels(Channel channel) throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.listModels(channel); - TestUtils.getLatch().await(); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadNoopV1ModelWithInitialWorkers"}) + public void testLoadNoopV2ModelWithInitialWorkers() throws InterruptedException { + testLoadModelWithInitialWorkers("noop_v2.mar", "noopversioned"); + } - ListModelsResponse resp = - JsonUtils.GSON.fromJson(TestUtils.getResult(), ListModelsResponse.class); - Assert.assertEquals(resp.getModels().size(), 1); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadNoopV2ModelWithInitialWorkers"}) + public void testDescribeDefaultModelVersion() throws InterruptedException { + testDescribeModel("noopversioned", null, "1.11"); } - private void testDescribeModel( - Channel channel, String modelName, String requestVersion, String expectedVersion) - throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.describeModel(channel, modelName, requestVersion); - TestUtils.getLatch().await(); + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeDefaultModelVersion"}) + public void testDescribeAllModelVersion() throws InterruptedException { + testDescribeModel("noopversioned", "all", "1.2.1"); + } - DescribeModelResponse[] resp = - JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class); - if ("all".equals(requestVersion)) { - Assert.assertTrue(resp.length >= 1); - } else { - Assert.assertTrue(resp.length == 1); - } + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeAllModelVersion"}) + public void testDescribeSpecificModelVersion() throws InterruptedException { + testDescribeModel("noopversioned", "1.11", "1.11"); + } - Assert.assertTrue(expectedVersion.equals(resp[0].getModelVersion())); + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeSpecificModelVersion"}) + public void testNoopVersionedPrediction() throws InterruptedException { + testPredictions("noopversioned", "OK", "1.11"); } - private void testSetDefault(Channel channel, String modelName, String defaultVersion) - throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoopVersionedPrediction"}) + public void testSetDefaultVersionNoop() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default"; - - HttpRequest req = - new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL); - channel.writeAndFlush(req); + TestUtils.setDefault(channel, "noopversioned", "1.2.1"); TestUtils.getLatch().await(); StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); Assert.assertEquals( resp.getStatus(), - "Default vesion succsesfully updated for model \"" - + modelName - + "\" to \"" - + defaultVersion - + "\""); + "Default vesion succsesfully updated for model \"noopversioned\" to \"1.2.1\""); } - private void testSetInvalidVersionDefault(String modelName, String defaultVersion) - throws InterruptedException { - Channel channel = TestUtils.connect(true, configManager); - Assert.assertNotNull(channel); + @Test( + alwaysRun = true, + dependsOnMethods = {"testSetDefaultVersionNoop"}) + public void testLoadModelWithInitialWorkersWithJSONReqBody() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + testUnregisterModel("noop", null); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default"; - - HttpRequest req = - new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models"); + req.headers().add("Content-Type", "application/json"); + req.content() + .writeCharSequence( + "{'url':'noop.mar', 'model_name':'noop', 'initial_workers':'1', 'synchronous':'true'}", + CharsetUtil.UTF_8); + HttpUtil.setContentLength(req, req.content().readableBytes()); channel.writeAndFlush(req); TestUtils.getLatch().await(); - ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); - Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); - Assert.assertEquals( - resp.getMessage(), - "Model version " + defaultVersion + " does not exist for model " + modelName); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), "Workers scaled"); } - private void testPredictions( - Channel channel, String modelName, String expectedOutput, String version) - throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - String requestURL = "/predictions/" + modelName; - if (version != null) { - requestURL += "/" + version; - } - DefaultFullHttpRequest req = - new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, requestURL); - req.content().writeCharSequence("data=test", CharsetUtil.UTF_8); - HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers() - .set( - HttpHeaderNames.CONTENT_TYPE, - HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); - channel.writeAndFlush(req); - - TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getResult(), expectedOutput); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadModelWithInitialWorkersWithJSONReqBody"}) + public void testNoopPrediction() throws InterruptedException { + testPredictions("noop", "OK", null); } - private void testPredictionsJson(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoopPrediction"}) + public void testPredictionsBinary() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); - req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + req.content().writeCharSequence("test", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); channel.writeAndFlush(req); TestUtils.getLatch().await(); + Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testPredictionsBinary(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsBinary"}) + public void testPredictionsJson() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); - req.content().writeCharSequence("test", CharsetUtil.UTF_8); + req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); channel.writeAndFlush(req); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testInvocationsJson(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsJson"}) + public void testInvocationsJson() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -521,9 +352,13 @@ private void testInvocationsJson(Channel channel) throws InterruptedException { Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testInvocationsMultipart(Channel channel) + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvocationsJson"}) + public void testInvocationsMultipart() throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -546,7 +381,11 @@ private void testInvocationsMultipart(Channel channel) Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testModelsInvokeJson(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvocationsMultipart"}) + public void testModelsInvokeJson() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -561,9 +400,13 @@ private void testModelsInvokeJson(Channel channel) throws InterruptedException { Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testModelsInvokeMultipart(Channel channel) + @Test( + alwaysRun = true, + dependsOnMethods = {"testModelsInvokeJson"}) + public void testModelsInvokeMultipart() throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -586,7 +429,27 @@ private void testModelsInvokeMultipart(Channel channel) Assert.assertEquals(TestUtils.getResult(), "OK"); } - private void testPredictionsInvalidRequestSize(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testModelsInvokeMultipart"}) + public void testLegacyPredict() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/noop/predict?data=test"); + channel.writeAndFlush(req); + + TestUtils.getLatch().await(); + Assert.assertEquals(TestUtils.getResult(), "OK"); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testLegacyPredict"}) + public void testPredictionsInvalidRequestSize() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -603,7 +466,11 @@ private void testPredictionsInvalidRequestSize(Channel channel) throws Interrupt Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); } - private void testPredictionsValidRequestSize(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsInvalidRequestSize"}) + public void testPredictionsValidRequestSize() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = @@ -620,35 +487,121 @@ private void testPredictionsValidRequestSize(Channel channel) throws Interrupted Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); } - private void loadTests(Channel channel, String model, String modelName) - throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsValidRequestSize"}) + public void testPredictionsDecodeRequest() + throws InterruptedException, NoSuchFieldException, IllegalAccessException { + Channel inferChannel = TestUtils.getInferenceChannel(configManager); + Channel mgmtChannel = TestUtils.getManagementChannel(configManager); + setConfiguration("decode_input_request", "true"); + loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config"); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, model, modelName, true, false); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); + req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + inferChannel.writeAndFlush(req); TestUtils.getLatch().await(); + + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); + Assert.assertFalse(TestUtils.getResult().contains("bytearray")); + unloadTests(mgmtChannel, "noop-config"); } - private void unloadTests(Channel channel, String modelName) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsDecodeRequest"}) + public void testPredictionsDoNotDecodeRequest() + throws InterruptedException, NoSuchFieldException, IllegalAccessException { + Channel inferChannel = TestUtils.getInferenceChannel(configManager); + Channel mgmtChannel = TestUtils.getManagementChannel(configManager); + setConfiguration("decode_input_request", "false"); + loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config"); + TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - String expected = "Model \"" + modelName + "\" unregistered"; - TestUtils.unregisterModel(channel, modelName, null, false); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); + req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + inferChannel.writeAndFlush(req); + TestUtils.getLatch().await(); - StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(resp.getStatus(), expected); + + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); + Assert.assertTrue(TestUtils.getResult().contains("bytearray")); + unloadTests(mgmtChannel, "noop-config"); } - private void setConfiguration(String key, String val) - throws NoSuchFieldException, IllegalAccessException { - Field f = configManager.getClass().getDeclaredField("prop"); - f.setAccessible(true); - Properties p = (Properties) f.get(configManager); - p.setProperty(key, val); + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsDoNotDecodeRequest"}) + public void testPredictionsModifyResponseHeader() + throws NoSuchFieldException, IllegalAccessException, InterruptedException { + Channel inferChannel = TestUtils.getInferenceChannel(configManager); + Channel mgmtChannel = TestUtils.getManagementChannel(configManager); + setConfiguration("decode_input_request", "false"); + loadTests(mgmtChannel, "respheader-test.mar", "respheader"); + + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/respheader"); + + req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + inferChannel.writeAndFlush(req); + + TestUtils.getLatch().await(); + + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); + Assert.assertEquals(TestUtils.getHeaders().get("dummy"), "1"); + Assert.assertEquals(TestUtils.getHeaders().get("content-type"), "text/plain"); + Assert.assertTrue(TestUtils.getResult().contains("bytearray")); + unloadTests(mgmtChannel, "respheader"); } - private void testModelRegisterWithDefaultWorkers(Channel mgmtChannel) + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsModifyResponseHeader"}) + public void testPredictionsNoManifest() + throws InterruptedException, NoSuchFieldException, IllegalAccessException { + Channel inferChannel = TestUtils.getInferenceChannel(configManager); + Channel mgmtChannel = TestUtils.getManagementChannel(configManager); + setConfiguration("default_service_handler", "service:handle"); + loadTests(mgmtChannel, "noop-no-manifest.mar", "nomanifest"); + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/nomanifest"); + req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + inferChannel.writeAndFlush(req); + + TestUtils.getLatch().await(); + + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); + Assert.assertEquals(TestUtils.getResult(), "OK"); + unloadTests(mgmtChannel, "nomanifest"); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsNoManifest"}) + public void testModelRegisterWithDefaultWorkers() throws NoSuchFieldException, IllegalAccessException, InterruptedException { + Channel mgmtChannel = TestUtils.getManagementChannel(configManager); setConfiguration("default_workers_per_model", "1"); loadTests(mgmtChannel, "noop.mar", "noop_default_model_workers"); @@ -666,109 +619,161 @@ private void testModelRegisterWithDefaultWorkers(Channel mgmtChannel) setConfiguration("default_workers_per_model", "0"); } - private void testPredictionsDecodeRequest(Channel inferChannel, Channel mgmtChannel) - throws InterruptedException, NoSuchFieldException, IllegalAccessException { - setConfiguration("decode_input_request", "true"); - loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testModelRegisterWithDefaultWorkers"}) + public void testLoadModelFromURL() throws InterruptedException { + testLoadModel( + "https://torchserve.s3.amazonaws.com/mar_files/squeezenet1_1.mar", "squeezenet"); + Assert.assertTrue(new File(configManager.getModelStore(), "squeezenet1_1.mar").exists()); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadModelFromURL"}) + public void testUnregisterURLModel() throws InterruptedException { + testUnregisterModel("squeezenet", null); + Assert.assertTrue(!new File(configManager.getModelStore(), "squeezenet1_1.mar").exists()); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterURLModel"}) + public void testLoadingMemoryError() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + Assert.assertNotNull(channel); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - DefaultFullHttpRequest req = - new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); - req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); - HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); - inferChannel.writeAndFlush(req); + TestUtils.registerModel(channel, "loading-memory-error.mar", "memory_error", true, false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - Assert.assertFalse(TestUtils.getResult().contains("bytearray")); - unloadTests(mgmtChannel, "noop-config"); + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); + channel.close(); } - private void testPredictionsDoNotDecodeRequest(Channel inferChannel, Channel mgmtChannel) - throws InterruptedException, NoSuchFieldException, IllegalAccessException { - setConfiguration("decode_input_request", "false"); - loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadingMemoryError"}) + public void testPredictionMemoryError() throws InterruptedException { + // Load the model + Channel channel = TestUtils.getManagementChannel(configManager); + Assert.assertNotNull(channel); + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + + TestUtils.registerModel(channel, "prediction-memory-error.mar", "pred-err", true, false); + TestUtils.getLatch().await(); + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); + channel.close(); + // Test for prediction + channel = TestUtils.connect(false, configManager); + Assert.assertNotNull(channel); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); DefaultFullHttpRequest req = new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); - req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); - HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); - inferChannel.writeAndFlush(req); + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/pred-err"); + req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); + + channel.writeAndFlush(req); + TestUtils.getLatch().await(); + + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); + channel.close(); + + // Unload the model + channel = TestUtils.connect(true, configManager); + TestUtils.setHttpStatus(null); + TestUtils.setLatch(new CountDownLatch(1)); + Assert.assertNotNull(channel); + TestUtils.unregisterModel(channel, "pred-err", null, false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - Assert.assertTrue(TestUtils.getResult().contains("bytearray")); - unloadTests(mgmtChannel, "noop-config"); } - private void testPredictionsModifyResponseHeader( - Channel inferChannel, Channel managementChannel) - throws NoSuchFieldException, IllegalAccessException, InterruptedException { - setConfiguration("decode_input_request", "false"); - loadTests(managementChannel, "respheader-test.mar", "respheader"); + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionMemoryError"}) + public void testErrorBatch() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + Assert.assertNotNull(channel); + TestUtils.setHttpStatus(null); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - DefaultFullHttpRequest req = - new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/respheader"); - - req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); - HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); - inferChannel.writeAndFlush(req); + TestUtils.registerModel(channel, "error_batch.mar", "err_batch", true, false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - Assert.assertEquals(TestUtils.getHeaders().get("dummy"), "1"); - Assert.assertEquals(TestUtils.getHeaders().get("content-type"), "text/plain"); - Assert.assertTrue(TestUtils.getResult().contains("bytearray")); - unloadTests(managementChannel, "respheader"); - } + StatusResponse status = + JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(status.getStatus(), "Workers scaled"); + + channel.close(); + + channel = TestUtils.connect(false, configManager); + Assert.assertNotNull(channel); - private void testPredictionsNoManifest(Channel inferChannel, Channel mgmtChannel) - throws InterruptedException, NoSuchFieldException, IllegalAccessException { - setConfiguration("default_service_handler", "service:handle"); - loadTests(mgmtChannel, "noop-no-manifest.mar", "nomanifest"); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); + TestUtils.setHttpStatus(null); DefaultFullHttpRequest req = new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/nomanifest"); - req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/err_batch"); + req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); - req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); - inferChannel.writeAndFlush(req); + req.headers() + .set( + HttpHeaderNames.CONTENT_TYPE, + HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); + channel.writeAndFlush(req); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - Assert.assertEquals(TestUtils.getResult(), "OK"); - unloadTests(mgmtChannel, "nomanifest"); + Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); + Assert.assertEquals(TestUtils.getResult(), "Invalid response"); } - private void testLegacyPredict(Channel channel) throws InterruptedException { - TestUtils.setResult(null); - TestUtils.setLatch(new CountDownLatch(1)); - DefaultFullHttpRequest req = - new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.GET, "/noop/predict?data=test"); - channel.writeAndFlush(req); + @Test( + alwaysRun = true, + dependsOnMethods = {"testErrorBatch"}) + public void testMetricManager() throws JsonParseException, InterruptedException { + MetricManager.scheduleMetrics(configManager); + MetricManager metricManager = MetricManager.getInstance(); + List metrics = metricManager.getMetrics(); - TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getResult(), "OK"); + // Wait till first value is read in + int count = 0; + while (metrics.isEmpty()) { + Thread.sleep(500); + metrics = metricManager.getMetrics(); + Assert.assertTrue(++count < 5); + } + for (Metric metric : metrics) { + if (metric.getMetricName().equals("CPUUtilization")) { + Assert.assertEquals(metric.getUnit(), "Percent"); + } + if (metric.getMetricName().equals("MemoryUsed")) { + Assert.assertEquals(metric.getUnit(), "Megabytes"); + } + if (metric.getMetricName().equals("DiskUsed")) { + List dimensions = metric.getDimensions(); + for (Dimension dimension : dimensions) { + if (dimension.getName().equals("Level")) { + Assert.assertEquals(dimension.getValue(), "Host"); + } + } + } + } } - private void testInvalidRootRequest() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testMetricManager"}) + public void testInvalidRootRequest() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); @@ -782,7 +787,10 @@ private void testInvalidRootRequest() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } - private void testInvalidInferenceUri() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidRootRequest"}) + public void testInvalidInferenceUri() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); @@ -797,7 +805,10 @@ private void testInvalidInferenceUri() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } - private void testInvalidDescribeModel() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidInferenceUri"}) + public void testInvalidDescribeModel() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); @@ -810,7 +821,10 @@ private void testInvalidDescribeModel() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } - private void testInvalidPredictionsUri() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidDescribeModel"}) + public void testInvalidPredictionsUri() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); @@ -825,7 +839,10 @@ private void testInvalidPredictionsUri() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } - private void testPredictionsModelNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidPredictionsUri"}) + public void testPredictionsModelNotFound() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); @@ -841,9 +858,13 @@ private void testPredictionsModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } - private void testPredictionsModelVersionNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsModelNotFound"}) + public void testPredictionsModelVersionNotFound() throws InterruptedException { Channel channel = TestUtils.connect(false, configManager); Assert.assertNotNull(channel); + HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions/noopversioned/1.3.1"); @@ -857,7 +878,10 @@ private void testPredictionsModelVersionNotFound() throws InterruptedException { resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned"); } - private void testInvalidManagementUri() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionsModelNotFound"}) + public void testInvalidManagementUri() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -872,7 +896,10 @@ private void testInvalidManagementUri() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } - private void testInvalidModelsMethod() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidManagementUri"}) + public void testInvalidModelsMethod() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -887,7 +914,10 @@ private void testInvalidModelsMethod() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } - private void testInvalidModelMethod() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidModelsMethod"}) + public void testInvalidModelMethod() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -902,7 +932,10 @@ private void testInvalidModelMethod() throws InterruptedException { Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } - private void testDescribeModelNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testInvalidModelMethod"}) + public void testDescribeModelNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -918,9 +951,13 @@ private void testDescribeModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } - private void testDescribeModelVersionNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeModelNotFound"}) + public void testDescribeModelVersionNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); + HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/noopversioned/1.3.1"); @@ -928,12 +965,16 @@ private void testDescribeModelVersionNotFound() throws InterruptedException { channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); + Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals( resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned"); } - private void testRegisterModelMissingUrl() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testDescribeModelNotFound"}) + public void testRegisterModelMissingUrl() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -948,7 +989,10 @@ private void testRegisterModelMissingUrl() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Parameter url is required."); } - private void testRegisterModelInvalidRuntime() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelMissingUrl"}) + public void testRegisterModelInvalidRuntime() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -966,7 +1010,10 @@ private void testRegisterModelInvalidRuntime() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Invalid RuntimeType value: InvalidRuntime"); } - private void testRegisterModelNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelInvalidRuntime"}) + public void testRegisterModelNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -982,7 +1029,10 @@ private void testRegisterModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found in model store: InvalidUrl"); } - private void testRegisterModelConflict() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelNotFound"}) + public void testRegisterModelConflict() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1009,7 +1059,10 @@ private void testRegisterModelConflict() throws InterruptedException { resp.getMessage(), "Model version 1.11 is already registered for model noop_v1.0"); } - private void testRegisterModelMalformedUrl() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelConflict"}) + public void testRegisterModelMalformedUrl() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1028,7 +1081,10 @@ private void testRegisterModelMalformedUrl() throws InterruptedException { resp.getMessage(), "Failed to download model from: http://localhost:aaaa"); } - private void testRegisterModelConnectionFailed() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelMalformedUrl"}) + public void testRegisterModelConnectionFailed() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1048,7 +1104,10 @@ private void testRegisterModelConnectionFailed() throws InterruptedException { "Failed to download model from: http://localhost:18888/fake.mar"); } - private void testRegisterModelHttpError() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelConnectionFailed"}) + public void testRegisterModelHttpError() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1068,7 +1127,10 @@ private void testRegisterModelHttpError() throws InterruptedException { "Failed to download model from: https://localhost:8443/fake.mar"); } - private void testRegisterModelInvalidPath() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelHttpError"}) + public void testRegisterModelInvalidPath() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1086,7 +1148,10 @@ private void testRegisterModelInvalidPath() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Relative path is not allowed in url: ../fake.mar"); } - private void testScaleModelNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterModelInvalidPath"}) + public void testScaleModelNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1101,7 +1166,10 @@ private void testScaleModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: fake"); } - private void testScaleModelVersionNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testScaleModelNotFound"}) + public void testScaleModelVersionNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.setResult(null); @@ -1116,7 +1184,10 @@ private void testScaleModelVersionNotFound() throws InterruptedException { resp.getMessage(), "Model version: 1.3.1 does not exist for model: noop_v1.0"); } - private void testUnregisterModelNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testScaleModelNotFound"}) + public void testUnregisterModelNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1128,19 +1199,25 @@ private void testUnregisterModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: fake"); } - private void testUnregisterModelVersionNotFound() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterModelNotFound"}) + public void testUnregisterModelVersionNotFound() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); - TestUtils.unregisterModel(channel, "noop_v1.0", "1.3.1", true); + TestUtils.unregisterModel(channel, "noopversioned", "1.3.1", true); ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals( - resp.getMessage(), "Model version: 1.3.1 does not exist for model: noop_v1.0"); + resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned"); } - private void testUnregisterModelTimeout() + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterModelNotFound"}) + public void testUnregisterModelTimeout() throws InterruptedException, NoSuchFieldException, IllegalAccessException { Channel channel = TestUtils.connect(true, configManager); setConfiguration("unregister_model_timeout", "0"); @@ -1158,7 +1235,10 @@ private void testUnregisterModelTimeout() Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); } - private void testScaleModelFailure() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterModelTimeout"}) + public void testScaleModelFailure() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); @@ -1185,85 +1265,183 @@ private void testScaleModelFailure() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Failed to start workers"); } - private void testLoadingMemoryError() throws InterruptedException { - Channel channel = TestUtils.connect(true, configManager); - Assert.assertNotNull(channel); + @Test( + alwaysRun = true, + dependsOnMethods = {"testScaleModelFailure"}) + public void testLoadMNISTEagerModel() throws InterruptedException { + testLoadModelWithInitialWorkers("mnist.mar", "mnist"); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadMNISTEagerModel"}) + public void testPredictionMNISTEagerModel() throws InterruptedException { + testPredictions("mnist", "0", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionMNISTEagerModel"}) + public void testUnregistedMNISTEagerModel() throws InterruptedException { + testUnregisterModel("mnist", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregistedMNISTEagerModel"}) + public void testLoadMNISTScriptedModel() throws InterruptedException { + testLoadModelWithInitialWorkers("mnist_scripted.mar", "mnist_scripted"); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadMNISTScriptedModel"}) + public void testPredictionMNISTScriptedModel() throws InterruptedException { + testPredictions("mnist_scripted", "0", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionMNISTScriptedModel"}) + public void testUnregistedMNISTScriptedModel() throws InterruptedException { + testUnregisterModel("mnist_scripted", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregistedMNISTScriptedModel"}) + public void testLoadMNISTTracedModel() throws InterruptedException { + testLoadModelWithInitialWorkers("mnist_traced.mar", "mnist_traced"); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadMNISTTracedModel"}) + public void testPredictionMNISTTracedModel() throws InterruptedException { + testPredictions("mnist_traced", "0", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testPredictionMNISTTracedModel"}) + public void testUnregistedMNISTTracedModel() throws InterruptedException { + testUnregisterModel("mnist_traced", null); + } + + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregistedMNISTTracedModel"}) + public void testSetInvalidDefaultVersion() throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - - TestUtils.registerModel(channel, "loading-memory-error.mar", "memory_error", true, false); + TestUtils.setDefault(channel, "noopversioned", "3.3.3"); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); - channel.close(); + ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); + Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); + Assert.assertEquals( + resp.getMessage(), "Model version 3.3.3 does not exist for model noopversioned"); } - private void testPredictionMemoryError() throws InterruptedException { - // Load the model + @Test( + alwaysRun = true, + dependsOnMethods = {"testSetInvalidDefaultVersion"}) + public void testUnregisterModelFailure() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - - TestUtils.registerModel(channel, "prediction-memory-error.mar", "pred-err", true, false); + TestUtils.unregisterModel(channel, "noopversioned", "1.2.1", false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - channel.close(); - // Test for prediction - channel = TestUtils.connect(false, configManager); + ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class); + Assert.assertEquals(resp.getCode(), HttpResponseStatus.FORBIDDEN.code()); + Assert.assertEquals( + resp.getMessage(), "Cannot remove default version for model noopversioned"); + + channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); + TestUtils.unregisterModel(channel, "noopversioned", "1.11", false); + TestUtils.unregisterModel(channel, "noopversioned", "1.2.1", false); + } + + private void testLoadModel(String url, String modelName) throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - DefaultFullHttpRequest req = - new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/pred-err"); - req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); + TestUtils.registerModel(channel, url, modelName, false, false); + TestUtils.getLatch().await(); - channel.writeAndFlush(req); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" registered"); + } + + private void testUnregisterModel(String modelName, String version) throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + TestUtils.unregisterModel(channel, modelName, version, false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); - channel.close(); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" unregistered"); + } - // Unload the model - channel = TestUtils.connect(true, configManager); - TestUtils.setHttpStatus(null); + private void testSyncScaleModel(String modelName, String version) throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - Assert.assertNotNull(channel); + TestUtils.scaleModel(channel, modelName, version, 1, true); - TestUtils.unregisterModel(channel, "pred-err", null, false); TestUtils.getLatch().await(); - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK); - } - private void testErrorBatch() throws InterruptedException { - Channel channel = TestUtils.connect(true, configManager); - Assert.assertNotNull(channel); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), "Workers scaled"); + } - TestUtils.setHttpStatus(null); + private void testDescribeModel(String modelName, String requestVersion, String expectedVersion) + throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - - TestUtils.registerModel(channel, "error_batch.mar", "err_batch", true, false); + TestUtils.describeModel(channel, modelName, requestVersion); TestUtils.getLatch().await(); - StatusResponse status = - JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); - Assert.assertEquals(status.getStatus(), "Workers scaled"); + DescribeModelResponse[] resp = + JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class); + if ("all".equals(requestVersion)) { + Assert.assertTrue(resp.length >= 1); + } else { + Assert.assertTrue(resp.length == 1); + } + Assert.assertTrue(expectedVersion.equals(resp[0].getModelVersion())); + } - channel.close(); + private void testLoadModelWithInitialWorkers(String url, String modelName) + throws InterruptedException { + Channel channel = TestUtils.getManagementChannel(configManager); + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + TestUtils.registerModel(channel, url, modelName, true, false); + TestUtils.getLatch().await(); - channel = TestUtils.connect(false, configManager); - Assert.assertNotNull(channel); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), "Workers scaled"); + } + private void testPredictions(String modelName, String expectedOutput, String version) + throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.setHttpStatus(null); + String requestURL = "/predictions/" + modelName; + if (version != null) { + requestURL += "/" + version; + } DefaultFullHttpRequest req = - new DefaultFullHttpRequest( - HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/err_batch"); - req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, requestURL); + req.content().writeCharSequence("data=test", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers() .set( @@ -1272,38 +1450,33 @@ private void testErrorBatch() throws InterruptedException { channel.writeAndFlush(req); TestUtils.getLatch().await(); + Assert.assertEquals(TestUtils.getResult(), expectedOutput); + } - Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE); - Assert.assertEquals(TestUtils.getResult(), "Invalid response"); + private void loadTests(Channel channel, String model, String modelName) + throws InterruptedException { + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + TestUtils.registerModel(channel, model, modelName, true, false); + + TestUtils.getLatch().await(); } - private void testMetricManager() throws JsonParseException, InterruptedException { - MetricManager.scheduleMetrics(configManager); - MetricManager metricManager = MetricManager.getInstance(); - List metrics = metricManager.getMetrics(); + private void unloadTests(Channel channel, String modelName) throws InterruptedException { + TestUtils.setResult(null); + TestUtils.setLatch(new CountDownLatch(1)); + String expected = "Model \"" + modelName + "\" unregistered"; + TestUtils.unregisterModel(channel, modelName, null, false); + TestUtils.getLatch().await(); + StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class); + Assert.assertEquals(resp.getStatus(), expected); + } - // Wait till first value is read in - int count = 0; - while (metrics.isEmpty()) { - Thread.sleep(500); - metrics = metricManager.getMetrics(); - Assert.assertTrue(++count < 5); - } - for (Metric metric : metrics) { - if (metric.getMetricName().equals("CPUUtilization")) { - Assert.assertEquals(metric.getUnit(), "Percent"); - } - if (metric.getMetricName().equals("MemoryUsed")) { - Assert.assertEquals(metric.getUnit(), "Megabytes"); - } - if (metric.getMetricName().equals("DiskUsed")) { - List dimensions = metric.getDimensions(); - for (Dimension dimension : dimensions) { - if (dimension.getName().equals("Level")) { - Assert.assertEquals(dimension.getValue(), "Host"); - } - } - } - } + private void setConfiguration(String key, String val) + throws NoSuchFieldException, IllegalAccessException { + Field f = configManager.getClass().getDeclaredField("prop"); + f.setAccessible(true); + Properties p = (Properties) f.get(configManager); + p.setProperty(key, val); } } diff --git a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java index 92144a4e20..fb764a60df 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java @@ -64,69 +64,21 @@ public void beforeSuite() } @AfterClass - public void afterSuite() { + public void afterSuite() throws InterruptedException { + TestUtils.closeChannels(); server.stop(); } @Test - public void test() - throws InterruptedException, IOException, GeneralSecurityException, - InvalidSnapshotException { - Channel channel = null; - Channel managementChannel = null; - for (int i = 0; i < 5; ++i) { - channel = TestUtils.connect(false, configManager); - if (channel != null) { - break; - } - Thread.sleep(100); - } - - for (int i = 0; i < 5; ++i) { - managementChannel = TestUtils.connect(true, configManager); - if (managementChannel != null) { - break; - } - Thread.sleep(100); - } - - Assert.assertNotNull(channel, "Failed to connect to inference port."); - Assert.assertNotNull(managementChannel, "Failed to connect to management port."); - - testStartupSnapshot("snapshot1.cfg"); - testUnregisterSnapshot(managementChannel); - testRegisterSnapshot(managementChannel); - testSyncScaleModelSnapshot(managementChannel); - testNoSnapshotOnListModels(managementChannel); - testNoSnapshotOnDescribeModel(managementChannel); - testLoadModelWithInitialWorkersSnapshot(managementChannel); - testRegisterSecondModelSnapshot(managementChannel); - testSecondModelVersionSnapshot(managementChannel); - testNoSnapshotOnPrediction(channel); - testSetDefaultSnapshot(managementChannel); - testAsyncScaleModelSnapshot(managementChannel); - - channel.close(); - managementChannel.close(); - - testStopTorchServeSnapshot(); - testStartTorchServeWithLastSnapshot(); - testRestartTorchServeWithSnapshotAsConfig(); - - // Negative management API calls, channel will be closed by server - testNoSnapshotOnInvalidModelRegister(); - testNoSnapshotOnInvalidModelUnregister(); - testNoSnapshotOnInvalidModelVersionUnregister(); - testNoSnapshotOnInvalidModelScale(); - testNoSnapshotOnInvalidModelVersionScale(); - testNoSnapshotOnInvalidModelVersionSetDefault(); + public void testStartupSnapshot() { + validateSnapshot("snapshot1.cfg"); } - private void testStartupSnapshot(String expectedSnapshot) { - validateSnapshot(expectedSnapshot); - } - - private void testUnregisterSnapshot(Channel managementChannel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testStartupSnapshot"}) + public void testUnregisterSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); TestUtils.unregisterModel(managementChannel, "noop", null, false); @@ -135,7 +87,11 @@ private void testUnregisterSnapshot(Channel managementChannel) throws Interrupte waitForSnapshot(); } - private void testRegisterSnapshot(Channel managementChannel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testUnregisterSnapshot"}) + public void testRegisterSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); TestUtils.registerModel(managementChannel, "noop.mar", "noop_v1.0", false, false); @@ -144,16 +100,24 @@ private void testRegisterSnapshot(Channel managementChannel) throws InterruptedE waitForSnapshot(); } - private void testSyncScaleModelSnapshot(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterSnapshot"}) + public void testSyncScaleModelSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.scaleModel(channel, "noop_v1.0", null, 1, true); + TestUtils.scaleModel(managementChannel, "noop_v1.0", null, 1, true); TestUtils.getLatch().await(); validateSnapshot("snapshot4.cfg"); waitForSnapshot(); } - private void testNoSnapshotOnListModels(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testSyncScaleModelSnapshot"}) + public void testNoSnapshotOnListModels() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); TestUtils.listModels(channel); @@ -161,7 +125,11 @@ private void testNoSnapshotOnListModels(Channel channel) throws InterruptedExcep validateNoSnapshot(); } - private void testNoSnapshotOnDescribeModel(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnListModels"}) + public void testNoSnapshotOnDescribeModel() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); TestUtils.describeModel(channel, "noop_v1.0", null); @@ -169,17 +137,24 @@ private void testNoSnapshotOnDescribeModel(Channel channel) throws InterruptedEx validateNoSnapshot(); } - private void testLoadModelWithInitialWorkersSnapshot(Channel channel) - throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnDescribeModel"}) + public void testLoadModelWithInitialWorkersSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, "noop.mar", "noop", true, false); + TestUtils.registerModel(managementChannel, "noop.mar", "noop", true, false); TestUtils.getLatch().await(); validateSnapshot("snapshot5.cfg"); waitForSnapshot(); } - private void testNoSnapshotOnPrediction(Channel channel) { + @Test( + alwaysRun = true, + dependsOnMethods = {"testLoadModelWithInitialWorkersSnapshot"}) + public void testNoSnapshotOnPrediction() throws InterruptedException { + Channel channel = TestUtils.getInferenceChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); String requestURL = "/predictions/noop_v1.0"; @@ -194,54 +169,76 @@ private void testNoSnapshotOnPrediction(Channel channel) { channel.writeAndFlush(req); } - private void testRegisterSecondModelSnapshot(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnPrediction"}) + public void testRegisterSecondModelSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, "noop.mar", "noopversioned", true, false); + TestUtils.registerModel(managementChannel, "noop.mar", "noopversioned", true, false); TestUtils.getLatch().await(); validateSnapshot("snapshot6.cfg"); waitForSnapshot(); } - private void testSecondModelVersionSnapshot(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRegisterSecondModelSnapshot"}) + public void testSecondModelVersionSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.registerModel(channel, "noop_v2.mar", "noopversioned", true, false); + TestUtils.registerModel(managementChannel, "noop_v2.mar", "noopversioned", true, false); TestUtils.getLatch().await(); validateSnapshot("snapshot7.cfg"); waitForSnapshot(); } - private void testSetDefaultSnapshot(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testSecondModelVersionSnapshot"}) + public void testSetDefaultSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); String requestURL = "/models/noopversioned/1.2.1/set-default"; HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL); - channel.writeAndFlush(req); + managementChannel.writeAndFlush(req); TestUtils.getLatch().await(); validateSnapshot("snapshot8.cfg"); waitForSnapshot(); } - private void testAsyncScaleModelSnapshot(Channel channel) throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testSetDefaultSnapshot"}) + public void testAsyncScaleModelSnapshot() throws InterruptedException { + Channel managementChannel = TestUtils.getManagementChannel(configManager); TestUtils.setResult(null); TestUtils.setLatch(new CountDownLatch(1)); - TestUtils.scaleModel(channel, "noop_v1.0", null, 2, false); + TestUtils.scaleModel(managementChannel, "noop_v1.0", null, 2, false); TestUtils.getLatch().await(); waitForSnapshot(5000); validateSnapshot("snapshot9.cfg"); waitForSnapshot(); } - private void testStopTorchServeSnapshot() { + @Test( + alwaysRun = true, + dependsOnMethods = {"testAsyncScaleModelSnapshot"}) + public void testStopTorchServeSnapshot() { server.stop(); validateSnapshot("snapshot9.cfg"); } - private void testStartTorchServeWithLastSnapshot() + @Test( + alwaysRun = true, + dependsOnMethods = {"testStopTorchServeSnapshot"}) + public void testStartTorchServeWithLastSnapshot() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException { System.setProperty("tsConfigFile", ""); @@ -260,7 +257,10 @@ private void testStartTorchServeWithLastSnapshot() validateSnapshot("snapshot9.cfg"); } - private void testRestartTorchServeWithSnapshotAsConfig() + @Test( + alwaysRun = true, + dependsOnMethods = {"testStartTorchServeWithLastSnapshot"}) + public void testRestartTorchServeWithSnapshotAsConfig() throws InterruptedException, IOException, GeneralSecurityException, InvalidSnapshotException { server.stop(); @@ -282,11 +282,10 @@ private void testRestartTorchServeWithSnapshotAsConfig() validateSnapshot("snapshot9.cfg"); } - private void validateNoSnapshot() { - validateSnapshot(lastSnapshot); - } - - private void testNoSnapshotOnInvalidModelRegister() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testRestartTorchServeWithSnapshotAsConfig"}) + public void testNoSnapshotOnInvalidModelRegister() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.registerModel(channel, "InvalidModel", "InvalidModel", false, true); @@ -294,7 +293,10 @@ private void testNoSnapshotOnInvalidModelRegister() throws InterruptedException validateSnapshot("snapshot9.cfg"); } - private void testNoSnapshotOnInvalidModelUnregister() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnInvalidModelRegister"}) + public void testNoSnapshotOnInvalidModelUnregister() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.unregisterModel(channel, "InvalidModel", null, true); @@ -302,7 +304,10 @@ private void testNoSnapshotOnInvalidModelUnregister() throws InterruptedExceptio validateSnapshot("snapshot9.cfg"); } - private void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnInvalidModelUnregister"}) + public void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.registerModel(channel, "noopversioned", "3.0", false, true); @@ -310,7 +315,10 @@ private void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedE validateSnapshot("snapshot9.cfg"); } - private void testNoSnapshotOnInvalidModelScale() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnInvalidModelVersionUnregister"}) + public void testNoSnapshotOnInvalidModelScale() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.scaleModel(channel, "invalidModel", null, 1, true); @@ -318,7 +326,10 @@ private void testNoSnapshotOnInvalidModelScale() throws InterruptedException { validateSnapshot("snapshot9.cfg"); } - private void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnInvalidModelScale"}) + public void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); TestUtils.scaleModel(channel, "noopversioned", "3.0", 1, true); @@ -326,7 +337,10 @@ private void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedExcept validateSnapshot("snapshot9.cfg"); } - private void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedException { + @Test( + alwaysRun = true, + dependsOnMethods = {"testNoSnapshotOnInvalidModelVersionScale"}) + public void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedException { Channel channel = TestUtils.connect(true, configManager); Assert.assertNotNull(channel); String requestURL = "/models/noopversioned/3.0/set-default"; @@ -339,6 +353,10 @@ private void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedE validateSnapshot("snapshot9.cfg"); } + private void validateNoSnapshot() { + validateSnapshot(lastSnapshot); + } + private void validateSnapshot(String expectedSnapshot) { lastSnapshot = expectedSnapshot; File expectedSnapshotFile = new File("src/test/resources/snapshots", expectedSnapshot); diff --git a/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java b/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java index 5ea9c7109c..afc58d196b 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java +++ b/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java @@ -39,6 +39,8 @@ public final class TestUtils { static HttpResponseStatus httpStatus; static String result; static HttpHeaders headers; + private static Channel inferenceChannel; + private static Channel managementChannel; private TestUtils() {} @@ -194,6 +196,14 @@ public static void listModels(Channel channel) { channel.writeAndFlush(req); } + public static void setDefault(Channel channel, String modelName, String defaultVersion) { + String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default"; + + HttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL); + channel.writeAndFlush(req); + } + public static Channel connect(boolean management, ConfigManager configManager) { return connect(management, configManager, 120); } @@ -236,6 +246,53 @@ public void initChannel(Channel ch) { return null; } + public static Channel getInferenceChannel(ConfigManager configManager) + throws InterruptedException { + return getChannel(false, configManager); + } + + public static Channel getManagementChannel(ConfigManager configManager) + throws InterruptedException { + return getChannel(true, configManager); + } + + private static Channel getChannel(boolean isManagementChannel, ConfigManager configManager) + throws InterruptedException { + if (isManagementChannel && managementChannel != null && managementChannel.isActive()) { + return managementChannel; + } else if (!isManagementChannel + && inferenceChannel != null + && inferenceChannel.isActive()) { + return inferenceChannel; + } else { + Channel channel = null; + if (channel == null) { + for (int i = 0; i < 5; ++i) { + channel = TestUtils.connect(isManagementChannel, configManager); + if (channel != null) { + break; + } + Thread.sleep(100); + } + } + if (isManagementChannel) { + managementChannel = channel; + } else { + inferenceChannel = channel; + } + return channel; + } + } + + public static void closeChannels() throws InterruptedException { + if (managementChannel != null) { + managementChannel.closeFuture().sync(); + } + if (inferenceChannel != null) { + inferenceChannel.closeFuture().sync(); + } + } + @ChannelHandler.Sharable private static class TestHandler extends SimpleChannelInboundHandler { diff --git a/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java b/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java index bd9fb82f68..ecec62347b 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java @@ -8,12 +8,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import org.junit.Assert; import org.pytorch.serve.TestUtils; import org.pytorch.serve.metrics.Dimension; import org.pytorch.serve.metrics.Metric; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.testng.Assert; import org.testng.annotations.Test; public class ConfigManagerTest { diff --git a/frontend/server/testng.xml b/frontend/server/testng.xml new file mode 100644 index 0000000000..58e956424b --- /dev/null +++ b/frontend/server/testng.xml @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/frontend/tools/conf/pmd.xml b/frontend/tools/conf/pmd.xml index 684635ef00..ac787e1139 100644 --- a/frontend/tools/conf/pmd.xml +++ b/frontend/tools/conf/pmd.xml @@ -140,4 +140,10 @@ + + + + + + diff --git a/frontend/tools/gradle/check.gradle b/frontend/tools/gradle/check.gradle index 0dd6de4a03..881874f9d3 100644 --- a/frontend/tools/gradle/check.gradle +++ b/frontend/tools/gradle/check.gradle @@ -21,7 +21,7 @@ checkstyle { configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml") } checkstyleMain { - classpath += configurations.compile + classpath += configurations.compileClasspath } tasks.withType(Checkstyle) { reports {