Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PT2 code changes #2222

Merged
merged 27 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e11810c
Refactor PT2 code changes
msaroufim Apr 6, 2023
db8f8a9
Added config file to pytest
msaroufim Apr 6, 2023
580c720
update
msaroufim Apr 13, 2023
e22532a
update tests
msaroufim Apr 13, 2023
e7e36bf
test
msaroufim Apr 13, 2023
c402a7b
update xla test
msaroufim Apr 13, 2023
bcd398f
Update base_handler.py
msaroufim Apr 13, 2023
6a574ff
update
msaroufim Apr 13, 2023
832a537
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim Apr 13, 2023
f456d19
update
msaroufim Apr 13, 2023
62c533a
Merge branch 'master' into msaroufim/pt2changes
msaroufim Apr 13, 2023
455ecd8
xla update
msaroufim Apr 13, 2023
c80c89b
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim Apr 13, 2023
a489c4f
Update base_handler.py
msaroufim Apr 13, 2023
f493fc4
Update base_handler.py
msaroufim Apr 13, 2023
d4c9c54
update
msaroufim Apr 17, 2023
dc738ca
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim Apr 17, 2023
d6e1964
tests pass
msaroufim Apr 18, 2023
ba1f8db
update
msaroufim Apr 18, 2023
574ca9c
Trigger Lint
msaroufim Apr 18, 2023
1be4b09
lint
msaroufim Apr 18, 2023
db07a4f
hamid feedback
msaroufim Apr 18, 2023
70438d8
Merge branch 'master' into msaroufim/pt2changes
msaroufim Apr 18, 2023
1e7a148
Merge branch 'master' into msaroufim/pt2changes
msaroufim Apr 19, 2023
48b86a1
Update test_torch_compile.py
msaroufim Apr 19, 2023
6212b0b
Trigger Build
msaroufim Apr 19, 2023
b3b48dd
lint
msaroufim Apr 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
## PyTorch 2.x integration

PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the official release and while we are relying on the nightly builds.
PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental given that most public benchmarks have focused on training instead of inference.

We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture.

## Get started

Install torchserve with nightly torch binaries
Install torchserve and ensure that you're using at least `torch>=2.0.0`

```
python ts_scripts/install_dependencies.py --cuda=cu117 --nightly_torch
```sh
python ts_scripts/install_dependencies.py --cuda=cu117
pip install torchserve torch-model-archiver
```

## Package your model

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `compile.json` during your model packaging
PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging

`{"pt2" : "inductor"}`
`pt2: "inductor"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this pr cover the feature enable/disable torch.compile flag?

Copy link
Member Author

@msaroufim msaroufim Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we discussed it verbally a few weeks ago to make that happen we'd need to be able to load a new config dynamically and not just package it with archiver. Right now users will see logs that compilation failed, they will fallback to non compiled model. Compilation will only be attempted if pt2 flag shows up in yaml so default behavior is unchanged


As an example let's expand our getting started guide with the only difference being passing in the extra `compile.json` file
As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file

```
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json,./serve/examples/image_classifier/compile.json --handler image_classifier
torchserve --start --ncs --model-store model_store --models densenet161.mar
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier
torchserve --start --ncs --model-store model_store --models densenet161.mar --config-file model_config.yaml
```

The exact same approach works with any other model, what's going on is the below
Expand All @@ -35,7 +35,7 @@ opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the original module (weights are shared)
torch.save(model, "model.pt")
torch.save(model, "model.pt")

# 4. Load the non optimized model
mod = torch.load(model)
Expand Down
20 changes: 20 additions & 0 deletions test/pytest/test_data/torch_compile/compile_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from ts.torch_handler.base_handler import BaseHandler


class CompileHandler(BaseHandler):
def __init__(self):
super().__init__()

def initialize(self, context):
super().initialize(context)

def preprocess(self, data):
instances = data[0]["body"]["instances"]
input_tensor = torch.as_tensor(instances, dtype=torch.float32)
return input_tensor

def postprocess(self, data):
# Convert the output tensor to a list and return
return data.tolist()[2]
agunapal marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/pt2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : "inductor"
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/xla.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : "torchxla_trace_once"
1 change: 0 additions & 1 deletion test/pytest/test_data/torch_xla/compile.json

This file was deleted.

104 changes: 104 additions & 0 deletions test/pytest/test_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import glob
import json
import os
import subprocess
import time
from pathlib import Path

import pytest
import torch
from pkg_resources import packaging

PT_2_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0")
else False
)

CURR_FILE_PATH = Path(__file__).parent
TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")

MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py")
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py")
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml")


SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt")
MODEL_STORE_DIR = os.path.join(TEST_DATA_DIR, "model_store")
MODEL_NAME = "half_plus_two"


@pytest.mark.skipif(PT_2_AVAILABLE == False, reason="torch version is < 2.0.0")
class TestTorchCompile:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the functions such as test_start_torchserve, test_server_status, test_registered_model, and test_serve_inference can be generalized to be shared with the other test cases. This work can be done later.

def teardown_class(self):
subprocess.run("torchserve --stop", shell=True, check=True)
time.sleep(10)

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1

def test_start_torchserve(self):
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}"
subprocess.run(
cmd,
shell=True,
check=True,
)
time.sleep(10)
assert len(glob.glob("logs/access_log.log")) == 1
assert len(glob.glob("logs/model_log.log")) == 1
assert len(glob.glob("logs/ts_log.log")) == 1

def test_server_status(self):
result = subprocess.run(
"curl http://localhost:8080/ping",
shell=True,
capture_output=True,
check=True,
)
expected_server_status_str = '{"status": "Healthy"}'
expected_server_status = json.loads(expected_server_status_str)
assert json.loads(result.stdout) == expected_server_status

def test_registered_model(self):
result = subprocess.run(
"curl http://localhost:8081/models",
shell=True,
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def test_serve_inference(self):
request_data = {"instances": [[1.0], [2.0], [3.0]]}
request_json = json.dumps(request_data)

result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/half_plus_two -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)

string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5

assert float_result == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend inductor" in model_log
8 changes: 4 additions & 4 deletions test/pytest/test_torch_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
TORCHXLA_AVAILABLE = False

CURR_FILE_PATH = Path(__file__).parent
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data")
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile")

MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py")
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json")
YAML_CONFIG = os.path.join(TORCH_XLA_TEST_DATA_DIR, "xla.yaml")
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties")

SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt")
Expand All @@ -40,14 +40,14 @@ def teardown_class(self):

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(EXTRA_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
assert len(glob.glob(CONFIG_PROPERTIES)) == 1
subprocess.run(
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True
)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
shell=True,
check=True,
)
Expand Down
28 changes: 19 additions & 9 deletions test/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,41 @@ def model_archiver_command_builder(
handler=None,
extra_files=None,
force=False,
config_file=None,
):
cmd = "torch-model-archiver"
# Initialize a list to store the command-line arguments
cmd_parts = ["torch-model-archiver"]

# Append arguments to the list
if model_name:
cmd += " --model-name {0}".format(model_name)
cmd_parts.append(f"--model-name {model_name}")

if version:
cmd += " --version {0}".format(version)
cmd_parts.append(f"--version {version}")

if model_file:
cmd += " --model-file {0}".format(model_file)
cmd_parts.append(f"--model-file {model_file}")

if serialized_file:
cmd += " --serialized-file {0}".format(serialized_file)
cmd_parts.append(f"--serialized-file {serialized_file}")

if handler:
cmd += " --handler {0}".format(handler)
cmd_parts.append(f"--handler {handler}")

if extra_files:
cmd += " --extra-files {0}".format(extra_files)
cmd_parts.append(f"--extra-files {extra_files}")

if config_file:
cmd_parts.append(f"--config-file {config_file}")

if force:
cmd += " --force"
cmd_parts.append("--force")

# Append the export-path argument to the list
cmd_parts.append(f"--export-path {MODEL_STORE}")

cmd += " --export-path {0}".format(MODEL_STORE)
# Convert the list into a string to represent the complete command
cmd = " ".join(cmd_parts)

return cmd

Expand Down
Loading