Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile() support #1960

Merged
merged 29 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/image_classifier/compile.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"pt2" : "inductor"}
49 changes: 49 additions & 0 deletions examples/pt2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
## PyTorch 2.x integration

PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the PyTorch 1.14 release.
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

## Get started

Install torchserve with nightly torch binaries

```
python ts_scripts/install_dependencies.py --cuda=cu117 --nightly_torch
pip install torchserve torch-model-archiver
```
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

## Package your model

PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `compile.json` during your model packaging

`{"pt2" : "inductor"}`

As an example let's expand our getting started guide with the only difference being passing in the extra `compile.json` file

```
mkdir model_store
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json,./serve/examples/image_classifier/compile.json --handler image_classifier
torchserve --start --ncs --model-store model_store --models densenet161.mar
```

The exact same approach works with any other model, what's going on is the below
msaroufim marked this conversation as resolved.
Show resolved Hide resolved

```python
# 1. Convert a regular module to an optimized module
opt_mod = torch.compile(mod)
# 2. Train the optimized module
# ....
# 3. Save the original module (weights are shared)
torch.save(model, "model.pt")

# 4. Load the non optimized model
mod = torch.load(model)

# 5. Compile the module and then run inferences with it
opt_mod = torch.compile(mod)
```

torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models.

## Next steps

For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. You can learn more here https://github.com/pytorch/serve/blob/master/docs/performance_guide.md
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
70 changes: 59 additions & 11 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
import torch
from pkg_resources import packaging

from ..utils.util import list_classes_from_module, load_label_mapping
from ..utils.util import (
list_classes_from_module,
load_compiler_config,
load_label_mapping,
)


if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"):
from torch.profiler import ProfilerActivity, profile, record_function
Expand All @@ -24,6 +29,33 @@

logger = logging.getLogger(__name__)

# Possible values for backend in utils.py
def check_pt2_enabled():
try:
import torch._dynamo
pt2_enabled = True
if torch.cuda.is_available():
# If Ampere enable tensor cores and ideally get yourself an A10G or A100
if torch.cuda.get_device_capability() >= (8, 0):
torch.backends.cuda.matmul.allow_tf32 = True
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
except ImportError as error:
logger.warning(
"dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu"
)
pt2_enabled = False
return pt2_enabled


ipex_enabled = False
if os.environ.get("TS_IPEX_ENABLE", "false") == "true":
try:
import intel_extension_for_pytorch as ipex

ipex_enabled = True
except ImportError as error:
logger.warning(
"IPEX is enabled but intel-extension-for-pytorch is not installed. Proceeding without IPEX."
)

class BaseHandler(abc.ABC):
"""
Expand Down Expand Up @@ -86,17 +118,18 @@ def initialize(self, context):
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)

if self.model_pt_path.endswith("onnx"):
try:
# import numpy as np
import onnxruntime as ort
import psutil
if self.model_pt_path:
if self.model_pt_path.endswith("onnx"):
try:
# import numpy as np
import onnxruntime as ort
import psutil

onnx_enabled = True
logger.info("ONNX enabled")
except ImportError as error:
onnx_enabled = False
logger.warning("proceeding without onnxruntime")
onnx_enabled = True
logger.info("ONNX enabled")
except ImportError as error:
onnx_enabled = False
logger.warning("proceeding without onnxruntime")

# model def file
model_file = self.manifest["model"].get("modelFile", "")
Expand Down Expand Up @@ -138,6 +171,21 @@ def initialize(self, context):
else:
raise RuntimeError("No model weights could be loaded")

self.model.eval()
optimization_config = os.path.join(model_dir, "compile.json")
backend = load_compiler_config(optimization_config)

# PT 2.0 support is opt in
if check_pt2_enabled() and backend:
# Compilation will delay your model initialization
try:
self.model = torch.compile(self.model, backend=backend)
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Compiled model with backend {backend}")
except:
logger.warning(
f"Compiling model model with backend {backend} has failed \n Proceeding without compilation"
)

msaroufim marked this conversation as resolved.
Show resolved Hide resolved
if ipex_enabled:
self.model = self.model.to(memory_format=torch.channels_last)
self.model = ipex.optimize(self.model)
Expand Down
35 changes: 35 additions & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
"""
Utility functions for TorchServe
"""
import enum
import inspect
import itertools
import json
import logging
import os
import re


class PT2Backend(str, enum.Enum):
EAGER = "eager"
AOT_EAGER = "aot_eager"
INDUCTOR = "inductor"
NVFUSER = "nvfuser"
AOT_NVFUSER = "aot_nvfuser"
AOT_CUDAGRAPHS = "aot_cudagraphs"
OFI = "ofi"
FX2TRT = "fx2trt"
ONNXRT = "onnxrt"
IPEX = "ipex"


logger = logging.getLogger(__name__)

CLEANUP_REGEX = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
Expand Down Expand Up @@ -38,6 +53,26 @@ def list_classes_from_module(module, parent_class=None):
return classes


def load_compiler_config(config_file_path):
"""
Load a compiler {compiler_name -> compiler }
Can be extended to also support kwargs for ONNX and TensorRT
"""
if not os.path.isfile(config_file_path):
logger.warning(f"{config_file_path} is missing. PT 2.0 will not be used")
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
return None

with open(config_file_path) as f:
mapping = json.load(f)

backend_values = [member.value for member in PT2Backend]
if mapping["pt2"] in backend_values:
return mapping["pt2"]
else:
logger.warning(f"{mapping['pt2']} is not a supported backend")
return mapping["pt2"]


def load_label_mapping(mapping_file_path):
"""
Load a JSON mapping { class ID -> friendly class name }.
Expand Down
19 changes: 15 additions & 4 deletions ts_scripts/install_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def install_torch_packages(self, cuda_version):
f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt"
)

def install_python_packages(self, cuda_version, requirements_file_path):
def install_python_packages(self, cuda_version, requirements_file_path, nightly):
check = "where" if platform.system() == "Windows" else "which"
if os.system(f"{check} conda") == 0:
# conda install command should run before the pip install commands
Expand All @@ -55,6 +55,10 @@ def install_python_packages(self, cuda_version, requirements_file_path):
# developer.txt also installs packages from common.txt
os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}")
# If conda is available install conda-build package
if nightly:
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
os.system(
f"pip3 install numpy --pre torch[dynamo] torchvision torchtext torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117"
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
)

def install_node_packages(self):
os.system(
Expand Down Expand Up @@ -140,7 +144,7 @@ def install_wget(self):
os.system("brew install wget")


def install_dependencies(cuda_version=None):
def install_dependencies(cuda_version=None, nightly=False):
os_map = {"Linux": Linux, "Windows": Windows, "Darwin": Darwin}
system = os_map[platform.system()]()

Expand All @@ -157,7 +161,7 @@ def install_dependencies(cuda_version=None):
requirements_file_path = "requirements/" + (
"production.txt" if args.environment == "prod" else "developer.txt"
)
system.install_python_packages(cuda_version, requirements_file_path)
system.install_python_packages(cuda_version, requirements_file_path, nightly)


def get_brew_version():
Expand All @@ -183,11 +187,18 @@ def get_brew_version():
choices=["prod", "dev"],
help="environment(production or developer) on which dependencies will be installed",
)

parser.add_argument(
"--nightly_torch",
action="store_true",
help="Install nightly version of torch package",
)

parser.add_argument(
"--force",
action="store_true",
help="force reinstall dependencies wget, node, java and apt-update",
)
args = parser.parse_args()

install_dependencies(cuda_version=args.cuda)
install_dependencies(cuda_version=args.cuda, nightly=args.nightly_torch)
3 changes: 3 additions & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1001,3 +1001,6 @@ sess
InferenceTimeInMS
MetricTypes
MetricsCache
TIMM
backends
inductor