Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make --serialized-file argument optional #994

Merged
merged 6 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model-archiver/model_archiver/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def export_model_args_parser():
'specified, else it will be saved under the export path')

parser_export.add_argument('--serialized-file',
required=True,
required=False,
type=str,
default=None,
help='Path to .pt or .pth file containing state_dict in case of eager mode\n'
Expand Down
7 changes: 5 additions & 2 deletions model-archiver/model_archiver/manifest_components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def __init__(self, model_name, serialized_file, handler, model_file=None, model_
extensions=None, requirements_file=None):

self.model_name = model_name
self.serialized_file = serialized_file.split("/")[-1]
self.serialized_file = None
if serialized_file:
self.serialized_file = serialized_file.split("/")[-1]
self.model_file = model_file
self.model_version = model_version
self.extensions = extensions
Expand All @@ -26,7 +28,8 @@ def __to_dict__(self):

model_dict['modelName'] = self.model_name

model_dict['serializedFile'] = self.serialized_file
if self.serialized_file:
model_dict['serializedFile'] = self.serialized_file

model_dict['handler'] = self.handler

Expand Down
17 changes: 10 additions & 7 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ def initialize(self, context):
self.manifest = context.manifest

model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)

if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)

# model def file
model_file = self.manifest["model"].get("modelFile", "")
Expand All @@ -67,6 +66,9 @@ def initialize(self, context):
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
else:
logger.debug("Loading torchscript model")
if not os.path.isfile(model_pt_path):
Copy link
Collaborator

@dhanainme dhanainme Mar 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like presence / absence of --model-file arg in the manifest determines if this is an eager mode / torchscript model. Given usage of this in AWS Inf tolkit

    model_archiver_cmd = [
        "torch-model-archiver",
        "--model-name",
        DEFAULT_TS_MODEL_NAME,
        "--handler",
        handler_service,
        "--serialized-file",
        os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
        "--export-path",
        DEFAULT_TS_MODEL_DIRECTORY,
        "--extra-files",
        os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR, environment.Environment().module_name + ".py"),
        "--version",
        "1",
    ]

https://github.com/aws/sagemaker-pytorch-inference-toolkit/blob/6936c08581e26ff3bac26824b1e4946ec68ffc85/src/sagemaker_pytorch_serving_container/torchserve.py#L110

Does AWS Inf toolkit even support Eager mode models ? Per documentation --model-file is a required arg for Eagermode models.

Should --model-file be passed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toolkit missed option "--model-file " for eager mode. It should be fixed in toolkit.

Copy link
Collaborator

@maaquib maaquib Mar 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If customers have been using TS archiver without --model-file arg, we should make sure this doesn't cause any backward compatibility issues

Copy link
Collaborator Author

@lxning lxning Mar 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On TS side, the only change for option "--serialized-file" is that it is optional in eager mode (instead of mandatory) . So it is backward compatible.

raise RuntimeError("Missing the model.pt file")

self.model = self._load_torchscript_model(model_pt_path)

self.model.to(self.device)
Expand Down Expand Up @@ -122,9 +124,10 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
)

model_class = model_class_definitions[0]
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model = model_class()
model.load_state_dict(state_dict)
if model_pt_path:
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model.load_state_dict(state_dict)
return model

def preprocess(self, data):
Expand Down