diff --git a/examples/Huggingface_Transformers/Download_Transformer_models.py b/examples/Huggingface_Transformers/Download_Transformer_models.py
new file mode 100644
index 0000000000..3357ba8240
--- /dev/null
+++ b/examples/Huggingface_Transformers/Download_Transformer_models.py
@@ -0,0 +1,103 @@
+import transformers
+from pathlib import Path
+import os
+import json
+import torch
+from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,
+ AutoModelForTokenClassification, AutoConfig)
+""" This function, save the checkpoint, config file along with tokenizer config and vocab files
+ of a transformer model of your choice.
+"""
+print('Transformers version',transformers.__version__)
+
+def transformers_model_dowloader(mode,pretrained_model_name,num_labels,do_lower_case):
+ print("Download model and tokenizer", pretrained_model_name)
+ #loading pre-trained model and tokenizer
+ if mode== "sequence_classification":
+ config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels)
+ model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name, config=config)
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
+ elif mode== "question_answering":
+ model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name)
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
+ elif mode== "token_classification":
+ config = AutoConfig.from_pretrained(pretrained_model_name,num_labels=num_labels)
+ model = AutoModelForTokenClassification.from_pretrained(pretrained_model_name)
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name,do_lower_case=do_lower_case)
+
+
+ # NOTE : for demonstration purposes, we do not go through the fine-tune processing here.
+ # A Fine_tunining process based on your needs can be added.
+ # An example of Colab notebook for Fine_tunining process has been provided in the README.
+
+
+ """ For demonstration purposes, we show an example of using question answering
+ data preprocessing and inference. Using the pre-trained models will not yeild
+ good results, we need to use fine_tuned models. For example, instead of "bert-base-uncased",
+ if 'bert-large-uncased-whole-word-masking-finetuned-squad' be passed to
+ the AutoModelForSequenceClassification.from_pretrained(pretrained_model_name)
+ a better result can be achieved. Models such as RoBERTa, xlm, xlnet,etc. can be
+ passed as the pre_trained models.
+ """
+
+ text = r"""
+ 🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose
+ architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural
+ Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between
+ TensorFlow 2.0 and PyTorch.
+ """
+
+ questions = [
+ "How many pretrained models are available in Transformers?",
+ "What does Transformers provide?",
+ "Transformers provides interoperability between which frameworks?",
+ ]
+
+ device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
+ model.to(device)
+
+
+ for question in questions:
+ inputs = tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
+ for key in inputs.keys():
+ inputs[key]= inputs[key].to(device)
+
+ input_ids = inputs["input_ids"].tolist()[0]
+
+ # text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
+
+ answer_start_scores, answer_end_scores = model(**inputs)
+
+ answer_start = torch.argmax(
+ answer_start_scores
+ ) # Get the most likely beginning of answer with the argmax of the score
+ answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score
+
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
+
+ print(f"Question: {question}")
+ print(f"Answer: {answer}\n")
+
+
+ NEW_DIR = "./Transformer_model"
+ try:
+ os.mkdir(NEW_DIR)
+ except OSError:
+ print ("Creation of directory %s failed" % NEW_DIR)
+ else:
+ print ("Successfully created directory %s " % NEW_DIR)
+
+ print("Save model and tokenizer", pretrained_model_name, 'in directory', NEW_DIR)
+ model.save_pretrained(NEW_DIR)
+ tokenizer.save_pretrained(NEW_DIR)
+ return
+if __name__== "__main__":
+ dirname = os.path.dirname(__file__)
+ filename = os.path.join(dirname, 'setup_config.json')
+ f = open(filename)
+ options = json.load(f)
+ mode = options["mode"]
+ model_name = options["model_name"]
+ num_labels = int(options["num_labels"])
+ do_lower_case = options["do_lower_case"]
+ transformers_model_dowloader(mode,model_name, num_labels,do_lower_case)
diff --git a/examples/Huggingface_Transformers/README.md b/examples/Huggingface_Transformers/README.md
new file mode 100644
index 0000000000..24b9b70a9a
--- /dev/null
+++ b/examples/Huggingface_Transformers/README.md
@@ -0,0 +1,87 @@
+## Serving Huggingface Transformers using TorchServe
+
+In this example, we show how to serve a Fine_tuned or off-the-shelf Transformer model from [huggingface](https://huggingface.co/transformers/index.html) using TorchServe. We use a custom handler, Transformer_handler.py. This handler enables us to use pre-trained transformer models from Hugginface, such as BERT, RoBERTA, XLM, etc. for use-cases defined in the AutoModel class such as AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification, and AutoModelWithLMHead will be added later.
+
+We borrowed ideas to write a custom handler for transformers from tutorial presented in [mnist from image classifiers examples](https://github.com/pytorch/serve/tree/master/examples/image_classifier/mnist) and the post by [MFreidank](https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18).
+
+First, we need to make sure that have installed the Transformers, it can be installed using the following command.
+
+ `pip install transformers`
+
+### Objectives
+1. Demonstrate how to package a transformer model with custom handler into torch model archive (.mar) file
+2. Demonstrate how to load model archive (.mar) file into TorchServe and run inference.
+
+### Serving a Model using TorchServe
+
+To serve a model using TrochServe following steps are required.
+
+- Frist, preparing the requirements for packaging a model, including serialized model, and other required files.
+- Create a torch model archive using the "torch-model-archiver" utility to archive the above files along with a handler ( in this example custom handler for transformers) .
+- Register the model on TorchServe using the above model archive file and run the inference.
+
+### **Getting Started with the Demo**
+
+There are two paths to obtain the required model files for this demo.
+
+- **Option A** : To yield desired results, one should fine-tuned each of the intended models to use before hand and saving the model and tokenizer using "save_pretrained() ". This will result in pytorch_model.bin file along with vocab.txt and config.json files. These files should be moved to a folder named "Transformer_model" in the current directory.
+
+- **Option B**: There is another option just for demonstration purposes, to simply run "Download_Transformer_models.py", . The "Download_Transformer_models.py" script loads and saves the required files mentioned above in "Transformer_model" directory, using a setup config file, "setup_config.json". Also, settings in "setup_config.json", are used in the handler, "Transformer_handler_generalized.py", as well to operate on the selected mode and other related settings.
+
+#### Setting the setup_config.json
+
+In the setup_config.json :
+
+*model_name* : bert-base-uncased , roberta-base or other available pre-trained models.
+
+*mode:* "sequence_classification "for sequence classification, "question_answering "for question answering and "token_classification" for token classification.
+
+*do_lower_case* : True or False for use of the Tokenizer.
+
+*num_labels* : number of outputs for "sequence_classification", or "token_classification".
+
+Once, setup_config.json has been set properly, the next step is to run " Download_Transformer_models.py":
+
+`python Download_Transformer_models.py`
+
+This produces all the required files for packaging using a huggingface transformer model off-the-shelf without fine-tuning process. Using this option will create and saved the required files into Transformer_model directory. In case, the "vocab.txt" was not saved into this directory, we can load the tokenizer from pre-trained model vocab, this case has been addressed in the handler.
+
+
+
+#### Setting the extra_files
+
+There are few files that are used for model packaging and at the inference time. "index_to_name.json" is passed as extra file to the model archiver and used for mapping predictions to labels. "sample_text.txt", is used at the inference time to pass the text that we want to get the inference on.
+
+index_to_name.json for question answering is not required.
+
+If intended to use Transformer handler for Token classification, the index_to_name.json should be formatted as follows for example:
+
+`{"label_list":"[O, B-MISC, I-MISC, B-PER,I-PER,B-ORG,I-ORG,B-LOC,I-LOC]"}`
+
+To use Transformer handler for question answering, the sample_text.txt should be formatted as follows:
+
+`{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"}`
+
+"question" represents the question to be asked from the source text named as "context" here.
+
+### Creating a torch Model Archive
+
+Once, setup_config.json, sample_text.txt and index_to_name.json are set properly, we can go ahead and package the model and start serving it. The current setting in "setup_config.json" is based on "roberta_base " model for question answering. To fine-tuned RoBERTa can be obtained from running [squad example](https://huggingface.co/transformers/examples.html#squad) from huggingface. Alternatively a fine_tuned BERT model can be used by setting "model_name" to "bert-large-uncased-whole-word-masking-finetuned-squad" in the "setup_config.json".
+
+```
+torch-model-archiver --model-name RobertaQA --version 1.0 --serialized-file Transformer_model/pytorch_model.bin --handler ./Transformer_handler_generalized.py --extra-files "Transformer_model/config.json,./setup_config.json"
+
+```
+
+### Registering the Model on TorchServe and Running Inference
+
+To register the model on TorchServe using the above model archive file, we run the following commands:
+
+```
+mkdir model_store
+mv RobertaQA.mar model_store/
+torchserve --start --model-store model_store --models my_tc=RobertaQA.mar
+
+```
+
+- To run the inference using our registered model, open a new terminal and run: `curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ./sample_text.txt`
\ No newline at end of file
diff --git a/examples/Huggingface_Transformers/Transformer_handler_generalized.py b/examples/Huggingface_Transformers/Transformer_handler_generalized.py
new file mode 100644
index 0000000000..084d233d09
--- /dev/null
+++ b/examples/Huggingface_Transformers/Transformer_handler_generalized.py
@@ -0,0 +1,157 @@
+from abc import ABC
+import json
+import logging
+import os
+import ast
+import torch
+from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering,AutoModelForTokenClassification
+
+from ts.torch_handler.base_handler import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class TransformersSeqClassifierHandler(BaseHandler, ABC):
+ """
+ Transformers handler class for sequence, token classification and question answering.
+ """
+ def __init__(self):
+ super(TransformersSeqClassifierHandler, self).__init__()
+ self.initialized = False
+
+ def initialize(self, ctx):
+ self.manifest = ctx.manifest
+
+ properties = ctx.system_properties
+ model_dir = properties.get("model_dir")
+ self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
+ #read configs for the mode, model_name, etc. from setup_config.json
+ setup_config_path = os.path.join(model_dir, "setup_config.json")
+
+ if os.path.isfile(setup_config_path):
+ with open(setup_config_path) as setup_config_file:
+ self.setup_config = json.load(setup_config_file)
+ else:
+ logger.warning('Missing the setup_config.json file.')
+
+ #Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
+ #further setup config can be added.
+
+ if self.setup_config["mode"]== "sequence_classification":
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
+ elif self.setup_config["mode"]== "question_answering":
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_dir)
+ elif self.setup_config["mode"]== "token_classification":
+ self.model = AutoModelForTokenClassification.from_pretrained(model_dir)
+ else:
+ logger.warning('Missing the operation mode.')
+
+ if not os.path.isfile(os.path.join(model_dir, "vocab.txt")):
+ self.tokenizer = AutoTokenizer.from_pretrained(self.setup_config["model_name"],do_lower_case=self.setup_config["do_lower_case"])
+ else:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir,do_lower_case=self.setup_config["do_lower_case"])
+
+ self.model.to(self.device)
+ self.model.eval()
+
+ logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))
+
+ # Read the mapping file, index to object name
+ mapping_file_path = os.path.join(model_dir, "index_to_name.json")
+ # Question answering does not need the index_to_name.json file.
+ if not self.setup_config["mode"]== "question_answering":
+ if os.path.isfile(mapping_file_path):
+ with open(mapping_file_path) as f:
+ self.mapping = json.load(f)
+ else:
+ logger.warning('Missing the index_to_name.json file.')
+
+ self.initialized = True
+
+ def preprocess(self, data):
+ """ Basic text preprocessing, based on the user's chocie of application mode.
+ """
+ text = data[0].get("data")
+ if text is None:
+ text = data[0].get("body")
+ input_text = text.decode('utf-8')
+ logger.info("Received text: '%s'", input_text)
+ #preprocessing text for sequence_classification and token_classification.
+ if self.setup_config["mode"]== "sequence_classification" or self.setup_config["mode"]== "token_classification" :
+ inputs = self.tokenizer.encode_plus(input_text, add_special_tokens = True, return_tensors = 'pt')
+ #preprocessing text for question_answering.
+ elif self.setup_config["mode"]== "question_answering":
+ # the sample text for question_answering should be formated as dictionary
+ # with question and text as keys and related text as values.
+ # we use this format here seperate question and text for encoding.
+ #TODO extend to handle multiple questions, cleaning and dealing with long the context.
+ question_context= ast.literal_eval(input_text)
+ question = question_context["question"]
+ context = question_context["context"]
+ inputs = self.tokenizer.encode_plus(question, context, add_special_tokens=True, return_tensors="pt")
+
+ return inputs
+
+ def inference(self, inputs):
+ """ Predict the class (or classes) of the received text using the serialized transformers checkpoint.
+ """
+
+
+ input_ids = inputs["input_ids"].to(self.device)
+ # Handling inference for sequence_classification.
+ if self.setup_config["mode"]== "sequence_classification":
+ predictions = self.model(input_ids)
+ prediction = predictions[0].argmax(1).item()
+
+ logger.info("Model predicted: '%s'", prediction)
+
+ if self.mapping:
+ prediction = self.mapping[str(prediction)]
+ # Handling inference for question_answering.
+ elif self.setup_config["mode"]== "question_answering":
+ # the output should be only answer_start and answer_end
+ # we are outputing the words just for demonstration.
+ answer_start_scores, answer_end_scores = self.model(input_ids)
+ answer_start = torch.argmax(answer_start_scores) # Get the most likely beginning of answer with the argmax of the score
+ answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score
+ input_ids = inputs["input_ids"].tolist()[0]
+ prediction = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
+
+ logger.info("Model predicted: '%s'", prediction)
+ # Handling inference for token_classification.
+ elif self.setup_config["mode"]== "token_classification":
+ outputs = self.model(input_ids)[0]
+ predictions = torch.argmax(outputs, dim=2)
+ tokens = self.tokenizer.tokenize(self.tokenizer.decode(inputs["input_ids"][0]))
+ if self.mapping:
+ label_list = self.mapping["label_list"]
+ label_list = label_list.strip('][').split(', ')
+ prediction = [(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())]
+
+ logger.info("Model predicted: '%s'", prediction)
+
+ return [prediction]
+
+ def postprocess(self, inference_output):
+ # TODO: Add any needed post-processing of the model predictions here
+ return inference_output
+
+
+_service = TransformersSeqClassifierHandler()
+
+
+def handle(data, context):
+ try:
+ if not _service.initialized:
+ _service.initialize(context)
+
+ if data is None:
+ return None
+
+ data = _service.preprocess(data)
+ data = _service.inference(data)
+ data = _service.postprocess(data)
+
+ return data
+ except Exception as e:
+ raise e
diff --git a/examples/Huggingface_Transformers/index_to_name.json b/examples/Huggingface_Transformers/index_to_name.json
new file mode 100644
index 0000000000..9ccff719f6
--- /dev/null
+++ b/examples/Huggingface_Transformers/index_to_name.json
@@ -0,0 +1,4 @@
+{
+ "0":"Not Accepted",
+ "1":"Accepted"
+}
diff --git a/examples/Huggingface_Transformers/sample_text.txt b/examples/Huggingface_Transformers/sample_text.txt
new file mode 100644
index 0000000000..e011e0184b
--- /dev/null
+++ b/examples/Huggingface_Transformers/sample_text.txt
@@ -0,0 +1 @@
+{"question" :"Who was Jim Henson?", "context": "Jim Henson was a nice puppet"}
diff --git a/examples/Huggingface_Transformers/setup_config.json b/examples/Huggingface_Transformers/setup_config.json
new file mode 100644
index 0000000000..c24780961e
--- /dev/null
+++ b/examples/Huggingface_Transformers/setup_config.json
@@ -0,0 +1,6 @@
+{
+ "model_name":"roberta-base",
+ "mode":"question_answering",
+ "do_lower_case":"True",
+ "num_labels":"0"
+}
diff --git a/examples/text_classification/Transformers_README.md b/examples/text_classification/Transformers_README.md
new file mode 100644
index 0000000000..15a92e6668
--- /dev/null
+++ b/examples/text_classification/Transformers_README.md
@@ -0,0 +1,46 @@
+# Serving Transformers Bert using TorchServe
+
+#### Prerequisite:
+
+```
+pip install transformers
+pip install torch torchtext
+pip install torchserve torch-model-archiver
+```
+
+```
+git clone https://github.com/HamidShojanazeri/serve.git
+
+```
+
+#### Preparing Serialized file for torch-model-archiver:
+
+```
+cd serve/examples/text_classification
+python bert/bert_serialization.py # outputs the jit.traced model "traced_bert.pt".
+mv traced_bert.pt bert/
+```
+
+#### Archive the model:
+
+```
+torch-model-archiver --model-name BertSeqClassification --version 1.0 --serialized-file bert/traced_bert.pt --handler bert/bert_handler.py --extra-files ./index_to_name.json
+```
+
+```
+mkdir model_store
+mv BertSeqClassification.mar model_store/
+```
+
+#### Start TorchServe to serve the model:
+
+```
+torchserve --start --model-store model_store --models my_tc=BertSeqClassification.mar
+```
+
+#### Get predictions from a model:
+
+```
+curl -X POST http://127.0.0.1:8080/predictions/my_tc -T ./sample_text.txt
+```
+
diff --git a/examples/text_classification/sample_text.txt b/examples/text_classification/sample_text.txt
index dccbbf96f9..760903ca07 100644
--- a/examples/text_classification/sample_text.txt
+++ b/examples/text_classification/sample_text.txt
@@ -1 +1 @@
-MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind.
\ No newline at end of file
+Bloomberg has decided to publish a new report on global economic situation.
diff --git a/examples/text_to_speech_synthesizer/README.md b/examples/text_to_speech_synthesizer/README.md
new file mode 100644
index 0000000000..8fac78d154
--- /dev/null
+++ b/examples/text_to_speech_synthesizer/README.md
@@ -0,0 +1,59 @@
+# Text to speech synthesis using WaveGlow & Tacotron2 model.
+
+**This example works only on NVIDIA CUDA device and not on CPU**
+
+We have used the following Waveglow/Tacotron2 model for this example:
+
+https://pytorch.org/hub/nvidia_deeplearningexamples_waveglow/
+
+We have copied WaveGlow's model file from following github repo:
+https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py
+
+
+# Install pip dependencies using following commands
+
+```bash
+pip install numpy scipy unidecode inflect
+pip install librosa --user
+```
+
+# Serve the WaveGlow speech synthesis model on TorchServe
+
+ * Generate the model archive for waveglow speech synthesis model using following command
+
+ ```bash
+ ./create_mar.sh
+ ```
+
+ * Register the model on TorchServe using the above model archive file
+
+ ```bash
+ mkdir model_store
+ mv waveglow_synthesizer.mar model_store/
+ torchserve --start --model-store model_store --models waveglow_synthesizer.mar
+ ```
+ * Run inference and download audio output using curl command :
+ ```bash
+ curl -X POST http://127.0.0.1:8080/predictions/waveglow_synthesizer -T sample_text.txt -o audio.wav
+ ```
+
+ * Run inference and download audio output using python script :
+
+ ```python
+ import requests
+
+ files = {'data': open('sample_text.txt','rb')}
+ response = requests.post('http://localhost:8080/predictions/waveglow_synthesizer', files=files)
+ data = response.content
+
+ with open('audio.wav', 'wb') as audio_file:
+ audio_file.write(data)
+ ```
+
+ * Change the host and port in above samples as per your server configuration.
+
+ * Response :
+ An audio.wav file gets downloaded.
+
+ **Note :** The above example works only for smaller text size. Refer following NVidia/DeepLearningExamples ticket for more details :
+ https://github.com/NVIDIA/DeepLearningExamples/issues/497
diff --git a/examples/text_to_speech_synthesizer/create_mar.sh b/examples/text_to_speech_synthesizer/create_mar.sh
new file mode 100755
index 0000000000..b4c7b6e8ef
--- /dev/null
+++ b/examples/text_to_speech_synthesizer/create_mar.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+set -euxo pipefail
+
+cd /tmp
+rm torchhub.zip
+wget https://github.com/nvidia/DeepLearningExamples/archive/torchhub.zip
+rm -rf DeepLearningExamples-torchhub
+unzip torchhub.zip
+cd -
+rm tacotron.zip
+rm -rf PyTorch
+mkdir -p PyTorch/SpeechSynthesis
+cp -r /tmp/DeepLearningExamples-torchhub/PyTorch/SpeechSynthesis/* PyTorch/SpeechSynthesis/
+zip -r tacotron.zip PyTorch
+wget https://api.ngc.nvidia.com/v2/models/nvidia/tacotron2pyt_fp32/versions/1/files/nvidia_tacotron2pyt_fp32_20190306.pth
+wget https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth
+torch-model-archiver --model-name waveglow_synthesizer --version 1.0 --model-file waveglow_model.py --serialized-file nvidia_waveglowpyt_fp32_20190306.pth --handler waveglow_handler.py --extra-files tacotron.zip,nvidia_tacotron2pyt_fp32_20190306.pth
+rm nvidia_*
diff --git a/examples/text_to_speech_synthesizer/sample_text.txt b/examples/text_to_speech_synthesizer/sample_text.txt
new file mode 100644
index 0000000000..98cfee721f
--- /dev/null
+++ b/examples/text_to_speech_synthesizer/sample_text.txt
@@ -0,0 +1 @@
+hello world, I missed you
\ No newline at end of file
diff --git a/examples/text_to_speech_synthesizer/waveglow_handler.py b/examples/text_to_speech_synthesizer/waveglow_handler.py
new file mode 100644
index 0000000000..88f1b82f88
--- /dev/null
+++ b/examples/text_to_speech_synthesizer/waveglow_handler.py
@@ -0,0 +1,120 @@
+import logging
+import numpy as np
+import os
+import torch
+import uuid
+import zipfile
+from waveglow_model import WaveGlow
+from scipy.io.wavfile import write, read
+
+logger = logging.getLogger(__name__)
+
+
+class WaveGlowSpeechSynthesizer(object):
+
+ def __init__(self):
+ self.waveglow_model = None
+ self.tacotron2_model = None
+ self.mapping = None
+ self.device = None
+ self.initialized = False
+ self.metrics = None
+
+ # From https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/inference.py
+ def _unwrap_distributed(self, state_dict):
+ """
+ Unwraps model from DistributedDataParallel.
+ DDP wraps model in additional "module.", it needs to be removed for single
+ GPU inference.
+ :param state_dict: model's state dict
+ """
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ new_key = key.replace('module.', '')
+ new_state_dict[new_key] = value
+ return new_state_dict
+
+ def _load_tacotron2_model(self, model_dir):
+ from PyTorch.SpeechSynthesis.Tacotron2.tacotron2 import model as tacotron2
+ from PyTorch.SpeechSynthesis.Tacotron2.tacotron2.text import text_to_sequence
+ tacotron2_checkpoint = torch.load(os.path.join(model_dir, 'nvidia_tacotron2pyt_fp32_20190306.pth'))
+ tacotron2_state_dict = self._unwrap_distributed(tacotron2_checkpoint['state_dict'])
+ tacotron2_config = tacotron2_checkpoint['config']
+ self.tacotron2_model = tacotron2.Tacotron2(**tacotron2_config)
+ self.tacotron2_model.load_state_dict(tacotron2_state_dict)
+ self.tacotron2_model.text_to_sequence = text_to_sequence
+ self.tacotron2_model.to(self.device)
+
+ def initialize(self, ctx):
+ """First try to load torchscript else load eager mode state_dict based model"""
+
+ properties = ctx.system_properties
+ model_dir = properties.get("model_dir")
+ if not torch.cuda.is_available():
+ raise RuntimeError("This model is not supported on CPU machines.")
+ self.device = torch.device("cuda:" + str(properties.get("gpu_id")))
+
+ with zipfile.ZipFile(model_dir + '/tacotron.zip', 'r') as zip_ref:
+ zip_ref.extractall(model_dir)
+
+ waveglow_checkpoint = torch.load(os.path.join(model_dir, "nvidia_waveglowpyt_fp32_20190306.pth"))
+ waveglow_state_dict = self._unwrap_distributed(waveglow_checkpoint['state_dict'])
+ waveglow_config = waveglow_checkpoint['config']
+ self.waveglow_model = WaveGlow(**waveglow_config)
+ self.waveglow_model.load_state_dict(waveglow_state_dict)
+ self.waveglow_model = self.waveglow_model.remove_weightnorm(self.waveglow_model)
+ self.waveglow_model.to(self.device)
+ self.waveglow_model.eval()
+
+ self._load_tacotron2_model(model_dir)
+
+ logger.debug('WaveGlow model file loaded successfully')
+ self.initialized = True
+
+ def preprocess(self, data):
+ """
+ Scales, crops, and normalizes a PIL image for a MNIST model,
+ returns an Numpy array
+ """
+ text = data[0].get("data")
+ if text is None:
+ text = data[0].get("body")
+ text = text.decode('utf-8')
+
+ sequence = np.array(self.tacotron2_model.text_to_sequence(text, ['english_cleaners']))[None, :]
+ sequence = torch.from_numpy(sequence).to(device=self.device, dtype=torch.int64)
+
+ return sequence
+
+ def inference(self, data):
+ with torch.no_grad():
+ _, mel, _, _ = self.tacotron2_model.infer(data)
+ audio = self.waveglow_model.infer(mel)
+
+ return audio
+
+ def postprocess(self, inference_output):
+ audio_numpy = inference_output[0].data.cpu().numpy()
+ path = "/tmp/{}.wav".format(uuid.uuid4().hex)
+ write(path, 22050, audio_numpy)
+ with open(path, 'rb') as output:
+ data = output.read()
+ os.remove(path)
+ return [data]
+
+
+_service = WaveGlowSpeechSynthesizer()
+
+
+def handle(data, context):
+ if not _service.initialized:
+ _service.initialize(context)
+
+ if data is None:
+ return None
+
+ data = _service.preprocess(data)
+ data = _service.inference(data)
+ data = _service.postprocess(data)
+
+ return data
diff --git a/examples/text_to_speech_synthesizer/waveglow_model.py b/examples/text_to_speech_synthesizer/waveglow_model.py
new file mode 100644
index 0000000000..c799709a87
--- /dev/null
+++ b/examples/text_to_speech_synthesizer/waveglow_model.py
@@ -0,0 +1,292 @@
+# *****************************************************************************
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the
+# names of its contributors may be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# *****************************************************************************
+import torch
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class Invertible1x1Conv(torch.nn.Module):
+ """
+ The layer outputs both the convolution, and the log determinant
+ of its weight matrix. If reverse=True it does convolution with
+ inverse
+ """
+
+ def __init__(self, c):
+ super(Invertible1x1Conv, self).__init__()
+ self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
+ bias=False)
+
+ # Sample a random orthonormal matrix to initialize weights
+ W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
+
+ # Ensure determinant is 1.0 not -1.0
+ if torch.det(W) < 0:
+ W[:, 0] = -1 * W[:, 0]
+ W = W.view(c, c, 1)
+ self.conv.weight.data = W
+
+ def forward(self, z, reverse=False):
+ # shape
+ batch_size, group_size, n_of_groups = z.size()
+
+ W = self.conv.weight.squeeze()
+
+ if reverse:
+ if not hasattr(self, 'W_inverse'):
+ # Reverse computation
+ W_inverse = W.float().inverse()
+ W_inverse = Variable(W_inverse[..., None])
+ if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
+ W_inverse = W_inverse.half()
+ self.W_inverse = W_inverse
+ z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
+ return z
+ else:
+ # Forward computation
+ log_det_W = batch_size * n_of_groups * torch.logdet(W.unsqueeze(0).float()).squeeze()
+ z = self.conv(z)
+ return z, log_det_W
+
+
+class WN(torch.nn.Module):
+ """
+ This is the WaveNet like layer for the affine coupling. The primary
+ difference from WaveNet is the convolutions need not be causal. There is
+ also no dilation size reset. The dilation only doubles on each layer
+ """
+
+ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
+ kernel_size):
+ super(WN, self).__init__()
+ assert(kernel_size % 2 == 1)
+ assert(n_channels % 2 == 0)
+ self.n_layers = n_layers
+ self.n_channels = n_channels
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.cond_layers = torch.nn.ModuleList()
+
+ start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
+ start = torch.nn.utils.weight_norm(start, name='weight')
+ self.start = start
+
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+
+ for i in range(n_layers):
+ dilation = 2 ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(n_channels, 2 * n_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels, 1)
+ cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+ self.cond_layers.append(cond_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * n_channels
+ else:
+ res_skip_channels = n_channels
+ res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(
+ res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, forward_input):
+ audio, spect = forward_input
+ audio = self.start(audio)
+
+ for i in range(self.n_layers):
+ acts = fused_add_tanh_sigmoid_multiply(
+ self.in_layers[i](audio),
+ self.cond_layers[i](spect),
+ torch.IntTensor([self.n_channels]))
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ audio = res_skip_acts[:, :self.n_channels, :] + audio
+ skip_acts = res_skip_acts[:, self.n_channels:, :]
+ else:
+ skip_acts = res_skip_acts
+
+ if i == 0:
+ output = skip_acts
+ else:
+ output = skip_acts + output
+ return self.end(output)
+
+
+class WaveGlow(torch.nn.Module):
+ def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
+ n_early_size, WN_config):
+ super(WaveGlow, self).__init__()
+
+ self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
+ n_mel_channels,
+ 1024, stride=256)
+ assert(n_group % 2 == 0)
+ self.n_flows = n_flows
+ self.n_group = n_group
+ self.n_early_every = n_early_every
+ self.n_early_size = n_early_size
+ self.WN = torch.nn.ModuleList()
+ self.convinv = torch.nn.ModuleList()
+
+ n_half = int(n_group / 2)
+
+ # Set up layers with the right sizes based on how many dimensions
+ # have been output already
+ n_remaining_channels = n_group
+ for k in range(n_flows):
+ if k % self.n_early_every == 0 and k > 0:
+ n_half = n_half - int(self.n_early_size / 2)
+ n_remaining_channels = n_remaining_channels - self.n_early_size
+ self.convinv.append(Invertible1x1Conv(n_remaining_channels))
+ self.WN.append(WN(n_half, n_mel_channels * n_group, **WN_config))
+ self.n_remaining_channels = n_remaining_channels
+
+ def forward(self, forward_input):
+ """
+ forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
+ forward_input[1] = audio: batch x time
+ """
+ spect, audio = forward_input
+
+ # Upsample spectrogram to size of audio
+ spect = self.upsample(spect)
+ assert(spect.size(2) >= audio.size(1))
+ if spect.size(2) > audio.size(1):
+ spect = spect[:, :, :audio.size(1)]
+
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
+ spect = spect.permute(0, 2, 1)
+
+ audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
+ output_audio = []
+ log_s_list = []
+ log_det_W_list = []
+
+ for k in range(self.n_flows):
+ if k % self.n_early_every == 0 and k > 0:
+ output_audio.append(audio[:, :self.n_early_size, :])
+ audio = audio[:, self.n_early_size:, :]
+
+ audio, log_det_W = self.convinv[k](audio)
+ log_det_W_list.append(log_det_W)
+
+ n_half = int(audio.size(1) / 2)
+ audio_0 = audio[:, :n_half, :]
+ audio_1 = audio[:, n_half:, :]
+
+ output = self.WN[k]((audio_0, spect))
+ log_s = output[:, n_half:, :]
+ b = output[:, :n_half, :]
+ audio_1 = torch.exp(log_s) * audio_1 + b
+ log_s_list.append(log_s)
+
+ audio = torch.cat([audio_0, audio_1], 1)
+
+ output_audio.append(audio)
+ return torch.cat(output_audio, 1), log_s_list, log_det_W_list
+
+ def infer(self, spect, sigma=1.0):
+
+ spect = self.upsample(spect)
+ # trim conv artifacts. maybe pad spec to kernel multiple
+ time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
+ spect = spect[:, :, :-time_cutoff]
+
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
+ spect = spect.permute(0, 2, 1)
+
+ audio = torch.randn(spect.size(0),
+ self.n_remaining_channels,
+ spect.size(2), device=spect.device).to(spect.dtype)
+
+ audio = torch.autograd.Variable(sigma * audio)
+
+ for k in reversed(range(self.n_flows)):
+ n_half = int(audio.size(1) / 2)
+ audio_0 = audio[:, :n_half, :]
+ audio_1 = audio[:, n_half:, :]
+
+ output = self.WN[k]((audio_0, spect))
+ s = output[:, n_half:, :]
+ b = output[:, :n_half, :]
+ audio_1 = (audio_1 - b) / torch.exp(s)
+ audio = torch.cat([audio_0, audio_1], 1)
+
+ audio = self.convinv[k](audio, reverse=True)
+
+ if k % self.n_early_every == 0 and k > 0:
+ z = torch.randn(spect.size(0), self.n_early_size, spect.size(
+ 2), device=spect.device).to(spect.dtype)
+ audio = torch.cat((sigma * z, audio), 1)
+
+ audio = audio.permute(
+ 0, 2, 1).contiguous().view(
+ audio.size(0), -1).data
+ return audio
+
+
+ @staticmethod
+ def remove_weightnorm(model):
+ waveglow = model
+ for WN in waveglow.WN:
+ WN.start = torch.nn.utils.remove_weight_norm(WN.start)
+ WN.in_layers = remove(WN.in_layers)
+ WN.cond_layers = remove(WN.cond_layers)
+ WN.res_skip_layers = remove(WN.res_skip_layers)
+ return waveglow
+
+
+def remove(conv_list):
+ new_conv_list = torch.nn.ModuleList()
+ for old_conv in conv_list:
+ old_conv = torch.nn.utils.remove_weight_norm(old_conv)
+ new_conv_list.append(old_conv)
+ return new_conv_list
diff --git a/frontend/build.gradle b/frontend/build.gradle
index 105a3b62aa..8c9bc4d7ba 100644
--- a/frontend/build.gradle
+++ b/frontend/build.gradle
@@ -30,7 +30,7 @@ def javaProjects() {
}
configure(javaProjects()) {
- apply plugin: 'java'
+ apply plugin: 'java-library'
sourceCompatibility = 1.8
targetCompatibility = 1.8
@@ -42,7 +42,7 @@ configure(javaProjects()) {
test {
useTestNG() {
- // suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml
+ suites 'testng.xml'
}
testLogging {
diff --git a/frontend/gradle.properties b/frontend/gradle.properties
index 97d3803c09..8c7e8ceab5 100644
--- a/frontend/gradle.properties
+++ b/frontend/gradle.properties
@@ -5,5 +5,5 @@ slf4j_api_version=1.7.25
slf4j_log4j12_version=1.7.25
gson_version=2.8.5
commons_cli_version=1.3.1
-testng_version=6.8.1
-torchserve_sdk_version=0.0.3
\ No newline at end of file
+testng_version=7.1.0
+torchserve_sdk_version=0.0.3
diff --git a/frontend/gradle/wrapper/gradle-wrapper.jar b/frontend/gradle/wrapper/gradle-wrapper.jar
index 6ffa237849..62d4c05355 100644
Binary files a/frontend/gradle/wrapper/gradle-wrapper.jar and b/frontend/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/frontend/gradle/wrapper/gradle-wrapper.properties b/frontend/gradle/wrapper/gradle-wrapper.properties
index 57c147560f..4c5803d13c 100644
--- a/frontend/gradle/wrapper/gradle-wrapper.properties
+++ b/frontend/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,5 @@
-#Thu Apr 13 16:20:04 PDT 2017
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-6.4-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-6.2.2-bin.zip
diff --git a/frontend/gradlew b/frontend/gradlew
index 9aa616c273..fbd7c51583 100755
--- a/frontend/gradlew
+++ b/frontend/gradlew
@@ -1,4 +1,20 @@
-#!/usr/bin/env bash
+#!/usr/bin/env sh
+
+#
+# Copyright 2015 the original author or authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
##############################################################################
##
@@ -28,16 +44,16 @@ APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-DEFAULT_JVM_OPTS=""
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
-warn ( ) {
+warn () {
echo "$*"
}
-die ( ) {
+die () {
echo
echo "$*"
echo
@@ -66,6 +82,7 @@ esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
@@ -109,10 +126,11 @@ if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
-# For Cygwin, switch paths to Windows format before running java
-if $cygwin ; then
+# For Cygwin or MSYS, switch paths to Windows format before running java
+if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
@@ -138,32 +156,30 @@ if $cygwin ; then
else
eval `echo args$i`="\"$arg\""
fi
- i=$((i+1))
+ i=`expr $i + 1`
done
case $i in
- (0) set -- ;;
- (1) set -- "$args0" ;;
- (2) set -- "$args0" "$args1" ;;
- (3) set -- "$args0" "$args1" "$args2" ;;
- (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
- (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
- (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
- (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
- (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
- (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+ 0) set -- ;;
+ 1) set -- "$args0" ;;
+ 2) set -- "$args0" "$args1" ;;
+ 3) set -- "$args0" "$args1" "$args2" ;;
+ 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+ 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+ 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+ 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+ 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+ 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
-# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules
-function splitJvmOpts() {
- JVM_OPTS=("$@")
+# Escape application args
+save () {
+ for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+ echo " "
}
-eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS
-JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME"
+APP_ARGS=`save "$@"`
-# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
-if [[ "$(uname)" == "Darwin" ]] && [[ "$HOME" == "$PWD" ]]; then
- cd "$(dirname "$0")"
-fi
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
-exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@"
+exec "$JAVACMD" "$@"
diff --git a/frontend/gradlew.bat b/frontend/gradlew.bat
index e95643d6a2..a9f778a7a9 100755
--- a/frontend/gradlew.bat
+++ b/frontend/gradlew.bat
@@ -1,3 +1,19 @@
+@rem
+@rem Copyright 2015 the original author or authors.
+@rem
+@rem Licensed under the Apache License, Version 2.0 (the "License");
+@rem you may not use this file except in compliance with the License.
+@rem You may obtain a copy of the License at
+@rem
+@rem https://www.apache.org/licenses/LICENSE-2.0
+@rem
+@rem Unless required by applicable law or agreed to in writing, software
+@rem distributed under the License is distributed on an "AS IS" BASIS,
+@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+@rem See the License for the specific language governing permissions and
+@rem limitations under the License.
+@rem
+
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@@ -13,8 +29,11 @@ if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
+@rem Resolve any "." and ".." in APP_HOME to make it shorter.
+for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
+
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-set DEFAULT_JVM_OPTS=
+set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
@@ -65,6 +84,7 @@ set CMD_LINE_ARGS=%*
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
diff --git a/frontend/modelarchive/build.gradle b/frontend/modelarchive/build.gradle
index b080f62765..b8d6e992a4 100644
--- a/frontend/modelarchive/build.gradle
+++ b/frontend/modelarchive/build.gradle
@@ -1,9 +1,9 @@
dependencies {
- compile "commons-io:commons-io:2.6"
- compile "org.slf4j:slf4j-api:${slf4j_api_version}"
- compile "org.slf4j:slf4j-log4j12:${slf4j_log4j12_version}"
- compile "com.google.code.gson:gson:${gson_version}"
+ api "commons-io:commons-io:2.6"
+ api "org.slf4j:slf4j-api:${slf4j_api_version}"
+ api "org.slf4j:slf4j-log4j12:${slf4j_log4j12_version}"
+ api "com.google.code.gson:gson:${gson_version}"
- testCompile "commons-cli:commons-cli:${commons_cli_version}"
- testCompile "org.testng:testng:${testng_version}"
+ testImplementation "commons-cli:commons-cli:${commons_cli_version}"
+ testImplementation "org.testng:testng:${testng_version}"
}
diff --git a/frontend/modelarchive/testng.xml b/frontend/modelarchive/testng.xml
new file mode 100644
index 0000000000..b34b60ac4b
--- /dev/null
+++ b/frontend/modelarchive/testng.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle
index 75b3a80aa9..6f4d114340 100644
--- a/frontend/server/build.gradle
+++ b/frontend/server/build.gradle
@@ -1,9 +1,9 @@
dependencies {
- compile "io.netty:netty-all:${netty_version}"
- compile project(":modelarchive")
- compile "commons-cli:commons-cli:${commons_cli_version}"
- compile "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
- testCompile "org.testng:testng:${testng_version}"
+ implementation "io.netty:netty-all:${netty_version}"
+ implementation project(":modelarchive")
+ implementation "commons-cli:commons-cli:${commons_cli_version}"
+ implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
+ testImplementation "org.testng:testng:${testng_version}"
}
apply from: file("${project.rootProject.projectDir}/tools/gradle/launcher.gradle")
@@ -13,7 +13,7 @@ jar {
attributes 'Main-Class': 'org.pytorch.serve.ModelServer'
}
includeEmptyDirs = false
- from configurations.runtime.collect { it.isDirectory() ? it : zipTree(it) }
+ from configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
exclude "META-INF/maven/**"
exclude "META-INF/INDEX.LIST"
diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java
index 3a67719f8f..bc8cafb599 100644
--- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java
+++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java
@@ -44,7 +44,6 @@
import org.testng.annotations.Test;
public class ModelServerTest {
-
private static final String ERROR_NOT_FOUND =
"Requested resource is not found, please refer to API document.";
private static final String ERROR_METHOD_NOT_ALLOWED =
@@ -92,157 +91,8 @@ public void afterSuite() {
}
@Test
- public void test()
- throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
- IOException, NoSuchFieldException, IllegalAccessException {
- Channel channel = null;
- Channel managementChannel = null;
- for (int i = 0; i < 5; ++i) {
- channel = TestUtils.connect(false, configManager);
- if (channel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- for (int i = 0; i < 5; ++i) {
- managementChannel = TestUtils.connect(true, configManager);
- if (managementChannel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- Assert.assertNotNull(channel, "Failed to connect to inference port.");
- Assert.assertNotNull(managementChannel, "Failed to connect to management port.");
-
- testPing(channel);
-
- testRoot(channel, listInferenceApisResult);
- testRoot(managementChannel, listManagementApisResult);
- testApiDescription(channel, listInferenceApisResult);
- testDescribeApi(channel);
- testUnregisterModel(managementChannel, "noop", null);
- testLoadModel(managementChannel, "noop.mar", "noop_v1.0");
- testSyncScaleModel(managementChannel, "noop_v1.0", null);
- testListModels(managementChannel);
- testDescribeModel(managementChannel, "noop_v1.0", null, "1.11");
- testLoadModelWithInitialWorkers(managementChannel, "noop.mar", "noop");
- testLoadModelWithInitialWorkers(managementChannel, "noop.mar", "noopversioned");
- testLoadModelWithInitialWorkers(managementChannel, "noop_v2.mar", "noopversioned");
- testDescribeModel(managementChannel, "noopversioned", null, "1.11");
- testDescribeModel(managementChannel, "noopversioned", "all", "1.2.1");
- testDescribeModel(managementChannel, "noopversioned", "1.11", "1.11");
- testPredictions(channel, "noopversioned", "OK", "1.2.1");
- testSetDefault(managementChannel, "noopversioned", "1.2.1");
- testLoadModelWithInitialWorkersWithJSONReqBody(managementChannel);
- testScaleModel(managementChannel);
- testPredictions(channel, "noop", "OK", null);
- testPredictionsBinary(channel);
- testPredictionsJson(channel);
- testInvocationsJson(channel);
- testInvocationsMultipart(channel);
- testModelsInvokeJson(channel);
- testModelsInvokeMultipart(channel);
- testLegacyPredict(channel);
- testPredictionsInvalidRequestSize(channel);
- testPredictionsValidRequestSize(channel);
- testPredictionsDecodeRequest(channel, managementChannel);
- testPredictionsDoNotDecodeRequest(channel, managementChannel);
- testPredictionsModifyResponseHeader(channel, managementChannel);
- testPredictionsNoManifest(channel, managementChannel);
- testModelRegisterWithDefaultWorkers(managementChannel);
- testLoadModelFromURL(managementChannel);
- testUnregisterURLModel(managementChannel);
- testLoadingMemoryError();
- testPredictionMemoryError();
- testMetricManager();
- testErrorBatch();
-
- channel.close();
- managementChannel.close();
-
- // negative test case, channel will be closed by server
- testInvalidRootRequest();
- testInvalidInferenceUri();
- testInvalidPredictionsUri();
- testInvalidDescribeModel();
- testPredictionsModelNotFound();
- testPredictionsModelVersionNotFound();
-
- testInvalidManagementUri();
- testInvalidModelsMethod();
- testInvalidModelMethod();
- testDescribeModelNotFound();
- testDescribeModelVersionNotFound();
- testRegisterModelMissingUrl();
- testRegisterModelInvalidRuntime();
- testRegisterModelNotFound();
- testRegisterModelConflict();
- testRegisterModelMalformedUrl();
- testRegisterModelConnectionFailed();
- testRegisterModelHttpError();
- testRegisterModelInvalidPath();
- testScaleModelNotFound();
- testScaleModelVersionNotFound();
- testScaleModelFailure();
- testUnregisterModelNotFound();
- testUnregisterModelVersionNotFound();
- testUnregisterModelTimeout();
- testSetInvalidVersionDefault("noopversioned", "3.3.3");
- testUnregisterModelFailure("noopversioned", "1.2.1");
-
- testTS();
- }
-
- public void testTS()
- throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
- IOException, NoSuchFieldException, IllegalAccessException {
- Channel channel = null;
- Channel managementChannel = null;
- for (int i = 0; i < 5; ++i) {
- channel = TestUtils.connect(false, configManager, 300);
- if (channel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- for (int i = 0; i < 5; ++i) {
- managementChannel = TestUtils.connect(true, configManager, 300);
- if (managementChannel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- Assert.assertNotNull(channel, "Failed to connect to inference port.");
- Assert.assertNotNull(managementChannel, "Failed to connect to management port.");
-
- testLoadModelWithInitialWorkers(managementChannel, "mnist.mar", "mnist");
- testPredictions(channel, "mnist", "0", null);
- testUnregisterModel(managementChannel, "mnist", null);
- testLoadModelWithInitialWorkers(managementChannel, "mnist_scripted.mar", "mnist_scripted");
- testPredictions(channel, "mnist_scripted", "0", null);
- testUnregisterModel(managementChannel, "mnist_scripted", null);
- testLoadModelWithInitialWorkers(managementChannel, "mnist_traced.mar", "mnist_traced");
- testPredictions(channel, "mnist_traced", "0", null);
- testUnregisterModel(managementChannel, "mnist_traced", null);
-
- channel.close();
- managementChannel.close();
- }
-
- private void testRoot(Channel channel, String expected) throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.getRoot(channel);
- TestUtils.getLatch().await();
-
- Assert.assertEquals(TestUtils.getResult(), expected);
- }
-
- private void testPing(Channel channel) throws InterruptedException {
+ public void testPing() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/ping");
@@ -254,259 +104,240 @@ private void testPing(Channel channel) throws InterruptedException {
Assert.assertTrue(TestUtils.getHeaders().contains("x-request-id"));
}
- private void testApiDescription(Channel channel, String expected) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPing"})
+ public void testRootInference() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.getApiDescription(channel);
+ TestUtils.getRoot(channel);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getResult(), expected);
+ Assert.assertEquals(TestUtils.getResult(), listInferenceApisResult);
}
- private void testDescribeApi(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRootInference"})
+ public void testRootManagement() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.describeModelApi(channel, "noop");
+ TestUtils.getRoot(channel);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getResult(), noopApiResult);
+ Assert.assertEquals(TestUtils.getResult(), listManagementApisResult);
}
- private void testLoadModel(Channel channel, String url, String modelName)
- throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRootManagement"})
+ public void testApiDescription() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, url, modelName, false, false);
+ TestUtils.getApiDescription(channel);
TestUtils.getLatch().await();
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" registered");
+ Assert.assertEquals(TestUtils.getResult(), listInferenceApisResult);
}
- private void testLoadModelFromURL(Channel channel) throws InterruptedException {
- testLoadModel(
- channel,
- "https://torchserve.s3.amazonaws.com/mar_files/squeezenet1_1.mar",
- "squeezenet");
- Assert.assertTrue(new File(configManager.getModelStore(), "squeezenet1_1.mar").exists());
- }
-
- private void testUnregisterURLModel(Channel channel) throws InterruptedException {
- testUnregisterModel(channel, "squeezenet", null);
- Assert.assertTrue(!new File(configManager.getModelStore(), "squeezenet1_1.mar").exists());
- }
-
- private void testLoadModelWithInitialWorkers(Channel channel, String url, String modelName)
- throws InterruptedException {
-
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testApiDescription"})
+ public void testDescribeApi() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, url, modelName, true, false);
+ TestUtils.describeModelApi(channel, "noop");
TestUtils.getLatch().await();
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Workers scaled");
+ Assert.assertEquals(TestUtils.getResult(), noopApiResult);
}
- private void testLoadModelWithInitialWorkersWithJSONReqBody(Channel channel)
- throws InterruptedException {
- testUnregisterModel(channel, "noop", null);
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models");
- req.headers().add("Content-Type", "application/json");
- req.content()
- .writeCharSequence(
- "{'url':'noop.mar', 'model_name':'noop', 'initial_workers':'1', 'synchronous':'true'}",
- CharsetUtil.UTF_8);
- HttpUtil.setContentLength(req, req.content().readableBytes());
- channel.writeAndFlush(req);
- TestUtils.getLatch().await();
-
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Workers scaled");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeApi"})
+ public void testUnregisterNoopModel() throws InterruptedException {
+ testUnregisterModel("noop", null);
}
- private void testScaleModel(Channel channel) throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.scaleModel(channel, "noop_v1.0", null, 2, false);
- TestUtils.getLatch().await();
-
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Processing worker updates...");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterNoopModel"})
+ public void testLoadNoopModel() throws InterruptedException {
+ testLoadModel("noop.mar", "noop_v1.0");
}
- private void testSyncScaleModel(Channel channel, String modelName, String version)
- throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.scaleModel(channel, modelName, version, 1, true);
-
- TestUtils.getLatch().await();
-
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Workers scaled");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadNoopModel"})
+ public void testSyncScaleNoopModel() throws InterruptedException {
+ testSyncScaleModel("noop_v1.0", null);
}
- private void testUnregisterModel(Channel channel, String modelName, String version)
- throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSyncScaleNoopModel"})
+ public void testListModels() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.unregisterModel(channel, modelName, version, false);
+ TestUtils.listModels(channel);
TestUtils.getLatch().await();
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" unregistered");
+ ListModelsResponse resp =
+ JsonUtils.GSON.fromJson(TestUtils.getResult(), ListModelsResponse.class);
+ Assert.assertEquals(resp.getModels().size(), 1);
}
- private void testUnregisterModelFailure(String modelName, String version)
- throws InterruptedException {
- Channel channel = TestUtils.connect(true, configManager);
- Assert.assertNotNull(channel);
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.unregisterModel(channel, modelName, version, false);
- TestUtils.getLatch().await();
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testListModels"})
+ public void testDescribeNoopModel() throws InterruptedException {
+ testDescribeModel("noop_v1.0", null, "1.11");
+ }
- ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
- Assert.assertEquals(resp.getCode(), HttpResponseStatus.FORBIDDEN.code());
- Assert.assertEquals(
- resp.getMessage(), "Cannot remove default version for model " + modelName);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeNoopModel"})
+ public void testLoadNoopModelWithInitialWorkers() throws InterruptedException {
+ testLoadModelWithInitialWorkers("noop.mar", "noop");
+ }
- channel = TestUtils.connect(true, configManager);
- Assert.assertNotNull(channel);
- testUnregisterModel(channel, "noopversioned", "1.11");
- testUnregisterModel(channel, "noopversioned", "1.2.1");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadNoopModelWithInitialWorkers"})
+ public void testLoadNoopV1ModelWithInitialWorkers() throws InterruptedException {
+ testLoadModelWithInitialWorkers("noop.mar", "noopversioned");
}
- private void testListModels(Channel channel) throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.listModels(channel);
- TestUtils.getLatch().await();
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadNoopV1ModelWithInitialWorkers"})
+ public void testLoadNoopV2ModelWithInitialWorkers() throws InterruptedException {
+ testLoadModelWithInitialWorkers("noop_v2.mar", "noopversioned");
+ }
- ListModelsResponse resp =
- JsonUtils.GSON.fromJson(TestUtils.getResult(), ListModelsResponse.class);
- Assert.assertEquals(resp.getModels().size(), 1);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadNoopV2ModelWithInitialWorkers"})
+ public void testDescribeDefaultModelVersion() throws InterruptedException {
+ testDescribeModel("noopversioned", null, "1.11");
}
- private void testDescribeModel(
- Channel channel, String modelName, String requestVersion, String expectedVersion)
- throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.describeModel(channel, modelName, requestVersion);
- TestUtils.getLatch().await();
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeDefaultModelVersion"})
+ public void testDescribeAllModelVersion() throws InterruptedException {
+ testDescribeModel("noopversioned", "all", "1.2.1");
+ }
- DescribeModelResponse[] resp =
- JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class);
- if ("all".equals(requestVersion)) {
- Assert.assertTrue(resp.length >= 1);
- } else {
- Assert.assertTrue(resp.length == 1);
- }
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeAllModelVersion"})
+ public void testDescribeSpecificModelVersion() throws InterruptedException {
+ testDescribeModel("noopversioned", "1.11", "1.11");
+ }
- Assert.assertTrue(expectedVersion.equals(resp[0].getModelVersion()));
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeSpecificModelVersion"})
+ public void testNoopVersionedPrediction() throws InterruptedException {
+ testPredictions("noopversioned", "OK", "1.11");
}
- private void testSetDefault(Channel channel, String modelName, String defaultVersion)
- throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoopVersionedPrediction"})
+ public void testSetDefaultVersionNoop() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default";
-
- HttpRequest req =
- new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL);
- channel.writeAndFlush(req);
+ TestUtils.setDefault(channel, "noopversioned", "1.2.1");
TestUtils.getLatch().await();
StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
Assert.assertEquals(
resp.getStatus(),
- "Default vesion succsesfully updated for model \""
- + modelName
- + "\" to \""
- + defaultVersion
- + "\"");
+ "Default vesion succsesfully updated for model \"noopversioned\" to \"1.2.1\"");
}
- private void testSetInvalidVersionDefault(String modelName, String defaultVersion)
- throws InterruptedException {
- Channel channel = TestUtils.connect(true, configManager);
- Assert.assertNotNull(channel);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSetDefaultVersionNoop"})
+ public void testLoadModelWithInitialWorkersWithJSONReqBody() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ testUnregisterModel("noop", null);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default";
-
- HttpRequest req =
- new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL);
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models");
+ req.headers().add("Content-Type", "application/json");
+ req.content()
+ .writeCharSequence(
+ "{'url':'noop.mar', 'model_name':'noop', 'initial_workers':'1', 'synchronous':'true'}",
+ CharsetUtil.UTF_8);
+ HttpUtil.setContentLength(req, req.content().readableBytes());
channel.writeAndFlush(req);
TestUtils.getLatch().await();
- ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
- Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code());
- Assert.assertEquals(
- resp.getMessage(),
- "Model version " + defaultVersion + " does not exist for model " + modelName);
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), "Workers scaled");
}
- private void testPredictions(
- Channel channel, String modelName, String expectedOutput, String version)
- throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- String requestURL = "/predictions/" + modelName;
- if (version != null) {
- requestURL += "/" + version;
- }
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, requestURL);
- req.content().writeCharSequence("data=test", CharsetUtil.UTF_8);
- HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers()
- .set(
- HttpHeaderNames.CONTENT_TYPE,
- HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED);
- channel.writeAndFlush(req);
-
- TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getResult(), expectedOutput);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadModelWithInitialWorkersWithJSONReqBody"})
+ public void testNoopPrediction() throws InterruptedException {
+ testPredictions("noop", "OK", null);
}
- private void testPredictionsJson(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoopPrediction"})
+ public void testPredictionsBinary() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop");
- req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ req.content().writeCharSequence("test", CharsetUtil.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
channel.writeAndFlush(req);
TestUtils.getLatch().await();
+
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testPredictionsBinary(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsBinary"})
+ public void testPredictionsJson() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop");
- req.content().writeCharSequence("test", CharsetUtil.UTF_8);
+ req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM);
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
channel.writeAndFlush(req);
TestUtils.getLatch().await();
-
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testInvocationsJson(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsJson"})
+ public void testInvocationsJson() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -521,9 +352,13 @@ private void testInvocationsJson(Channel channel) throws InterruptedException {
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testInvocationsMultipart(Channel channel)
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvocationsJson"})
+ public void testInvocationsMultipart()
throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
IOException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -546,7 +381,11 @@ private void testInvocationsMultipart(Channel channel)
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testModelsInvokeJson(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvocationsMultipart"})
+ public void testModelsInvokeJson() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -561,9 +400,13 @@ private void testModelsInvokeJson(Channel channel) throws InterruptedException {
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testModelsInvokeMultipart(Channel channel)
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testModelsInvokeJson"})
+ public void testModelsInvokeMultipart()
throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
IOException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -586,7 +429,27 @@ private void testModelsInvokeMultipart(Channel channel)
Assert.assertEquals(TestUtils.getResult(), "OK");
}
- private void testPredictionsInvalidRequestSize(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testModelsInvokeMultipart"})
+ public void testLegacyPredict() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(
+ HttpVersion.HTTP_1_1, HttpMethod.GET, "/noop/predict?data=test");
+ channel.writeAndFlush(req);
+
+ TestUtils.getLatch().await();
+ Assert.assertEquals(TestUtils.getResult(), "OK");
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLegacyPredict"})
+ public void testPredictionsInvalidRequestSize() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -603,7 +466,11 @@ private void testPredictionsInvalidRequestSize(Channel channel) throws Interrupt
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
}
- private void testPredictionsValidRequestSize(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsInvalidRequestSize"})
+ public void testPredictionsValidRequestSize() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
@@ -620,35 +487,121 @@ private void testPredictionsValidRequestSize(Channel channel) throws Interrupted
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
}
- private void loadTests(Channel channel, String model, String modelName)
- throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsValidRequestSize"})
+ public void testPredictionsDecodeRequest()
+ throws InterruptedException, NoSuchFieldException, IllegalAccessException {
+ Channel inferChannel = TestUtils.getInferenceChannel(configManager);
+ Channel mgmtChannel = TestUtils.getManagementChannel(configManager);
+ setConfiguration("decode_input_request", "true");
+ loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config");
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, model, modelName, true, false);
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config");
+ req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ HttpUtil.setContentLength(req, req.content().readableBytes());
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
+ inferChannel.writeAndFlush(req);
TestUtils.getLatch().await();
+
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
+ Assert.assertFalse(TestUtils.getResult().contains("bytearray"));
+ unloadTests(mgmtChannel, "noop-config");
}
- private void unloadTests(Channel channel, String modelName) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsDecodeRequest"})
+ public void testPredictionsDoNotDecodeRequest()
+ throws InterruptedException, NoSuchFieldException, IllegalAccessException {
+ Channel inferChannel = TestUtils.getInferenceChannel(configManager);
+ Channel mgmtChannel = TestUtils.getManagementChannel(configManager);
+ setConfiguration("decode_input_request", "false");
+ loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config");
+
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- String expected = "Model \"" + modelName + "\" unregistered";
- TestUtils.unregisterModel(channel, modelName, null, false);
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config");
+ req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ HttpUtil.setContentLength(req, req.content().readableBytes());
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
+ inferChannel.writeAndFlush(req);
+
TestUtils.getLatch().await();
- StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(resp.getStatus(), expected);
+
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
+ Assert.assertTrue(TestUtils.getResult().contains("bytearray"));
+ unloadTests(mgmtChannel, "noop-config");
}
- private void setConfiguration(String key, String val)
- throws NoSuchFieldException, IllegalAccessException {
- Field f = configManager.getClass().getDeclaredField("prop");
- f.setAccessible(true);
- Properties p = (Properties) f.get(configManager);
- p.setProperty(key, val);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsDoNotDecodeRequest"})
+ public void testPredictionsModifyResponseHeader()
+ throws NoSuchFieldException, IllegalAccessException, InterruptedException {
+ Channel inferChannel = TestUtils.getInferenceChannel(configManager);
+ Channel mgmtChannel = TestUtils.getManagementChannel(configManager);
+ setConfiguration("decode_input_request", "false");
+ loadTests(mgmtChannel, "respheader-test.mar", "respheader");
+
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/respheader");
+
+ req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ HttpUtil.setContentLength(req, req.content().readableBytes());
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
+ inferChannel.writeAndFlush(req);
+
+ TestUtils.getLatch().await();
+
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
+ Assert.assertEquals(TestUtils.getHeaders().get("dummy"), "1");
+ Assert.assertEquals(TestUtils.getHeaders().get("content-type"), "text/plain");
+ Assert.assertTrue(TestUtils.getResult().contains("bytearray"));
+ unloadTests(mgmtChannel, "respheader");
}
- private void testModelRegisterWithDefaultWorkers(Channel mgmtChannel)
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsModifyResponseHeader"})
+ public void testPredictionsNoManifest()
+ throws InterruptedException, NoSuchFieldException, IllegalAccessException {
+ Channel inferChannel = TestUtils.getInferenceChannel(configManager);
+ Channel mgmtChannel = TestUtils.getManagementChannel(configManager);
+ setConfiguration("default_service_handler", "service:handle");
+ loadTests(mgmtChannel, "noop-no-manifest.mar", "nomanifest");
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ DefaultFullHttpRequest req =
+ new DefaultFullHttpRequest(
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/nomanifest");
+ req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ HttpUtil.setContentLength(req, req.content().readableBytes());
+ req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
+ inferChannel.writeAndFlush(req);
+
+ TestUtils.getLatch().await();
+
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
+ Assert.assertEquals(TestUtils.getResult(), "OK");
+ unloadTests(mgmtChannel, "nomanifest");
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsNoManifest"})
+ public void testModelRegisterWithDefaultWorkers()
throws NoSuchFieldException, IllegalAccessException, InterruptedException {
+ Channel mgmtChannel = TestUtils.getManagementChannel(configManager);
setConfiguration("default_workers_per_model", "1");
loadTests(mgmtChannel, "noop.mar", "noop_default_model_workers");
@@ -666,109 +619,161 @@ private void testModelRegisterWithDefaultWorkers(Channel mgmtChannel)
setConfiguration("default_workers_per_model", "0");
}
- private void testPredictionsDecodeRequest(Channel inferChannel, Channel mgmtChannel)
- throws InterruptedException, NoSuchFieldException, IllegalAccessException {
- setConfiguration("decode_input_request", "true");
- loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testModelRegisterWithDefaultWorkers"})
+ public void testLoadModelFromURL() throws InterruptedException {
+ testLoadModel(
+ "https://torchserve.s3.amazonaws.com/mar_files/squeezenet1_1.mar", "squeezenet");
+ Assert.assertTrue(new File(configManager.getModelStore(), "squeezenet1_1.mar").exists());
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadModelFromURL"})
+ public void testUnregisterURLModel() throws InterruptedException {
+ testUnregisterModel("squeezenet", null);
+ Assert.assertTrue(!new File(configManager.getModelStore(), "squeezenet1_1.mar").exists());
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterURLModel"})
+ public void testLoadingMemoryError() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ Assert.assertNotNull(channel);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config");
- req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
- HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
- inferChannel.writeAndFlush(req);
+ TestUtils.registerModel(channel, "loading-memory-error.mar", "memory_error", true, false);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- Assert.assertFalse(TestUtils.getResult().contains("bytearray"));
- unloadTests(mgmtChannel, "noop-config");
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
+ channel.close();
}
- private void testPredictionsDoNotDecodeRequest(Channel inferChannel, Channel mgmtChannel)
- throws InterruptedException, NoSuchFieldException, IllegalAccessException {
- setConfiguration("decode_input_request", "false");
- loadTests(mgmtChannel, "noop-v1.0-config-tests.mar", "noop-config");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadingMemoryError"})
+ public void testPredictionMemoryError() throws InterruptedException {
+ // Load the model
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ Assert.assertNotNull(channel);
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+
+ TestUtils.registerModel(channel, "prediction-memory-error.mar", "pred-err", true, false);
+ TestUtils.getLatch().await();
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
+ channel.close();
+ // Test for prediction
+ channel = TestUtils.connect(false, configManager);
+ Assert.assertNotNull(channel);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config");
- req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
- HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
- inferChannel.writeAndFlush(req);
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/pred-err");
+ req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);
+
+ channel.writeAndFlush(req);
+ TestUtils.getLatch().await();
+
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
+ channel.close();
+
+ // Unload the model
+ channel = TestUtils.connect(true, configManager);
+ TestUtils.setHttpStatus(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ Assert.assertNotNull(channel);
+ TestUtils.unregisterModel(channel, "pred-err", null, false);
TestUtils.getLatch().await();
-
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- Assert.assertTrue(TestUtils.getResult().contains("bytearray"));
- unloadTests(mgmtChannel, "noop-config");
}
- private void testPredictionsModifyResponseHeader(
- Channel inferChannel, Channel managementChannel)
- throws NoSuchFieldException, IllegalAccessException, InterruptedException {
- setConfiguration("decode_input_request", "false");
- loadTests(managementChannel, "respheader-test.mar", "respheader");
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionMemoryError"})
+ public void testErrorBatch() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ Assert.assertNotNull(channel);
+ TestUtils.setHttpStatus(null);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/respheader");
-
- req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
- HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
- inferChannel.writeAndFlush(req);
+ TestUtils.registerModel(channel, "error_batch.mar", "err_batch", true, false);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- Assert.assertEquals(TestUtils.getHeaders().get("dummy"), "1");
- Assert.assertEquals(TestUtils.getHeaders().get("content-type"), "text/plain");
- Assert.assertTrue(TestUtils.getResult().contains("bytearray"));
- unloadTests(managementChannel, "respheader");
- }
+ StatusResponse status =
+ JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(status.getStatus(), "Workers scaled");
+
+ channel.close();
+
+ channel = TestUtils.connect(false, configManager);
+ Assert.assertNotNull(channel);
- private void testPredictionsNoManifest(Channel inferChannel, Channel mgmtChannel)
- throws InterruptedException, NoSuchFieldException, IllegalAccessException {
- setConfiguration("default_service_handler", "service:handle");
- loadTests(mgmtChannel, "noop-no-manifest.mar", "nomanifest");
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
+ TestUtils.setHttpStatus(null);
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/nomanifest");
- req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8);
+ HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/err_batch");
+ req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
- req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
- inferChannel.writeAndFlush(req);
+ req.headers()
+ .set(
+ HttpHeaderNames.CONTENT_TYPE,
+ HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED);
+ channel.writeAndFlush(req);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- Assert.assertEquals(TestUtils.getResult(), "OK");
- unloadTests(mgmtChannel, "nomanifest");
+ Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
+ Assert.assertEquals(TestUtils.getResult(), "Invalid response");
}
- private void testLegacyPredict(Channel channel) throws InterruptedException {
- TestUtils.setResult(null);
- TestUtils.setLatch(new CountDownLatch(1));
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.GET, "/noop/predict?data=test");
- channel.writeAndFlush(req);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testErrorBatch"})
+ public void testMetricManager() throws JsonParseException, InterruptedException {
+ MetricManager.scheduleMetrics(configManager);
+ MetricManager metricManager = MetricManager.getInstance();
+ List metrics = metricManager.getMetrics();
- TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getResult(), "OK");
+ // Wait till first value is read in
+ int count = 0;
+ while (metrics.isEmpty()) {
+ Thread.sleep(500);
+ metrics = metricManager.getMetrics();
+ Assert.assertTrue(++count < 5);
+ }
+ for (Metric metric : metrics) {
+ if (metric.getMetricName().equals("CPUUtilization")) {
+ Assert.assertEquals(metric.getUnit(), "Percent");
+ }
+ if (metric.getMetricName().equals("MemoryUsed")) {
+ Assert.assertEquals(metric.getUnit(), "Megabytes");
+ }
+ if (metric.getMetricName().equals("DiskUsed")) {
+ List dimensions = metric.getDimensions();
+ for (Dimension dimension : dimensions) {
+ if (dimension.getName().equals("Level")) {
+ Assert.assertEquals(dimension.getValue(), "Host");
+ }
+ }
+ }
+ }
}
- private void testInvalidRootRequest() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testMetricManager"})
+ public void testInvalidRootRequest() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
@@ -782,7 +787,10 @@ private void testInvalidRootRequest() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED);
}
- private void testInvalidInferenceUri() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidRootRequest"})
+ public void testInvalidInferenceUri() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
@@ -797,7 +805,10 @@ private void testInvalidInferenceUri() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND);
}
- private void testInvalidDescribeModel() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidInferenceUri"})
+ public void testInvalidDescribeModel() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
@@ -810,7 +821,10 @@ private void testInvalidDescribeModel() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel");
}
- private void testInvalidPredictionsUri() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidDescribeModel"})
+ public void testInvalidPredictionsUri() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
@@ -825,7 +839,10 @@ private void testInvalidPredictionsUri() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND);
}
- private void testPredictionsModelNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidPredictionsUri"})
+ public void testPredictionsModelNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
@@ -841,9 +858,13 @@ private void testPredictionsModelNotFound() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel");
}
- private void testPredictionsModelVersionNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsModelNotFound"})
+ public void testPredictionsModelVersionNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(false, configManager);
Assert.assertNotNull(channel);
+
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions/noopversioned/1.3.1");
@@ -857,7 +878,10 @@ private void testPredictionsModelVersionNotFound() throws InterruptedException {
resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned");
}
- private void testInvalidManagementUri() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionsModelNotFound"})
+ public void testInvalidManagementUri() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -872,7 +896,10 @@ private void testInvalidManagementUri() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND);
}
- private void testInvalidModelsMethod() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidManagementUri"})
+ public void testInvalidModelsMethod() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -887,7 +914,10 @@ private void testInvalidModelsMethod() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED);
}
- private void testInvalidModelMethod() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidModelsMethod"})
+ public void testInvalidModelMethod() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -902,7 +932,10 @@ private void testInvalidModelMethod() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED);
}
- private void testDescribeModelNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testInvalidModelMethod"})
+ public void testDescribeModelNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -918,9 +951,13 @@ private void testDescribeModelNotFound() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel");
}
- private void testDescribeModelVersionNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeModelNotFound"})
+ public void testDescribeModelVersionNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
+
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/noopversioned/1.3.1");
@@ -928,12 +965,16 @@ private void testDescribeModelVersionNotFound() throws InterruptedException {
channel.closeFuture().sync();
ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
+
Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code());
Assert.assertEquals(
resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned");
}
- private void testRegisterModelMissingUrl() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testDescribeModelNotFound"})
+ public void testRegisterModelMissingUrl() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -948,7 +989,10 @@ private void testRegisterModelMissingUrl() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Parameter url is required.");
}
- private void testRegisterModelInvalidRuntime() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelMissingUrl"})
+ public void testRegisterModelInvalidRuntime() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -966,7 +1010,10 @@ private void testRegisterModelInvalidRuntime() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Invalid RuntimeType value: InvalidRuntime");
}
- private void testRegisterModelNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelInvalidRuntime"})
+ public void testRegisterModelNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -982,7 +1029,10 @@ private void testRegisterModelNotFound() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found in model store: InvalidUrl");
}
- private void testRegisterModelConflict() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelNotFound"})
+ public void testRegisterModelConflict() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1009,7 +1059,10 @@ private void testRegisterModelConflict() throws InterruptedException {
resp.getMessage(), "Model version 1.11 is already registered for model noop_v1.0");
}
- private void testRegisterModelMalformedUrl() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelConflict"})
+ public void testRegisterModelMalformedUrl() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1028,7 +1081,10 @@ private void testRegisterModelMalformedUrl() throws InterruptedException {
resp.getMessage(), "Failed to download model from: http://localhost:aaaa");
}
- private void testRegisterModelConnectionFailed() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelMalformedUrl"})
+ public void testRegisterModelConnectionFailed() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1048,7 +1104,10 @@ private void testRegisterModelConnectionFailed() throws InterruptedException {
"Failed to download model from: http://localhost:18888/fake.mar");
}
- private void testRegisterModelHttpError() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelConnectionFailed"})
+ public void testRegisterModelHttpError() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1068,7 +1127,10 @@ private void testRegisterModelHttpError() throws InterruptedException {
"Failed to download model from: https://localhost:8443/fake.mar");
}
- private void testRegisterModelInvalidPath() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelHttpError"})
+ public void testRegisterModelInvalidPath() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1086,7 +1148,10 @@ private void testRegisterModelInvalidPath() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Relative path is not allowed in url: ../fake.mar");
}
- private void testScaleModelNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterModelInvalidPath"})
+ public void testScaleModelNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1101,7 +1166,10 @@ private void testScaleModelNotFound() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found: fake");
}
- private void testScaleModelVersionNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testScaleModelNotFound"})
+ public void testScaleModelVersionNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.setResult(null);
@@ -1116,7 +1184,10 @@ private void testScaleModelVersionNotFound() throws InterruptedException {
resp.getMessage(), "Model version: 1.3.1 does not exist for model: noop_v1.0");
}
- private void testUnregisterModelNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testScaleModelNotFound"})
+ public void testUnregisterModelNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1128,19 +1199,25 @@ private void testUnregisterModelNotFound() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Model not found: fake");
}
- private void testUnregisterModelVersionNotFound() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterModelNotFound"})
+ public void testUnregisterModelVersionNotFound() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
- TestUtils.unregisterModel(channel, "noop_v1.0", "1.3.1", true);
+ TestUtils.unregisterModel(channel, "noopversioned", "1.3.1", true);
ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code());
Assert.assertEquals(
- resp.getMessage(), "Model version: 1.3.1 does not exist for model: noop_v1.0");
+ resp.getMessage(), "Model version: 1.3.1 does not exist for model: noopversioned");
}
- private void testUnregisterModelTimeout()
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterModelNotFound"})
+ public void testUnregisterModelTimeout()
throws InterruptedException, NoSuchFieldException, IllegalAccessException {
Channel channel = TestUtils.connect(true, configManager);
setConfiguration("unregister_model_timeout", "0");
@@ -1158,7 +1235,10 @@ private void testUnregisterModelTimeout()
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
}
- private void testScaleModelFailure() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterModelTimeout"})
+ public void testScaleModelFailure() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
@@ -1185,85 +1265,183 @@ private void testScaleModelFailure() throws InterruptedException {
Assert.assertEquals(resp.getMessage(), "Failed to start workers");
}
- private void testLoadingMemoryError() throws InterruptedException {
- Channel channel = TestUtils.connect(true, configManager);
- Assert.assertNotNull(channel);
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testScaleModelFailure"})
+ public void testLoadMNISTEagerModel() throws InterruptedException {
+ testLoadModelWithInitialWorkers("mnist.mar", "mnist");
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadMNISTEagerModel"})
+ public void testPredictionMNISTEagerModel() throws InterruptedException {
+ testPredictions("mnist", "0", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionMNISTEagerModel"})
+ public void testUnregistedMNISTEagerModel() throws InterruptedException {
+ testUnregisterModel("mnist", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregistedMNISTEagerModel"})
+ public void testLoadMNISTScriptedModel() throws InterruptedException {
+ testLoadModelWithInitialWorkers("mnist_scripted.mar", "mnist_scripted");
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadMNISTScriptedModel"})
+ public void testPredictionMNISTScriptedModel() throws InterruptedException {
+ testPredictions("mnist_scripted", "0", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionMNISTScriptedModel"})
+ public void testUnregistedMNISTScriptedModel() throws InterruptedException {
+ testUnregisterModel("mnist_scripted", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregistedMNISTScriptedModel"})
+ public void testLoadMNISTTracedModel() throws InterruptedException {
+ testLoadModelWithInitialWorkers("mnist_traced.mar", "mnist_traced");
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadMNISTTracedModel"})
+ public void testPredictionMNISTTracedModel() throws InterruptedException {
+ testPredictions("mnist_traced", "0", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testPredictionMNISTTracedModel"})
+ public void testUnregistedMNISTTracedModel() throws InterruptedException {
+ testUnregisterModel("mnist_traced", null);
+ }
+
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregistedMNISTTracedModel"})
+ public void testSetInvalidDefaultVersion() throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
-
- TestUtils.registerModel(channel, "loading-memory-error.mar", "memory_error", true, false);
+ TestUtils.setDefault(channel, "noopversioned", "3.3.3");
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
- channel.close();
+ ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
+ Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code());
+ Assert.assertEquals(
+ resp.getMessage(), "Model version 3.3.3 does not exist for model noopversioned");
}
- private void testPredictionMemoryError() throws InterruptedException {
- // Load the model
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSetInvalidDefaultVersion"})
+ public void testUnregisterModelFailure() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
-
- TestUtils.registerModel(channel, "prediction-memory-error.mar", "pred-err", true, false);
+ TestUtils.unregisterModel(channel, "noopversioned", "1.2.1", false);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- channel.close();
- // Test for prediction
- channel = TestUtils.connect(false, configManager);
+ ErrorResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), ErrorResponse.class);
+ Assert.assertEquals(resp.getCode(), HttpResponseStatus.FORBIDDEN.code());
+ Assert.assertEquals(
+ resp.getMessage(), "Cannot remove default version for model noopversioned");
+
+ channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
+ TestUtils.unregisterModel(channel, "noopversioned", "1.11", false);
+ TestUtils.unregisterModel(channel, "noopversioned", "1.2.1", false);
+ }
+
+ private void testLoadModel(String url, String modelName) throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/pred-err");
- req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);
+ TestUtils.registerModel(channel, url, modelName, false, false);
+ TestUtils.getLatch().await();
- channel.writeAndFlush(req);
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" registered");
+ }
+
+ private void testUnregisterModel(String modelName, String version) throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ TestUtils.unregisterModel(channel, modelName, version, false);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
- channel.close();
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), "Model \"" + modelName + "\" unregistered");
+ }
- // Unload the model
- channel = TestUtils.connect(true, configManager);
- TestUtils.setHttpStatus(null);
+ private void testSyncScaleModel(String modelName, String version) throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- Assert.assertNotNull(channel);
+ TestUtils.scaleModel(channel, modelName, version, 1, true);
- TestUtils.unregisterModel(channel, "pred-err", null, false);
TestUtils.getLatch().await();
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
- }
- private void testErrorBatch() throws InterruptedException {
- Channel channel = TestUtils.connect(true, configManager);
- Assert.assertNotNull(channel);
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), "Workers scaled");
+ }
- TestUtils.setHttpStatus(null);
+ private void testDescribeModel(String modelName, String requestVersion, String expectedVersion)
+ throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
-
- TestUtils.registerModel(channel, "error_batch.mar", "err_batch", true, false);
+ TestUtils.describeModel(channel, modelName, requestVersion);
TestUtils.getLatch().await();
- StatusResponse status =
- JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
- Assert.assertEquals(status.getStatus(), "Workers scaled");
+ DescribeModelResponse[] resp =
+ JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class);
+ if ("all".equals(requestVersion)) {
+ Assert.assertTrue(resp.length >= 1);
+ } else {
+ Assert.assertTrue(resp.length == 1);
+ }
+ Assert.assertTrue(expectedVersion.equals(resp[0].getModelVersion()));
+ }
- channel.close();
+ private void testLoadModelWithInitialWorkers(String url, String modelName)
+ throws InterruptedException {
+ Channel channel = TestUtils.getManagementChannel(configManager);
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ TestUtils.registerModel(channel, url, modelName, true, false);
+ TestUtils.getLatch().await();
- channel = TestUtils.connect(false, configManager);
- Assert.assertNotNull(channel);
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), "Workers scaled");
+ }
+ private void testPredictions(String modelName, String expectedOutput, String version)
+ throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.setHttpStatus(null);
+ String requestURL = "/predictions/" + modelName;
+ if (version != null) {
+ requestURL += "/" + version;
+ }
DefaultFullHttpRequest req =
- new DefaultFullHttpRequest(
- HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/err_batch");
- req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);
+ new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, requestURL);
+ req.content().writeCharSequence("data=test", CharsetUtil.UTF_8);
HttpUtil.setContentLength(req, req.content().readableBytes());
req.headers()
.set(
@@ -1272,38 +1450,33 @@ private void testErrorBatch() throws InterruptedException {
channel.writeAndFlush(req);
TestUtils.getLatch().await();
+ Assert.assertEquals(TestUtils.getResult(), expectedOutput);
+ }
- Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.INSUFFICIENT_STORAGE);
- Assert.assertEquals(TestUtils.getResult(), "Invalid response");
+ private void loadTests(Channel channel, String model, String modelName)
+ throws InterruptedException {
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ TestUtils.registerModel(channel, model, modelName, true, false);
+
+ TestUtils.getLatch().await();
}
- private void testMetricManager() throws JsonParseException, InterruptedException {
- MetricManager.scheduleMetrics(configManager);
- MetricManager metricManager = MetricManager.getInstance();
- List metrics = metricManager.getMetrics();
+ private void unloadTests(Channel channel, String modelName) throws InterruptedException {
+ TestUtils.setResult(null);
+ TestUtils.setLatch(new CountDownLatch(1));
+ String expected = "Model \"" + modelName + "\" unregistered";
+ TestUtils.unregisterModel(channel, modelName, null, false);
+ TestUtils.getLatch().await();
+ StatusResponse resp = JsonUtils.GSON.fromJson(TestUtils.getResult(), StatusResponse.class);
+ Assert.assertEquals(resp.getStatus(), expected);
+ }
- // Wait till first value is read in
- int count = 0;
- while (metrics.isEmpty()) {
- Thread.sleep(500);
- metrics = metricManager.getMetrics();
- Assert.assertTrue(++count < 5);
- }
- for (Metric metric : metrics) {
- if (metric.getMetricName().equals("CPUUtilization")) {
- Assert.assertEquals(metric.getUnit(), "Percent");
- }
- if (metric.getMetricName().equals("MemoryUsed")) {
- Assert.assertEquals(metric.getUnit(), "Megabytes");
- }
- if (metric.getMetricName().equals("DiskUsed")) {
- List dimensions = metric.getDimensions();
- for (Dimension dimension : dimensions) {
- if (dimension.getName().equals("Level")) {
- Assert.assertEquals(dimension.getValue(), "Host");
- }
- }
- }
- }
+ private void setConfiguration(String key, String val)
+ throws NoSuchFieldException, IllegalAccessException {
+ Field f = configManager.getClass().getDeclaredField("prop");
+ f.setAccessible(true);
+ Properties p = (Properties) f.get(configManager);
+ p.setProperty(key, val);
}
}
diff --git a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java
index 92144a4e20..fb764a60df 100644
--- a/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java
+++ b/frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java
@@ -64,69 +64,21 @@ public void beforeSuite()
}
@AfterClass
- public void afterSuite() {
+ public void afterSuite() throws InterruptedException {
+ TestUtils.closeChannels();
server.stop();
}
@Test
- public void test()
- throws InterruptedException, IOException, GeneralSecurityException,
- InvalidSnapshotException {
- Channel channel = null;
- Channel managementChannel = null;
- for (int i = 0; i < 5; ++i) {
- channel = TestUtils.connect(false, configManager);
- if (channel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- for (int i = 0; i < 5; ++i) {
- managementChannel = TestUtils.connect(true, configManager);
- if (managementChannel != null) {
- break;
- }
- Thread.sleep(100);
- }
-
- Assert.assertNotNull(channel, "Failed to connect to inference port.");
- Assert.assertNotNull(managementChannel, "Failed to connect to management port.");
-
- testStartupSnapshot("snapshot1.cfg");
- testUnregisterSnapshot(managementChannel);
- testRegisterSnapshot(managementChannel);
- testSyncScaleModelSnapshot(managementChannel);
- testNoSnapshotOnListModels(managementChannel);
- testNoSnapshotOnDescribeModel(managementChannel);
- testLoadModelWithInitialWorkersSnapshot(managementChannel);
- testRegisterSecondModelSnapshot(managementChannel);
- testSecondModelVersionSnapshot(managementChannel);
- testNoSnapshotOnPrediction(channel);
- testSetDefaultSnapshot(managementChannel);
- testAsyncScaleModelSnapshot(managementChannel);
-
- channel.close();
- managementChannel.close();
-
- testStopTorchServeSnapshot();
- testStartTorchServeWithLastSnapshot();
- testRestartTorchServeWithSnapshotAsConfig();
-
- // Negative management API calls, channel will be closed by server
- testNoSnapshotOnInvalidModelRegister();
- testNoSnapshotOnInvalidModelUnregister();
- testNoSnapshotOnInvalidModelVersionUnregister();
- testNoSnapshotOnInvalidModelScale();
- testNoSnapshotOnInvalidModelVersionScale();
- testNoSnapshotOnInvalidModelVersionSetDefault();
+ public void testStartupSnapshot() {
+ validateSnapshot("snapshot1.cfg");
}
- private void testStartupSnapshot(String expectedSnapshot) {
- validateSnapshot(expectedSnapshot);
- }
-
- private void testUnregisterSnapshot(Channel managementChannel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testStartupSnapshot"})
+ public void testUnregisterSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
TestUtils.unregisterModel(managementChannel, "noop", null, false);
@@ -135,7 +87,11 @@ private void testUnregisterSnapshot(Channel managementChannel) throws Interrupte
waitForSnapshot();
}
- private void testRegisterSnapshot(Channel managementChannel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testUnregisterSnapshot"})
+ public void testRegisterSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
TestUtils.registerModel(managementChannel, "noop.mar", "noop_v1.0", false, false);
@@ -144,16 +100,24 @@ private void testRegisterSnapshot(Channel managementChannel) throws InterruptedE
waitForSnapshot();
}
- private void testSyncScaleModelSnapshot(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterSnapshot"})
+ public void testSyncScaleModelSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.scaleModel(channel, "noop_v1.0", null, 1, true);
+ TestUtils.scaleModel(managementChannel, "noop_v1.0", null, 1, true);
TestUtils.getLatch().await();
validateSnapshot("snapshot4.cfg");
waitForSnapshot();
}
- private void testNoSnapshotOnListModels(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSyncScaleModelSnapshot"})
+ public void testNoSnapshotOnListModels() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
TestUtils.listModels(channel);
@@ -161,7 +125,11 @@ private void testNoSnapshotOnListModels(Channel channel) throws InterruptedExcep
validateNoSnapshot();
}
- private void testNoSnapshotOnDescribeModel(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnListModels"})
+ public void testNoSnapshotOnDescribeModel() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
TestUtils.describeModel(channel, "noop_v1.0", null);
@@ -169,17 +137,24 @@ private void testNoSnapshotOnDescribeModel(Channel channel) throws InterruptedEx
validateNoSnapshot();
}
- private void testLoadModelWithInitialWorkersSnapshot(Channel channel)
- throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnDescribeModel"})
+ public void testLoadModelWithInitialWorkersSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, "noop.mar", "noop", true, false);
+ TestUtils.registerModel(managementChannel, "noop.mar", "noop", true, false);
TestUtils.getLatch().await();
validateSnapshot("snapshot5.cfg");
waitForSnapshot();
}
- private void testNoSnapshotOnPrediction(Channel channel) {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testLoadModelWithInitialWorkersSnapshot"})
+ public void testNoSnapshotOnPrediction() throws InterruptedException {
+ Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
String requestURL = "/predictions/noop_v1.0";
@@ -194,54 +169,76 @@ private void testNoSnapshotOnPrediction(Channel channel) {
channel.writeAndFlush(req);
}
- private void testRegisterSecondModelSnapshot(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnPrediction"})
+ public void testRegisterSecondModelSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, "noop.mar", "noopversioned", true, false);
+ TestUtils.registerModel(managementChannel, "noop.mar", "noopversioned", true, false);
TestUtils.getLatch().await();
validateSnapshot("snapshot6.cfg");
waitForSnapshot();
}
- private void testSecondModelVersionSnapshot(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRegisterSecondModelSnapshot"})
+ public void testSecondModelVersionSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.registerModel(channel, "noop_v2.mar", "noopversioned", true, false);
+ TestUtils.registerModel(managementChannel, "noop_v2.mar", "noopversioned", true, false);
TestUtils.getLatch().await();
validateSnapshot("snapshot7.cfg");
waitForSnapshot();
}
- private void testSetDefaultSnapshot(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSecondModelVersionSnapshot"})
+ public void testSetDefaultSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
String requestURL = "/models/noopversioned/1.2.1/set-default";
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL);
- channel.writeAndFlush(req);
+ managementChannel.writeAndFlush(req);
TestUtils.getLatch().await();
validateSnapshot("snapshot8.cfg");
waitForSnapshot();
}
- private void testAsyncScaleModelSnapshot(Channel channel) throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testSetDefaultSnapshot"})
+ public void testAsyncScaleModelSnapshot() throws InterruptedException {
+ Channel managementChannel = TestUtils.getManagementChannel(configManager);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
- TestUtils.scaleModel(channel, "noop_v1.0", null, 2, false);
+ TestUtils.scaleModel(managementChannel, "noop_v1.0", null, 2, false);
TestUtils.getLatch().await();
waitForSnapshot(5000);
validateSnapshot("snapshot9.cfg");
waitForSnapshot();
}
- private void testStopTorchServeSnapshot() {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testAsyncScaleModelSnapshot"})
+ public void testStopTorchServeSnapshot() {
server.stop();
validateSnapshot("snapshot9.cfg");
}
- private void testStartTorchServeWithLastSnapshot()
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testStopTorchServeSnapshot"})
+ public void testStartTorchServeWithLastSnapshot()
throws InterruptedException, IOException, GeneralSecurityException,
InvalidSnapshotException {
System.setProperty("tsConfigFile", "");
@@ -260,7 +257,10 @@ private void testStartTorchServeWithLastSnapshot()
validateSnapshot("snapshot9.cfg");
}
- private void testRestartTorchServeWithSnapshotAsConfig()
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testStartTorchServeWithLastSnapshot"})
+ public void testRestartTorchServeWithSnapshotAsConfig()
throws InterruptedException, IOException, GeneralSecurityException,
InvalidSnapshotException {
server.stop();
@@ -282,11 +282,10 @@ private void testRestartTorchServeWithSnapshotAsConfig()
validateSnapshot("snapshot9.cfg");
}
- private void validateNoSnapshot() {
- validateSnapshot(lastSnapshot);
- }
-
- private void testNoSnapshotOnInvalidModelRegister() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testRestartTorchServeWithSnapshotAsConfig"})
+ public void testNoSnapshotOnInvalidModelRegister() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.registerModel(channel, "InvalidModel", "InvalidModel", false, true);
@@ -294,7 +293,10 @@ private void testNoSnapshotOnInvalidModelRegister() throws InterruptedException
validateSnapshot("snapshot9.cfg");
}
- private void testNoSnapshotOnInvalidModelUnregister() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnInvalidModelRegister"})
+ public void testNoSnapshotOnInvalidModelUnregister() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.unregisterModel(channel, "InvalidModel", null, true);
@@ -302,7 +304,10 @@ private void testNoSnapshotOnInvalidModelUnregister() throws InterruptedExceptio
validateSnapshot("snapshot9.cfg");
}
- private void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnInvalidModelUnregister"})
+ public void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.registerModel(channel, "noopversioned", "3.0", false, true);
@@ -310,7 +315,10 @@ private void testNoSnapshotOnInvalidModelVersionUnregister() throws InterruptedE
validateSnapshot("snapshot9.cfg");
}
- private void testNoSnapshotOnInvalidModelScale() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnInvalidModelVersionUnregister"})
+ public void testNoSnapshotOnInvalidModelScale() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.scaleModel(channel, "invalidModel", null, 1, true);
@@ -318,7 +326,10 @@ private void testNoSnapshotOnInvalidModelScale() throws InterruptedException {
validateSnapshot("snapshot9.cfg");
}
- private void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnInvalidModelScale"})
+ public void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
TestUtils.scaleModel(channel, "noopversioned", "3.0", 1, true);
@@ -326,7 +337,10 @@ private void testNoSnapshotOnInvalidModelVersionScale() throws InterruptedExcept
validateSnapshot("snapshot9.cfg");
}
- private void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedException {
+ @Test(
+ alwaysRun = true,
+ dependsOnMethods = {"testNoSnapshotOnInvalidModelVersionScale"})
+ public void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedException {
Channel channel = TestUtils.connect(true, configManager);
Assert.assertNotNull(channel);
String requestURL = "/models/noopversioned/3.0/set-default";
@@ -339,6 +353,10 @@ private void testNoSnapshotOnInvalidModelVersionSetDefault() throws InterruptedE
validateSnapshot("snapshot9.cfg");
}
+ private void validateNoSnapshot() {
+ validateSnapshot(lastSnapshot);
+ }
+
private void validateSnapshot(String expectedSnapshot) {
lastSnapshot = expectedSnapshot;
File expectedSnapshotFile = new File("src/test/resources/snapshots", expectedSnapshot);
diff --git a/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java b/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java
index 5ea9c7109c..afc58d196b 100644
--- a/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java
+++ b/frontend/server/src/test/java/org/pytorch/serve/TestUtils.java
@@ -39,6 +39,8 @@ public final class TestUtils {
static HttpResponseStatus httpStatus;
static String result;
static HttpHeaders headers;
+ private static Channel inferenceChannel;
+ private static Channel managementChannel;
private TestUtils() {}
@@ -194,6 +196,14 @@ public static void listModels(Channel channel) {
channel.writeAndFlush(req);
}
+ public static void setDefault(Channel channel, String modelName, String defaultVersion) {
+ String requestURL = "/models/" + modelName + "/" + defaultVersion + "/set-default";
+
+ HttpRequest req =
+ new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, requestURL);
+ channel.writeAndFlush(req);
+ }
+
public static Channel connect(boolean management, ConfigManager configManager) {
return connect(management, configManager, 120);
}
@@ -236,6 +246,53 @@ public void initChannel(Channel ch) {
return null;
}
+ public static Channel getInferenceChannel(ConfigManager configManager)
+ throws InterruptedException {
+ return getChannel(false, configManager);
+ }
+
+ public static Channel getManagementChannel(ConfigManager configManager)
+ throws InterruptedException {
+ return getChannel(true, configManager);
+ }
+
+ private static Channel getChannel(boolean isManagementChannel, ConfigManager configManager)
+ throws InterruptedException {
+ if (isManagementChannel && managementChannel != null && managementChannel.isActive()) {
+ return managementChannel;
+ } else if (!isManagementChannel
+ && inferenceChannel != null
+ && inferenceChannel.isActive()) {
+ return inferenceChannel;
+ } else {
+ Channel channel = null;
+ if (channel == null) {
+ for (int i = 0; i < 5; ++i) {
+ channel = TestUtils.connect(isManagementChannel, configManager);
+ if (channel != null) {
+ break;
+ }
+ Thread.sleep(100);
+ }
+ }
+ if (isManagementChannel) {
+ managementChannel = channel;
+ } else {
+ inferenceChannel = channel;
+ }
+ return channel;
+ }
+ }
+
+ public static void closeChannels() throws InterruptedException {
+ if (managementChannel != null) {
+ managementChannel.closeFuture().sync();
+ }
+ if (inferenceChannel != null) {
+ inferenceChannel.closeFuture().sync();
+ }
+ }
+
@ChannelHandler.Sharable
private static class TestHandler extends SimpleChannelInboundHandler {
diff --git a/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java b/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java
index bd9fb82f68..ecec62347b 100644
--- a/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java
+++ b/frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java
@@ -8,12 +8,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import org.junit.Assert;
import org.pytorch.serve.TestUtils;
import org.pytorch.serve.metrics.Dimension;
import org.pytorch.serve.metrics.Metric;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.testng.Assert;
import org.testng.annotations.Test;
public class ConfigManagerTest {
diff --git a/frontend/server/testng.xml b/frontend/server/testng.xml
new file mode 100644
index 0000000000..58e956424b
--- /dev/null
+++ b/frontend/server/testng.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/tools/conf/pmd.xml b/frontend/tools/conf/pmd.xml
index 684635ef00..ac787e1139 100644
--- a/frontend/tools/conf/pmd.xml
+++ b/frontend/tools/conf/pmd.xml
@@ -140,4 +140,10 @@
+
+
+
+
+
+
diff --git a/frontend/tools/gradle/check.gradle b/frontend/tools/gradle/check.gradle
index 0dd6de4a03..881874f9d3 100644
--- a/frontend/tools/gradle/check.gradle
+++ b/frontend/tools/gradle/check.gradle
@@ -21,7 +21,7 @@ checkstyle {
configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml")
}
checkstyleMain {
- classpath += configurations.compile
+ classpath += configurations.compileClasspath
}
tasks.withType(Checkstyle) {
reports {