-
Notifications
You must be signed in to change notification settings - Fork 863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor PT2 code changes #2222
Merged
Merged
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
e11810c
Refactor PT2 code changes
msaroufim db8f8a9
Added config file to pytest
msaroufim 580c720
update
msaroufim e22532a
update tests
msaroufim e7e36bf
test
msaroufim c402a7b
update xla test
msaroufim bcd398f
Update base_handler.py
msaroufim 6a574ff
update
msaroufim 832a537
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim f456d19
update
msaroufim 62c533a
Merge branch 'master' into msaroufim/pt2changes
msaroufim 455ecd8
xla update
msaroufim c80c89b
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim a489c4f
Update base_handler.py
msaroufim f493fc4
Update base_handler.py
msaroufim d4c9c54
update
msaroufim dc738ca
Merge branch 'msaroufim/pt2changes' of https://github.com/pytorch/ser…
msaroufim d6e1964
tests pass
msaroufim ba1f8db
update
msaroufim 574ca9c
Trigger Lint
msaroufim 1be4b09
lint
msaroufim db07a4f
hamid feedback
msaroufim 70438d8
Merge branch 'master' into msaroufim/pt2changes
msaroufim 1e7a148
Merge branch 'master' into msaroufim/pt2changes
msaroufim 48b86a1
Update test_torch_compile.py
msaroufim 6212b0b
Trigger Build
msaroufim b3b48dd
lint
msaroufim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
|
||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
|
||
class CompileHandler(BaseHandler): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def initialize(self, context): | ||
super().initialize(context) | ||
|
||
def preprocess(self, data): | ||
instances = data[0]["body"]["instances"] | ||
input_tensor = torch.as_tensor(instances, dtype=torch.float32) | ||
return input_tensor | ||
|
||
def postprocess(self, data): | ||
# Convert the output tensor to a list and return | ||
return data.tolist()[2] | ||
agunapal marked this conversation as resolved.
Show resolved
Hide resolved
|
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pt2 : "inductor" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pt2 : "torchxla_trace_once" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import glob | ||
import json | ||
import os | ||
import subprocess | ||
import time | ||
import warnings | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from pkg_resources import packaging | ||
|
||
|
||
PT_2_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0") | ||
else False | ||
) | ||
|
||
CURR_FILE_PATH = Path(__file__).parent | ||
TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data", "torch_compile") | ||
|
||
MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py") | ||
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py") | ||
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml") | ||
|
||
|
||
SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt") | ||
MODEL_STORE_DIR = os.path.join(TEST_DATA_DIR, "model_store") | ||
MODEL_NAME = "half_plus_two" | ||
|
||
|
||
@pytest.mark.skipif(PT_2_AVAILABLE == False, reason="torch version is < 2.0.0") | ||
class TestTorchCompile: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the functions such as test_start_torchserve, test_server_status, test_registered_model, and test_serve_inference can be generalized to be shared with the other test cases. This work can be done later. |
||
def teardown_class(self): | ||
subprocess.run("torchserve --stop", shell=True, check=True) | ||
time.sleep(10) | ||
|
||
def test_archive_model_artifacts(self): | ||
assert len(glob.glob(MODEL_FILE)) == 1 | ||
assert len(glob.glob(YAML_CONFIG)) == 1 | ||
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True) | ||
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True) | ||
subprocess.run( | ||
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f", | ||
shell=True, | ||
check=True, | ||
) | ||
assert len(glob.glob(SERIALIZED_FILE)) == 1 | ||
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1 | ||
|
||
def test_start_torchserve(self): | ||
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}" | ||
subprocess.run( | ||
cmd, | ||
shell=True, | ||
check=True, | ||
) | ||
time.sleep(10) | ||
assert len(glob.glob("logs/access_log.log")) == 1 | ||
assert len(glob.glob("logs/model_log.log")) == 1 | ||
assert len(glob.glob("logs/ts_log.log")) == 1 | ||
|
||
def test_server_status(self): | ||
result = subprocess.run( | ||
"curl http://localhost:8080/ping", | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
expected_server_status_str = '{"status": "Healthy"}' | ||
expected_server_status = json.loads(expected_server_status_str) | ||
assert json.loads(result.stdout) == expected_server_status | ||
|
||
def test_registered_model(self): | ||
result = subprocess.run( | ||
"curl http://localhost:8081/models", | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}' | ||
expected_registered_model = json.loads(expected_registered_model_str) | ||
assert json.loads(result.stdout) == expected_registered_model | ||
|
||
def test_serve_inference(self): | ||
request_data = {"instances": [[1.0], [2.0], [3.0]]} | ||
request_json = json.dumps(request_data) | ||
|
||
result = subprocess.run( | ||
f'curl -s -X POST -H "Content-Type: application/json;" http://localhost:8080/predictions/half_plus_two -d \'{request_json}\'', | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
|
||
string_result = result.stdout.decode("utf-8") | ||
float_result = float(string_result) | ||
expected_result = 3.5 | ||
|
||
assert float_result == expected_result | ||
|
||
model_log_path = glob.glob("logs/model_log.log")[0] | ||
with open(model_log_path, "rt") as model_log_file: | ||
model_log = model_log_file.read() | ||
assert "Compiled model with backend inductor" in model_log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this pr cover the feature enable/disable torch.compile flag?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we discussed it verbally a few weeks ago to make that happen we'd need to be able to load a new config dynamically and not just package it with archiver. Right now users will see logs that compilation failed, they will fallback to non compiled model. Compilation will only be attempted if pt2 flag shows up in yaml so default behavior is unchanged