Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that self.manifest is defined before call profiler #2488

Merged
merged 7 commits into from
Jul 22, 2023
2 changes: 2 additions & 0 deletions docs/performance_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ You can find more information on TorchServe benchmarking [here](https://github.c

TorchServe has native support for the PyTorch profiler which will help you find performance bottlenecks in your code.

If you created a custom `handle` or `initialize` method overwriting the BaseHandler, you must define the `self.manifest` attribute to be able to run `_infer_with_profiler`.

```
export ENABLE_TORCH_PROFILER=TRUE
```
Expand Down
2 changes: 2 additions & 0 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def handle(self, data, context):
is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None)
if is_profiler_enabled:
if PROFILER_AVAILABLE:
if self.manifest is None:
self.manifest = context.manifest # profiler will use to get the model name
output, _ = self._infer_with_profiler(data=data)
else:
raise RuntimeError(
Expand Down
13 changes: 12 additions & 1 deletion ts/torch_handler/unit_tests/test_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
Ensures it can load and execute an example model
"""

import os
import pytest

from ts.torch_handler.base_handler import BaseHandler
from ts.torch_handler.base_handler import BaseHandler, PROFILER_AVAILABLE


@pytest.fixture()
Expand All @@ -30,3 +31,13 @@ def test_batch_handle(handler, base_model_context):
processed = handler.handle(list_data, base_model_context)

assert processed == [1, 0]


def test_inference_with_profiler_works_with_custom_initialize_method(handler, base_model_context):
handler.manifest = None
PROFILER_AVAILABLE = True
os.environ["ENABLE_TORCH_PROFILER"] = "1"

list_data = [[1.0, 2.0], [4.0, 3.0]]
processed = handler.handle(list_data, base_model_context)
assert processed == [1, 0]