Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kwargs for torch.compile in BaseHandler #2796

Merged
merged 5 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@ pip install torchserve torch-model-archiver

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `model_config.yaml` during your model packaging

`pt2: "inductor"`
```yaml
pt2: "inductor"
```

You can also pass a dictionary with compile options if you need more control over torch.compile:

```yaml
pt2 : {backend: inductor, mode: reduce-overhead}
```

As an example let's expand our getting started guide with the only difference being passing in the extra `model_config.yaml` file

Expand Down Expand Up @@ -99,5 +107,3 @@ print(extra_files['foo.txt'])
# from inference()
print(ep(torch.randn(5)))
```


1 change: 1 addition & 0 deletions test/pytest/test_data/torch_compile/pt2_dict.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pt2 : {backend: inductor, mode: reduce-overhead}
64 changes: 45 additions & 19 deletions test/pytest/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

MODEL_FILE = os.path.join(TEST_DATA_DIR, "model.py")
HANDLER_FILE = os.path.join(TEST_DATA_DIR, "compile_handler.py")
YAML_CONFIG = os.path.join(TEST_DATA_DIR, "pt2.yaml")
YAML_CONFIG_STR = os.path.join(TEST_DATA_DIR, "pt2.yaml") # backend as string
YAML_CONFIG_DICT = os.path.join(TEST_DATA_DIR, "pt2_dict.yaml") # arbitrary kwargs dict


SERIALIZED_FILE = os.path.join(TEST_DATA_DIR, "model.pt")
Expand All @@ -41,19 +42,32 @@ def teardown_class(self):

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(YAML_CONFIG)) == 1
assert len(glob.glob(YAML_CONFIG_STR)) == 1
assert len(glob.glob(YAML_CONFIG_DICT)) == 1
subprocess.run(f"cd {TEST_DATA_DIR} && python model.py", shell=True, check=True)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)

# register 2 models, one with the backend as str config, the other with the kwargs as dict config
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME}_str --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG_STR} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
f"torch-model-archiver --model-name {MODEL_NAME}_dict --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --config-file {YAML_CONFIG_DICT} --export-path {MODEL_STORE_DIR} --handler {HANDLER_FILE} -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1
assert (
len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}_str.mar"))) == 1
)
assert (
len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}_dict.mar"))) == 1
)

def test_start_torchserve(self):
cmd = f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR}"
cmd = f"torchserve --start --ncs --models {MODEL_NAME}_str.mar,{MODEL_NAME}_dict.mar --model-store {MODEL_STORE_DIR}"
subprocess.run(
cmd,
shell=True,
Expand Down Expand Up @@ -90,9 +104,16 @@ def test_registered_model(self):
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def _response_to_tuples(response_str):
models = json.loads(response_str)["models"]
return {(k, v) for d in models for k, v in d.items()}

# transform to set of tuples so order won't cause inequality
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two_str", "modelUrl": "half_plus_two_str.mar"}, {"modelName": "half_plus_two_dict", "modelUrl": "half_plus_two_dict.mar"}]}'
assert _response_to_tuples(result.stdout) == _response_to_tuples(
expected_registered_model_str
)

@pytest.mark.skipif(
os.environ.get("TS_RUN_IN_DOCKER", False),
Expand All @@ -103,20 +124,25 @@ def test_serve_inference(self):
request_data = {"instances": [[1.0], [2.0], [3.0]]}
request_json = json.dumps(request_data)

result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/half_plus_two -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)
for model_name in [f"{MODEL_NAME}_str", f"{MODEL_NAME}_dict"]:
result = subprocess.run(
f"curl -s -X POST -H \"Content-Type: application/json;\" http://localhost:8080/predictions/{model_name} -d '{request_json}'",
shell=True,
capture_output=True,
check=True,
)

string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5
string_result = result.stdout.decode("utf-8")
float_result = float(string_result)
expected_result = 3.5

assert float_result == expected_result
assert float_result == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend inductor" in model_log
assert "Compiled model with backend inductor\n" in model_log
assert (
"Compiled model with backend inductor, mode reduce-overhead"
in model_log
)
27 changes: 22 additions & 5 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,23 +184,40 @@ def initialize(self, context):
raise RuntimeError("No model weights could be loaded")

if hasattr(self, "model_yaml_config") and "pt2" in self.model_yaml_config:
pt2_backend = self.model_yaml_config["pt2"]
valid_backend = check_valid_pt2_backend(pt2_backend)
pt2_value = self.model_yaml_config["pt2"]

# pt2_value can be the backend, passed as a str, or arbitrary kwargs, passed as a dict
if isinstance(pt2_value, str):
compile_options = dict(backend=pt2_value)
elif isinstance(pt2_value, dict):
compile_options = pt2_value
else:
raise ValueError("pt2 should be str or dict")

# if backend is not provided, compile will use its default, which is valid
valid_backend = (
check_valid_pt2_backend(compile_options["backend"])
if "backend" in compile_options
else True
)
else:
valid_backend = False

# PT 2.0 support is opt in
if PT2_AVAILABLE and valid_backend:
compile_options_str = ", ".join(
[f"{k} {v}" for k, v in compile_options.items()]
)
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model,
backend=pt2_backend,
**compile_options,
)
logger.info(f"Compiled model with backend {pt2_backend}")
logger.info(f"Compiled model with {compile_options_str}")
except Exception as e:
logger.warning(
f"Compiling model model with backend {pt2_backend} has failed \n Proceeding without compilation"
f"Compiling model model with {compile_options_str} has failed \n Proceeding without compilation"
)
logger.warning(e)

Expand Down
Loading