Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed model-archiver to accept handler name or handler_name:entry_pnt_func combinations #472

Merged
merged 22 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
14d14f1
Fix for error while loading model on cpu which was saved on gpu
dhaniram-kshirsagar Jun 9, 2020
d8092b2
Merge branch 'master' of https://github.com/pytorch/serve
dhaniram-kshirsagar Jun 18, 2020
9217583
Fixed issue 465
dhaniram-kshirsagar Jun 22, 2020
449206e
Added more changes
dhaniram-kshirsagar Jun 22, 2020
6275bfc
Fixed UTs
dhaniram-kshirsagar Jun 22, 2020
ad923db
Offline code review comments
dhaniram-kshirsagar Jun 23, 2020
53582f4
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jun 29, 2020
bde62d6
Merge branch 'master' into issue_465
maaquib Jul 6, 2020
06a5452
Merge branch 'master' into issue_465
maaquib Jul 9, 2020
26a0032
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jul 20, 2020
d96c463
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jul 22, 2020
d4c7e34
Fixed pylint error
dhaniram-kshirsagar Jul 22, 2020
e4ea81b
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jul 22, 2020
a551733
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jul 22, 2020
58d1b04
Merge branch 'master' into issue_465
maaquib Jul 22, 2020
1fd88b1
Merge branch 'master' into issue_465
dhaniram-kshirsagar Jul 23, 2020
0a0532d
Merge branch 'master' into issue_465
maaquib Jul 23, 2020
75da63f
Fixed conflicts and merged master
dhaniram-kshirsagar Jul 24, 2020
352374a
Added UT-IT
dhaniram-kshirsagar Jul 24, 2020
6ddaee5
Merge branch 'master' into issue_465
harshbafna Jul 29, 2020
916e063
Merge branch 'master' into issue_465
maaquib Jul 31, 2020
5188727
Merge branch 'master' into issue_465
maaquib Jul 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/image_classifier/resnet_18/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Run the commands given in following steps from the parent directory of the root

```bash
wget https://download.pytorch.org/models/resnet18-5c106cde.pth
torch-model-archiver --model-name resnet-18 --version 1.0 --model-file ./serve/examples/image_classifier/resnet_18/model.py --serialized-file resnet18-5c106cde.pth --handler image_classifier --extra-files ./serve/examples/image_classifier/index_to_name.json
torch-model-archiver --model-name resnet-18 --version 1.0 --model-file ./examples/image_classifier/resnet_18/model.py --serialized-file resnet18-5c106cde.pth --handler image_classifier --extra-files ./examples/image_classifier/index_to_name.json
mkdir model_store
mv resnet-18.mar model_store/
torchserve --start --model-store model_store --models resnet-18=resnet-18.mar
Expand Down
6 changes: 5 additions & 1 deletion examples/text_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ The above command generated the model's state dict as model.pt and the vocab use
* Create a torch model archive using the torch-model-archiver utility to archive the above files.

```bash
torch-model-archiver --model-name my_text_classifier --version 1.0 --model-file model.py --serialized-file model.pt --source-vocab source_vocab.pt --handler text_classifier --extra-files index_to_name.json
torch-model-archiver --model-name my_text_classifier --version 1.0 --model-file model.py --serialized-file model.pt --handler text_classifier --extra-files "index_to_name.json,source_vocab.pt"
dhaniram-kshirsagar marked this conversation as resolved.
Show resolved Hide resolved
```

NOTE - `run_script.sh` has generated `source_vocab.pt` and it is a mandatory file for this handler.
If you are planning to override or use custom source vocab. then name it as `source_vocab.pt` and provide it as `--extra-files` as per above example.
Other option is to extend `TextHandler` and override `get_source_vocab_path` function in your custom handler. Refer [custom handler](../../docs/custom_service.md) for detail

* Register the model on TorchServe using the above model archive file and run digit recognition inference

Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,34 @@ public void testPredictionsJson() throws InterruptedException {
@Test(
alwaysRun = true,
dependsOnMethods = {"testPredictionsJson"})
public void testLoadModelWithHandlerName() throws InterruptedException {
testLoadModelWithInitialWorkers("noop_handlername.mar", "noop_handlername", "1.0");
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testLoadModelWithHandlerName"})
public void testNoopWithHandlerNamePrediction() throws InterruptedException {
testPredictions("noop_handlername", "OK", "1.0");
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testNoopWithHandlerNamePrediction"})
public void testLoadModelWithEntryPntFuncName() throws InterruptedException {
testLoadModelWithInitialWorkers("noop_entrypntfunc.mar", "noop_entrypntfunc", "1.0");
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testLoadModelWithEntryPntFuncName"})
public void testNoopWithEntryPntFuncPrediction() throws InterruptedException {
testPredictions("noop_entrypntfunc", "OK", "1.0");
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testNoopWithEntryPntFuncPrediction"})
public void testInvocationsJson() throws InterruptedException {
Channel channel = TestUtils.getInferenceChannel(configManager);
TestUtils.setResult(null);
Expand Down
11 changes: 6 additions & 5 deletions model-archiver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ optional arguments:
class definition extended from torch.nn.modules.
--handler HANDLER TorchServe's default handler name or handler python
file path to handle custom TorchServe inference logic.
--source-vocab SOURCE_VOCAB
Vocab file for source language required for text
based models
--extra-files EXTRA_FILES
Comma separated path to extra dependency files.
--runtime {python,python2,python3}
Expand Down Expand Up @@ -140,13 +137,17 @@ A serialized file (.pt or .pth) should be a checkpoint in case of torchscript an

### Handler

Handler can be TorchServe's inbuilt handler name or path to a py to handle custom TorchServe inference logic. TorchServe supports following handlers out or box:
Handler can be TorchServe's inbuilt handler name or path to a py file to handle custom TorchServe inference logic. TorchServe supports following handlers out or box:
1. `image_classifier`
2. `object_detector`
3. `text_classifier`
4. `image_segmenter`

For more details refer [default handler documentation](../docs/default_handlers.md)
In case of custom handler, if you plan to provide just `module_name` or `module_name:entry_point_function_name` then make sure that it is prefixed with absolute or relative path of python file.
e.g. if your custom handler custom_image_classifier.py is in /home/serve/examples then
`--handler /home/serve/examples/custom_image_classifier` or if it has my_entry_point module level function then `--handler /home/serve/examples/custom_image_classifier:my_entry_point_func`
dhaniram-kshirsagar marked this conversation as resolved.
Show resolved Hide resolved

For more details refer [default handler documentation](../docs/default_handlers.md) or [custom handler documentation](../docs/custom_service.md)
## Creating a Model Archive

**1. Download the torch model archiver source**
Expand Down
8 changes: 1 addition & 7 deletions model-archiver/model_archiver/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,7 @@ def export_model_args_parser():
type=str,
default=None,
help="TorchServe's default handler name\n"
" or handler python file path to handle custom TorchServe inference logic.")

parser_export.add_argument('--source-vocab',
required=False,
type=str,
default=None,
help='Vocab file for source language. Required for text based models.')
" or Handler path to handle custom inference logic.")

parser_export.add_argument('--extra-files',
required=False,
Expand Down
8 changes: 3 additions & 5 deletions model-archiver/model_archiver/manifest_components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ class Model(object):
"""

def __init__(self, model_name, serialized_file, handler, model_file=None, model_version=None,
extensions=None, source_vocab=None, requirements_file=None):
extensions=None, requirements_file=None):

self.model_name = model_name
self.serialized_file = serialized_file.split("/")[-1]
self.model_file = model_file
self.model_version = model_version
self.extensions = extensions
self.handler = handler.split("/")[-1]
self.source_vocab = source_vocab
self.requirements_file = requirements_file

self.model_dict = self.__to_dict__()

def __to_dict__(self):
Expand All @@ -29,9 +30,6 @@ def __to_dict__(self):

model_dict['handler'] = self.handler

if self.source_vocab:
model_dict['sourceVocab'] = self.source_vocab.split("/")[-1]

if self.model_file:
model_dict['modelFile'] = self.model_file.split("/")[-1]

Expand Down
21 changes: 2 additions & 19 deletions model-archiver/model_archiver/model_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def package_model(args, manifest):
handler = args.handler
extra_files = args.extra_files
export_file_path = args.export_path
source_vocab = args.source_vocab
requirements_file = args.requirements_file

temp_files = []

try:
Expand All @@ -32,8 +32,7 @@ def package_model(args, manifest):

# Step 2 : Copy all artifacts to temp directory
artifact_files = {'model_file': model_file, 'serialized_file': serialized_file, 'handler': handler,
'extra_files': extra_files, 'source_vocab': source_vocab,
'requirements-file': requirements_file}
'extra_files': extra_files, 'requirements-file': requirements_file}

model_path = ModelExportUtils.copy_artifacts(model_name, **artifact_files)

Expand All @@ -54,25 +53,9 @@ def generate_model_archive():
:return:
"""

model_handlers = {
'text_classifier': 'text',
'image_classifier': 'vision',
'object_detector': 'vision',
'image_segmenter': 'vision'
}

logging.basicConfig(format='%(levelname)s - %(message)s')
args = ArgParser.export_model_args_parser().parse_args()

if args.handler in model_handlers.keys():
if model_handlers[args.handler] == "text":
if not args.source_vocab:
raise Exception("Please provide the source language vocab for {0} model.".format(args.handler))
elif not args.handler.endswith(".py"):
raise Exception("Handler should be one of the default TorchServe handlers [{0}]"
" or a py file to handle custom TorchServe inference logic."
.format(",".join(model_handlers.keys())))

manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest=manifest)

Expand Down
18 changes: 15 additions & 3 deletions model-archiver/model_archiver/model_packaging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
"default": ".mar"
}

model_handlers = {
'text_classifier': 'text',
'image_classifier': 'vision',
'object_detector': 'vision',
'image_segmenter': 'vision'
}

MODEL_SERVER_VERSION = '1.0'
MODEL_ARCHIVE_VERSION = '1.0'
MANIFEST_FILE_NAME = 'MANIFEST.json'
Expand Down Expand Up @@ -88,7 +95,7 @@ def find_unique(files, suffix):
def generate_model(modelargs):
model = Model(model_name=modelargs.model_name, serialized_file=modelargs.serialized_file,
model_file=modelargs.model_file, handler=modelargs.handler, model_version=modelargs.version,
source_vocab=modelargs.source_vocab, requirements_file=modelargs.requirements_file)
requirements_file=modelargs.requirements_file)
return model

@staticmethod
Expand Down Expand Up @@ -129,8 +136,13 @@ def copy_artifacts(model_name, **kwargs):
ModelExportUtils.make_dir(model_path)
for file_type, path in kwargs.items():
if path:
if file_type == "handler" and len(path.split("/")[-1].split(".")) == 1:
continue
if file_type == "handler":
if path in model_handlers.keys():
continue

if '.py' not in path:
path = (path.split(':')[0] if ':' in path else path) + '.py'

if file_type == "extra_files":
for file in path.split(","):
shutil.copy(file, model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,30 @@
"iterations": 2,
"version": "1.0",
"force": true
} ]
},
{
"name": "packaging_mar_with_handler_name",
"model-name": "model",
"model-file": "model_archiver/tests/integ_tests/resources/regular_model/test_model.py",
"serialized-file": "model_archiver/tests/integ_tests/resources/regular_model/test_serialized_file.pt",
"handler": "model_archiver/tests/integ_tests/resources/regular_model/test_handler",
"extra-files": "model_archiver/tests/integ_tests/resources/regular_model/test_index_to_name.json",
"export-path": "/tmp/model",
"archive-format": "default",
"iterations": 2,
"version": "1.0",
"force": true
},
{
"name": "packaging_mar_with_handler_entrypoint_func",
"model-name": "model",
"model-file": "model_archiver/tests/integ_tests/resources/regular_model/test_model.py",
"serialized-file": "model_archiver/tests/integ_tests/resources/regular_model/test_serialized_file.pt",
"handler": "model_archiver/tests/integ_tests/resources/regular_model/test_handler:my_handler",
"extra-files": "model_archiver/tests/integ_tests/resources/regular_model/test_index_to_name.json",
"export-path": "/tmp/model",
"archive-format": "default",
"iterations": 2,
"version": "1.0",
"force": true
}]
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,9 @@
"model-file": "model_archiver/tests/integ_tests/resources/regular_model/test_model.py",
"serialized-file": "model_archiver/tests/integ_tests/resources/regular_model/test_serialized_file.pt",
"handler": "text_classifier",
"extra-files": "model_archiver/tests/integ_tests/resources/regular_model/test_index_to_name.json",
"iterations": 1,
"export-path": "/tmp/model",
"version": "1.0",
"source-vocab": "model_archiver/tests/integ_tests/resources/regular_model/source_vocab.pt"
},
{
"name": "text_no_source_vocab",
"model-name": "model",
"model-file": "model_archiver/tests/integ_tests/resources/regular_model/test_model.py",
"serialized-file": "model_archiver/tests/integ_tests/resources/regular_model/test_serialized_file.pt",
"handler": "text_classifier",
"extra-files": "model_archiver/tests/integ_tests/resources/regular_model/test_index_to_name.json",
"extra-files": "model_archiver/tests/integ_tests/resources/regular_model/test_index_to_name.json,model_archiver/tests/integ_tests/resources/regular_model/source_vocab.pt",
"iterations": 1,
"export-path": "/tmp/model",
"version": "1.0",
"expect-error": true
"version": "1.0"
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def my_handler(data, context):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def validate_files(file_list, prefix, default_handler=None):
assert os.path.join(prefix, "dummy-artifacts.txt") in file_list
assert os.path.join(prefix, "1.py") in file_list

if default_handler =="text_classifier":
if default_handler == "text_classifier":
assert os.path.join(prefix, "source_vocab.pt") in file_list


Expand Down Expand Up @@ -121,7 +121,7 @@ def validate(test):


def build_cmd(test):
args = ['model-name', 'model-file', 'serialized-file', 'handler', 'extra-files', 'archive-format', 'source-vocab',
args = ['model-name', 'model-file', 'serialized-file', 'handler', 'extra-files', 'archive-format',
'version', 'export-path', 'runtime']
cmd = ["torch-model-archiver"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,11 @@ def __init__(self, **kwargs):
serialized_file = 'model.pt'
model_file = 'model.pt'
version = "1.0"
source_vocab = None
requirements_file = "requirements.txt"

args = Namespace(model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value,
serialized_file=serialized_file, model_file=model_file, version=version,
source_vocab=source_vocab, requirements_file=requirements_file)
requirements_file=requirements_file)

def test_model(self):
mod = ModelExportUtils.generate_model(self.args)
Expand Down
2 changes: 1 addition & 1 deletion ts/torch_handler/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def preprocess(self, data):
text = text.lower()
text = self._expand_contractions(text)
text = self._remove_accented_characters(text)
text = self._remove_puncutation(text)
text = self._remove_punctuation(text)
text = self._tokenize(text)
text = torch.tensor([self.source_vocab[token] for token in ngrams_iterator(text, ngrams)])

Expand Down
21 changes: 19 additions & 2 deletions ts/torch_handler/text_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Base module for all text based default handler.
Contains various text based utility methods
"""
import os
import re
import string
import unicodedata
Expand All @@ -27,9 +28,25 @@ def __init__(self):
def initialize(self, ctx):
super(TextHandler, self).initialize(ctx)
self.initialized = False
self.source_vocab = torch.load(self.manifest['model']['sourceVocab'])
source_vocab = self.manifest['model']['sourceVocab'] if 'sourceVocab' in self.manifest['model'] else None
if source_vocab:
# Backward compatibility
dhaniram-kshirsagar marked this conversation as resolved.
Show resolved Hide resolved
self.source_vocab = torch.load(source_vocab)
else:
self.source_vocab = torch.load(self.get_source_vocab_path(ctx))
self.initialized = True

def get_source_vocab_path(self, ctx):
properties = ctx.system_properties
model_dir = properties.get("model_dir")
source_vocab_path = os.path.join(model_dir, "source_vocab.pt")

if os.path.isfile(source_vocab_path):
return source_vocab_path
else:
raise Exception('Missing the source_vocab file. Refer default handler '
'documentation for details on using text_handler.')

def _expand_contractions(self, text):
def expand_match(contraction):
match = contraction.group(0)
Expand All @@ -54,7 +71,7 @@ def _remove_html_tags(self, text):
clean_text = re.sub(cleanup_regex, '', text)
return clean_text

def _remove_puncutation(self, text):
def _remove_punctuation(self, text):
return text.translate(str.maketrans('', '', string.punctuation))

def _tokenize(self, text):
Expand Down