From 05e38fba296d7664909eb0705a7993c4bd3a23f8 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 9 Nov 2022 06:34:36 +0000 Subject: [PATCH 01/27] Experimental torchdynamo support --- ts/torch_handler/base_handler.py | 13 +++++++++++ ts/torch_handler/utils.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 ts/torch_handler/utils.py diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 5849f89f4b..5e6a28a4e1 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -11,6 +11,8 @@ import torch from pkg_resources import packaging from ..utils.util import list_classes_from_module, load_label_mapping +from .utils import DynamoBackend +from os import environ if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"): from torch.profiler import profile, record_function, ProfilerActivity @@ -22,6 +24,15 @@ logger = logging.getLogger(__name__) +# Possible values for backend in utils.py +if os.environ.get("DYNAMO_BACKEND"): + try: + import torch._dynamo + dynamo_enabled = True + torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 + except ImportError as error: + logger.warning("dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu") + ipex_enabled = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: @@ -98,6 +109,8 @@ def initialize(self, context): self.model = self._load_torchscript_model(model_pt_path) self.model.eval() + if dynamo_enabled: + torch._dynamo.optimize(DynamoBackend.INDUCTOR)(self.model) if ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) self.model = ipex.optimize(self.model) diff --git a/ts/torch_handler/utils.py b/ts/torch_handler/utils.py new file mode 100644 index 0000000000..8b81713923 --- /dev/null +++ b/ts/torch_handler/utils.py @@ -0,0 +1,37 @@ +import enum + +class DynamoBackend(str, enum.Enum): + """ + Represents a dynamo backend (see https://github.com/pytorch/torchdynamo). + Values: + - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo + issues. + - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's + extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. + - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton + kernels. [Read + more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) + - **NVFUSER** -- nvFuser with TorchScript. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **AOT_NVFUSER** -- nvFuser with AotAutograd. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **AOT_CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read + more](https://github.com/pytorch/torchdynamo/pull/757) + - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read + more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) + - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read + more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) + - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) + - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read + more](https://github.com/intel/intel-extension-for-pytorch). + """ + EAGER = "eager" + AOT_EAGER = "aot_eager" + INDUCTOR = "inductor" + NVFUSER = "nvfuser" + AOT_NVFUSER = "aot_nvfuser" + AOT_CUDAGRAPHS = "aot_cudagraphs" + OFI = "ofi" + FX2TRT = "fx2trt" + ONNXRT = "onnxrt" + IPEX = "ipex" \ No newline at end of file From 715ee129bb054ccf4281126fc00a6a15d0a46f86 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 9 Nov 2022 06:35:39 +0000 Subject: [PATCH 02/27] utils.py --- ts/torch_handler/utils.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/ts/torch_handler/utils.py b/ts/torch_handler/utils.py index 8b81713923..87c73b8a50 100644 --- a/ts/torch_handler/utils.py +++ b/ts/torch_handler/utils.py @@ -1,30 +1,6 @@ import enum class DynamoBackend(str, enum.Enum): - """ - Represents a dynamo backend (see https://github.com/pytorch/torchdynamo). - Values: - - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo - issues. - - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's - extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. - - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton - kernels. [Read - more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) - - **NVFUSER** -- nvFuser with TorchScript. [Read - more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) - - **AOT_NVFUSER** -- nvFuser with AotAutograd. [Read - more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) - - **AOT_CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read - more](https://github.com/pytorch/torchdynamo/pull/757) - - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read - more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) - - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read - more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) - - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) - - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read - more](https://github.com/intel/intel-extension-for-pytorch). - """ EAGER = "eager" AOT_EAGER = "aot_eager" INDUCTOR = "inductor" From e0934019aacef17e53e14ded4bbb08480b6cb921 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 9 Nov 2022 06:41:46 +0000 Subject: [PATCH 03/27] [skip ci] push --- ts/torch_handler/base_handler.py | 2 ++ ts/torch_handler/utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 5e6a28a4e1..029e6bebf5 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -29,6 +29,7 @@ try: import torch._dynamo dynamo_enabled = True + dynamo_backend = os.environ.get("DYNAMO_BACKEND") torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 except ImportError as error: logger.warning("dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu") @@ -110,6 +111,7 @@ def initialize(self, context): self.model.eval() if dynamo_enabled: + # For now just enable inductor by default torch._dynamo.optimize(DynamoBackend.INDUCTOR)(self.model) if ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) diff --git a/ts/torch_handler/utils.py b/ts/torch_handler/utils.py index 87c73b8a50..7de746b445 100644 --- a/ts/torch_handler/utils.py +++ b/ts/torch_handler/utils.py @@ -1,5 +1,6 @@ import enum +@add_mapping class DynamoBackend(str, enum.Enum): EAGER = "eager" AOT_EAGER = "aot_eager" @@ -10,4 +11,10 @@ class DynamoBackend(str, enum.Enum): OFI = "ofi" FX2TRT = "fx2trt" ONNXRT = "onnxrt" - IPEX = "ipex" \ No newline at end of file + IPEX = "ipex" + +def add_mapping(enum_cls): + for name in enum_cls.__MAPPING__: + member = enum_cls.__members__[name] + enum_cls.__MAPPING__[name] = member + return enum_cls \ No newline at end of file From 449602863b039f1fa721cb8a8b9d0ff0be66426a Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 8 Nov 2022 22:45:40 -0800 Subject: [PATCH 04/27] Update base_handler.py --- ts/torch_handler/base_handler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 029e6bebf5..91c687b9dc 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -110,9 +110,8 @@ def initialize(self, context): self.model = self._load_torchscript_model(model_pt_path) self.model.eval() - if dynamo_enabled: - # For now just enable inductor by default - torch._dynamo.optimize(DynamoBackend.INDUCTOR)(self.model) + if dynamo_enabled: + torch._dynamo.optimize(dynamo_backend if dynamo_backend in DynamoBackend else DynamoBackend.INDUCTOR)(self.model) if ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) self.model = ipex.optimize(self.model) From 06eecca518b1c1276caa6089e7beae15c8db4c33 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 8 Nov 2022 23:19:54 -0800 Subject: [PATCH 05/27] Update base_handler.py --- ts/torch_handler/base_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 91c687b9dc..ec332fedd2 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -33,7 +33,8 @@ torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 except ImportError as error: logger.warning("dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu") - + dynamo_enabled = False + ipex_enabled = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: From 8ea85fc0be27cd8bb82b550ba68ec7ac68e86d38 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 8 Nov 2022 23:20:44 -0800 Subject: [PATCH 06/27] Update base_handler.py --- ts/torch_handler/base_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index ec332fedd2..f028b07779 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -30,7 +30,8 @@ import torch._dynamo dynamo_enabled = True dynamo_backend = os.environ.get("DYNAMO_BACKEND") - torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 except ImportError as error: logger.warning("dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu") dynamo_enabled = False From 9d48c8709054ed28a007050feb4a2ba0a8ac5cb3 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:10:23 +0000 Subject: [PATCH 07/27] add nightly installation instructions --- examples/pt2/README.md | 12 ++++++++++++ ts_scripts/install_dependencies.py | 19 +++++++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 examples/pt2/README.md diff --git a/examples/pt2/README.md b/examples/pt2/README.md new file mode 100644 index 0000000000..b6ae84dd25 --- /dev/null +++ b/examples/pt2/README.md @@ -0,0 +1,12 @@ +## PyTorch 2.x integration + +PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the PyTorch 1.14 release. + +## Get started + +Install torchserve + +``` +python ts_scripts/dependencies.py --cuda=cu117 --nightly +pip install torchserve +``` diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 1b0484c4e5..197d0c8721 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -43,7 +43,7 @@ def install_torch_packages(self, cuda_version): f"{sys.executable} -m pip install -U -r requirements/torch_{platform.system().lower()}.txt" ) - def install_python_packages(self, cuda_version, requirements_file_path): + def install_python_packages(self, cuda_version, requirements_file_path, nightly): check = "where" if platform.system() == "Windows" else "which" if os.system(f"{check} conda") == 0: # conda install command should run before the pip install commands @@ -55,6 +55,10 @@ def install_python_packages(self, cuda_version, requirements_file_path): # developer.txt also installs packages from common.txt os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}") # If conda is available install conda-build package + if nightly: + os.system( + f"pip3 install numpy --pre torch[dynamo] torchvision torchtext torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117" + ) def install_node_packages(self): os.system( @@ -140,7 +144,7 @@ def install_wget(self): os.system("brew install wget") -def install_dependencies(cuda_version=None): +def install_dependencies(cuda_version=None, nightly=False): os_map = {"Linux": Linux, "Windows": Windows, "Darwin": Darwin} system = os_map[platform.system()]() @@ -157,7 +161,7 @@ def install_dependencies(cuda_version=None): requirements_file_path = "requirements/" + ( "production.txt" if args.environment == "prod" else "developer.txt" ) - system.install_python_packages(cuda_version, requirements_file_path) + system.install_python_packages(cuda_version, requirements_file_path, nightly) def get_brew_version(): @@ -183,6 +187,13 @@ def get_brew_version(): choices=["prod", "dev"], help="environment(production or developer) on which dependencies will be installed", ) + + parser.add_argument( + "--nightly_torch", + action="store_true", + help="Install nightly version of torch package", + ) + parser.add_argument( "--force", action="store_true", @@ -190,4 +201,4 @@ def get_brew_version(): ) args = parser.parse_args() - install_dependencies(cuda_version=args.cuda) + install_dependencies(cuda_version=args.cuda, nightly=args.nightly_torch) From 3b6c123fb61daa6c91a3ed0ae227a3a21a200e5d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:34:18 +0000 Subject: [PATCH 08/27] weee --- examples/pt2/README.md | 2 ++ ts/torch_handler/base_handler.py | 47 ++++++++++++++++++++------------ ts/torch_handler/utils.py | 10 ++----- ts/utils/util.py | 35 ++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 26 deletions(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index b6ae84dd25..39fcdc0c54 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -10,3 +10,5 @@ Install torchserve python ts_scripts/dependencies.py --cuda=cu117 --nightly pip install torchserve ``` + +## Package your model diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index f028b07779..bc46130532 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -4,18 +4,22 @@ """ import abc +import importlib.util import logging import os -import importlib.util import time + import torch from pkg_resources import packaging -from ..utils.util import list_classes_from_module, load_label_mapping -from .utils import DynamoBackend -from os import environ + +from ..utils.util import ( + list_classes_from_module, + load_compiler_config, + load_label_mapping, +) if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"): - from torch.profiler import profile, record_function, ProfilerActivity + from torch.profiler import ProfilerActivity, profile, record_function PROFILER_AVAILABLE = True else: @@ -25,17 +29,23 @@ logger = logging.getLogger(__name__) # Possible values for backend in utils.py -if os.environ.get("DYNAMO_BACKEND"): +def check_pt2_enabled(): try: - import torch._dynamo - dynamo_enabled = True - dynamo_backend = os.environ.get("DYNAMO_BACKEND") + import torch.compile + + pt2_enabled = True if torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True # Enable tensor cores and idealy get an A10G or A100 + # If Ampere enable tensor cores and ideally get yourself an A10G or A100 + if torch.cuda.get_device_capability() >= (8, 0): + torch.backends.cuda.matmul.allow_tf32 = True except ImportError as error: - logger.warning("dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu") - dynamo_enabled = False - + logger.warning( + "dynamo/inductor are not installed. \n For GPU please run pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117 \n for CPU please run pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu" + ) + pt2_enabled = False + return pt2_enabled + + ipex_enabled = False if os.environ.get("TS_IPEX_ENABLE", "false") == "true": try: @@ -112,8 +122,11 @@ def initialize(self, context): self.model = self._load_torchscript_model(model_pt_path) self.model.eval() - if dynamo_enabled: - torch._dynamo.optimize(dynamo_backend if dynamo_backend in DynamoBackend else DynamoBackend.INDUCTOR)(self.model) + optimization_config = os.path.join(model_dir, "compile.json") + backend = load_compiler_config(optimization_config) + if check_pt2_enabled() and self.backend: + # Compilation will delay your model initialization + torch.compile(self.model, backend=backend) if ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) self.model = ipex.optimize(self.model) @@ -297,9 +310,7 @@ def _infer_with_profiler(self, data): self.profiler_args[ "on_trace_ready" ] = torch.profiler.tensorboard_trace_handler(result_path) - logger.info( - "Saving chrome trace to : %s", result_path - ) + logger.info("Saving chrome trace to : %s", result_path) with profile(**self.profiler_args) as prof: with record_function("preprocess"): diff --git a/ts/torch_handler/utils.py b/ts/torch_handler/utils.py index 7de746b445..e7b1daabd0 100644 --- a/ts/torch_handler/utils.py +++ b/ts/torch_handler/utils.py @@ -1,7 +1,7 @@ import enum -@add_mapping -class DynamoBackend(str, enum.Enum): + +class CompilerBackend(str, enum.Enum): EAGER = "eager" AOT_EAGER = "aot_eager" INDUCTOR = "inductor" @@ -12,9 +12,3 @@ class DynamoBackend(str, enum.Enum): FX2TRT = "fx2trt" ONNXRT = "onnxrt" IPEX = "ipex" - -def add_mapping(enum_cls): - for name in enum_cls.__MAPPING__: - member = enum_cls.__members__[name] - enum_cls.__MAPPING__[name] = member - return enum_cls \ No newline at end of file diff --git a/ts/utils/util.py b/ts/utils/util.py index 9a779363b1..59e275df73 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -1,6 +1,7 @@ """ Utility functions for TorchServe """ +import enum import inspect import itertools import json @@ -8,6 +9,20 @@ import os import re + +class PT2Backend(str, enum.Enum): + EAGER = "eager" + AOT_EAGER = "aot_eager" + INDUCTOR = "inductor" + NVFUSER = "nvfuser" + AOT_NVFUSER = "aot_nvfuser" + AOT_CUDAGRAPHS = "aot_cudagraphs" + OFI = "ofi" + FX2TRT = "fx2trt" + ONNXRT = "onnxrt" + IPEX = "ipex" + + logger = logging.getLogger(__name__) CLEANUP_REGEX = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});") @@ -38,6 +53,26 @@ def list_classes_from_module(module, parent_class=None): return classes +def load_compiler_config(config_file_path): + """ + Load a compiler {compiler_name -> compiler } + Can be extended to also support kwargs for ONNX and TensorRT + """ + if not os.path.isfile(config_file_path): + logger.warning(f"{config_file_path} is missing. PT 2.0 will not be used") + return None + + with open(config_file_path) as f: + mapping = json.load(f) + + backend_values = [member.value for member in PT2Backend] + if mapping["pt2"] in backend_values: + return mapping["pt2"] + else: + logger.warning(f"{mapping['pt2']} is not a supported backend") + return mapping["pt2"] + + def load_label_mapping(mapping_file_path): """ Load a JSON mapping { class ID -> friendly class name }. From d582ef9dc2989f9c64db2f923e5673b8ffef5a96 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:44:38 +0000 Subject: [PATCH 09/27] push --- examples/pt2/README.md | 8 ++++++++ ts/torch_handler/base_handler.py | 10 +++++++++- ts/torch_handler/utils.py | 14 -------------- 3 files changed, 17 insertions(+), 15 deletions(-) delete mode 100644 ts/torch_handler/utils.py diff --git a/examples/pt2/README.md b/examples/pt2/README.md index 39fcdc0c54..dc2f51a70c 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -12,3 +12,11 @@ pip install torchserve ``` ## Package your model + +PyTorch 2.0 supports several compiler backends and you pick which one you want by passing in an optional file `compile.json` during your model packaging + +`{"pt2" : "inductor"}` + +## Next steps + +For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index bc46130532..0cb48d29c4 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -124,9 +124,17 @@ def initialize(self, context): self.model.eval() optimization_config = os.path.join(model_dir, "compile.json") backend = load_compiler_config(optimization_config) + + # PT 2.0 support is opt in if check_pt2_enabled() and self.backend: # Compilation will delay your model initialization - torch.compile(self.model, backend=backend) + try: + torch.compile(self.model, backend=backend) + logger.info(f"Compiled {self.model} with backend {backend}") + except: + logger.warning( + f"Compiling model {self.model} with backend {backend} has failed \n Proceeding without compilation" + ) if ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) self.model = ipex.optimize(self.model) diff --git a/ts/torch_handler/utils.py b/ts/torch_handler/utils.py deleted file mode 100644 index e7b1daabd0..0000000000 --- a/ts/torch_handler/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -import enum - - -class CompilerBackend(str, enum.Enum): - EAGER = "eager" - AOT_EAGER = "aot_eager" - INDUCTOR = "inductor" - NVFUSER = "nvfuser" - AOT_NVFUSER = "aot_nvfuser" - AOT_CUDAGRAPHS = "aot_cudagraphs" - OFI = "ofi" - FX2TRT = "fx2trt" - ONNXRT = "onnxrt" - IPEX = "ipex" From 6301e04eece2e1495b90521f079287a98690fc43 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 23:28:36 +0000 Subject: [PATCH 10/27] update --- examples/pt2/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index dc2f51a70c..46f1219971 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -4,11 +4,11 @@ PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean be ## Get started -Install torchserve +Install torchserve with nightly torch binaries ``` -python ts_scripts/dependencies.py --cuda=cu117 --nightly -pip install torchserve +python ts_scripts/dependencies.py --cuda=cu117 --nightly_torch +pip install torchserve torch-model-archiver ``` ## Package your model From 797dc075932764e36672d6240bba83fa1e640223 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 23:34:59 +0000 Subject: [PATCH 11/27] update --- examples/pt2/README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index 46f1219971..cc8697c6f9 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -17,6 +17,33 @@ PyTorch 2.0 supports several compiler backends and you pick which one you want b `{"pt2" : "inductor"}` +As an example let's expand our getting started guidde with the only difference being passing in the extra `compile.json` file + +``` +mkdir model_store +torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json,./serve/examples/image_classifier/compile.json --handler image_classifier +torchserve --start --ncs --model-store model_store --models densenet161.mar +``` + +The exact same approach works with any other mdoel, what's going on is the beelow + +```python +# 1. Convert a regular module to an optimized module +opt_mod = torch.compile(mod) +# 2. Train the optimized module +# .... +# 3. Save the original module (weights are shared) +torch.save(model, "model.pt") + +# 4. Load the non optimized model +mod = torch.load(model) + +# 5. Compile the module and then run inferences with it +opt_mod = torch.compile(mod) +``` + +torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models. + ## Next steps For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. From d74162a3a05413a02095b4062e69685de91c87a7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 23:52:46 +0000 Subject: [PATCH 12/27] updates --- examples/image_classifier/compile.json | 1 + examples/pt2/README.md | 2 +- ts/torch_handler/base_handler.py | 23 ++++++++++++----------- 3 files changed, 14 insertions(+), 12 deletions(-) create mode 100644 examples/image_classifier/compile.json diff --git a/examples/image_classifier/compile.json b/examples/image_classifier/compile.json new file mode 100644 index 0000000000..c4f0ac85c7 --- /dev/null +++ b/examples/image_classifier/compile.json @@ -0,0 +1 @@ +{"pt2" : "inductor"} diff --git a/examples/pt2/README.md b/examples/pt2/README.md index cc8697c6f9..70a5102cd8 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -7,7 +7,7 @@ PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean be Install torchserve with nightly torch binaries ``` -python ts_scripts/dependencies.py --cuda=cu117 --nightly_torch +python ts_scripts/install_dependencies.py --cuda=cu117 --nightly_torch pip install torchserve torch-model-archiver ``` diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 783eb69fd4..6791d842df 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -119,17 +119,18 @@ def initialize(self, context): serialized_file = self.manifest["model"]["serializedFile"] self.model_pt_path = os.path.join(model_dir, serialized_file) - if self.model_pt_path.endswith("onnx"): - try: - # import numpy as np - import onnxruntime as ort - import psutil - - onnx_enabled = True - logger.info("ONNX enabled") - except ImportError as error: - onnx_enabled = False - logger.warning("proceeding without onnxruntime") + if self.model_pt_path: + if self.model_pt_path.endswith("onnx"): + try: + # import numpy as np + import onnxruntime as ort + import psutil + + onnx_enabled = True + logger.info("ONNX enabled") + except ImportError as error: + onnx_enabled = False + logger.warning("proceeding without onnxruntime") # model def file model_file = self.manifest["model"].get("modelFile", "") From 2a28a7653600b9cf41096a837f5fc41290c8dd39 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 23:53:24 +0000 Subject: [PATCH 13/27] push --- ts/torch_handler/base_handler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 6791d842df..46c9dfc98b 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -32,8 +32,7 @@ # Possible values for backend in utils.py def check_pt2_enabled(): try: - import torch.compile - + import torch._dynamo pt2_enabled = True if torch.cuda.is_available(): # If Ampere enable tensor cores and ideally get yourself an A10G or A100 From 999f41794a05f033fc6c145590b9552f982e9df3 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 23:57:14 +0000 Subject: [PATCH 14/27] push --- ts/torch_handler/base_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 46c9dfc98b..1326900cbc 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -176,14 +176,14 @@ def initialize(self, context): backend = load_compiler_config(optimization_config) # PT 2.0 support is opt in - if check_pt2_enabled() and self.backend: + if check_pt2_enabled() and backend: # Compilation will delay your model initialization try: torch.compile(self.model, backend=backend) - logger.info(f"Compiled {self.model} with backend {backend}") + logger.info(f"Compiled model with backend {backend}") except: logger.warning( - f"Compiling model {self.model} with backend {backend} has failed \n Proceeding without compilation" + f"Compiling model model with backend {backend} has failed \n Proceeding without compilation" ) if ipex_enabled: From e9644eab6abe843b4faa575a969e11930100d9e7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 00:01:26 +0000 Subject: [PATCH 15/27] fixes --- examples/pt2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index 70a5102cd8..0ff6db0d06 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -17,7 +17,7 @@ PyTorch 2.0 supports several compiler backends and you pick which one you want b `{"pt2" : "inductor"}` -As an example let's expand our getting started guidde with the only difference being passing in the extra `compile.json` file +As an example let's expand our getting started gui de with the only difference being passing in the extra `compile.json` file ``` mkdir model_store @@ -25,7 +25,7 @@ torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve torchserve --start --ncs --model-store model_store --models densenet161.mar ``` -The exact same approach works with any other mdoel, what's going on is the beelow +The exact same approach works with any other model, what's going on is the below ```python # 1. Convert a regular module to an optimized module From 682f03e5d922f4c9f44792b0c123c55c189d683b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 00:07:59 +0000 Subject: [PATCH 16/27] updates --- examples/pt2/README.md | 2 +- ts/torch_handler/base_handler.py | 2 +- ts_scripts/spellcheck_conf/wordlist.txt | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index 0ff6db0d06..47cad2b81c 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -46,4 +46,4 @@ torchserve takes care of 4 and 5 for you while the remaining steps are your resp ## Next steps -For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. +For now PyTorch 2.0 has mostly been focused on accelerating training so production grade applications should instead opt for TensorRT for accelerated inference performance which is also natively supported in torchserve. We just wanted to make it really easy for users to experiment with the PyTorch 2.0 stack. You can learn more here https://github.com/pytorch/serve/blob/master/docs/performance_guide.md diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 1326900cbc..32b0a15a72 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -179,7 +179,7 @@ def initialize(self, context): if check_pt2_enabled() and backend: # Compilation will delay your model initialization try: - torch.compile(self.model, backend=backend) + self.model = torch.compile(self.model, backend=backend) logger.info(f"Compiled model with backend {backend}") except: logger.warning( diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 70c1c52097..23c7200cad 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1001,3 +1001,6 @@ sess InferenceTimeInMS MetricTypes MetricsCache +TIMM +backends +inductor \ No newline at end of file From 299803cde256f07fc82a9f18ef2db2aa59953f4e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 16:10:09 -0800 Subject: [PATCH 17/27] Update README.md --- examples/pt2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index 47cad2b81c..eb11c6045c 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -17,7 +17,7 @@ PyTorch 2.0 supports several compiler backends and you pick which one you want b `{"pt2" : "inductor"}` -As an example let's expand our getting started gui de with the only difference being passing in the extra `compile.json` file +As an example let's expand our getting started guide with the only difference being passing in the extra `compile.json` file ``` mkdir model_store From 85077b668c2b3d163bb2cabcbb1194171fd16245 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:21:28 -0800 Subject: [PATCH 18/27] Update install_dependencies.py --- ts_scripts/install_dependencies.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index ff4cc2b9ec..8586617b2d 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -55,6 +55,9 @@ def install_python_packages(self, cuda_version, requirements_file_path, nightly) # developer.txt also installs packages from common.txt os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}") # If conda is available install conda-build package + + # TODO: This will run 2 installations for torch but to make this cleaner we should first refactor all of our requirements.txt into just 2 files + # And then make torch an optional dependency for the common.txt if nightly: os.system( f"pip3 install numpy --pre torch[dynamo] torchvision torchtext torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117" From 31f38951d3fc3b392b0047a9b528b273c1a4ed87 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:21:50 -0800 Subject: [PATCH 19/27] Update util.py --- ts/utils/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/utils/util.py b/ts/utils/util.py index 0a9027a03a..c87661ef2b 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -59,7 +59,7 @@ def load_compiler_config(config_file_path): Can be extended to also support kwargs for ONNX and TensorRT """ if not os.path.isfile(config_file_path): - logger.warning(f"{config_file_path} is missing. PT 2.0 will not be used") + logger.info(f"{config_file_path} is missing. PT 2.0 will not be used") return None with open(config_file_path) as f: From d7dc89ba2099602397ed9ae1fb4af2661150b20e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 5 Dec 2022 22:22:32 -0800 Subject: [PATCH 20/27] Update README.md --- examples/pt2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index eb11c6045c..a8411be1b6 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -1,6 +1,6 @@ ## PyTorch 2.x integration -PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the PyTorch 1.14 release. +PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the official release and while we are relying on the nightly builds. ## Get started From 723974d3a9b5ba52857ea118aa0528dab6d55d84 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 08:08:09 -0800 Subject: [PATCH 21/27] Update base_handler.py --- ts/torch_handler/base_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 32b0a15a72..b6ae5b85a2 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -35,7 +35,8 @@ def check_pt2_enabled(): import torch._dynamo pt2_enabled = True if torch.cuda.is_available(): - # If Ampere enable tensor cores and ideally get yourself an A10G or A100 + # If Ampere enable tensor cores which will give better performance + # Ideally get yourself an A10G or A100 for optimal performance if torch.cuda.get_device_capability() >= (8, 0): torch.backends.cuda.matmul.allow_tf32 = True except ImportError as error: @@ -179,7 +180,7 @@ def initialize(self, context): if check_pt2_enabled() and backend: # Compilation will delay your model initialization try: - self.model = torch.compile(self.model, backend=backend) + self.model = torch.compile(self.model, backend=backend, mode="reduce-overhead") logger.info(f"Compiled model with backend {backend}") except: logger.warning( From 0dc2928869ddb783aac0588bd77234ef6d061dbf Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 20:16:37 +0000 Subject: [PATCH 22/27] update --- examples/pt2/README.md | 2 ++ ts/torch_handler/base_handler.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/pt2/README.md b/examples/pt2/README.md index a8411be1b6..1a6be197f0 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -2,6 +2,8 @@ PyTorch 2.0 brings more compiler options to PyTorch, for you that should mean better perf either in the form of lower latency or lower memory consumption. Integrating PyTorch 2.0 is fairly trivial but for now the support will be experimental until the official release and while we are relying on the nightly builds. +We strongly recommend you leverage newer hardware so for GPUs that would be an Ampere architecture. You'll get even more benefits from using server GPU deployments like A10G and A100 vs consumer cards. But you should expect to see some speedups for any Volta or Ampere architecture. + ## Get started Install torchserve with nightly torch binaries diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index b6ae5b85a2..64445fff0f 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -187,7 +187,7 @@ def initialize(self, context): f"Compiling model model with backend {backend} has failed \n Proceeding without compilation" ) - if ipex_enabled: + elif ipex_enabled: self.model = self.model.to(memory_format=torch.channels_last) self.model = ipex.optimize(self.model) From b7ef6733a49aa646b06b0db6a6e80017a0c6617f Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 20:17:34 +0000 Subject: [PATCH 23/27] push --- ts_scripts/install_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index 8586617b2d..e3288ffeb5 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -60,7 +60,7 @@ def install_python_packages(self, cuda_version, requirements_file_path, nightly) # And then make torch an optional dependency for the common.txt if nightly: os.system( - f"pip3 install numpy --pre torch[dynamo] torchvision torchtext torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117" + f"pip3 install numpy --pre torch[dynamo] torchvision torchtext torchaudio --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/{cuda_version}" ) def install_node_packages(self): From c7bb068c7faa0bdea45073aac312b8eab89b9b0b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 12:59:19 -0800 Subject: [PATCH 24/27] Update util.py --- ts/utils/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/utils/util.py b/ts/utils/util.py index c87661ef2b..eae542ce9a 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -70,7 +70,7 @@ def load_compiler_config(config_file_path): return mapping["pt2"] else: logger.warning(f"{mapping['pt2']} is not a supported backend") - return mapping["pt2"] + return None def load_label_mapping(mapping_file_path): From ab4867b0b53deceafda0d9dc173250aef7ea7092 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 13:36:54 -0800 Subject: [PATCH 25/27] lint --- ts/torch_handler/base_handler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ts/torch_handler/base_handler.py b/ts/torch_handler/base_handler.py index 64445fff0f..1c8eafce38 100644 --- a/ts/torch_handler/base_handler.py +++ b/ts/torch_handler/base_handler.py @@ -18,7 +18,6 @@ load_label_mapping, ) - if packaging.version.parse(torch.__version__) >= packaging.version.parse("1.8.1"): from torch.profiler import ProfilerActivity, profile, record_function @@ -33,6 +32,7 @@ def check_pt2_enabled(): try: import torch._dynamo + pt2_enabled = True if torch.cuda.is_available(): # If Ampere enable tensor cores which will give better performance @@ -58,6 +58,7 @@ def check_pt2_enabled(): "IPEX is enabled but intel-extension-for-pytorch is not installed. Proceeding without IPEX." ) + class BaseHandler(abc.ABC): """ Base default handler to load torchscript or eager mode [state_dict] models @@ -180,7 +181,9 @@ def initialize(self, context): if check_pt2_enabled() and backend: # Compilation will delay your model initialization try: - self.model = torch.compile(self.model, backend=backend, mode="reduce-overhead") + self.model = torch.compile( + self.model, backend=backend, mode="reduce-overhead" + ) logger.info(f"Compiled model with backend {backend}") except: logger.warning( From a57299642ec8653fa2c4ccc69d7c484517c9e69c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 14:18:54 -0800 Subject: [PATCH 26/27] lint --- ts/utils/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/utils/util.py b/ts/utils/util.py index eae542ce9a..629f274008 100644 --- a/ts/utils/util.py +++ b/ts/utils/util.py @@ -118,7 +118,7 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None): """ if not isinstance(probs, list): raise Exception("Convert classes to list before doing mapping") - + if mapping is not None and not isinstance(mapping, dict): raise Exception("Mapping must be a dict") From 2e5c21515deec9a5fe6515612a669b53d0e81126 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 6 Dec 2022 14:29:31 -0800 Subject: [PATCH 27/27] lint --- ts_scripts/install_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts_scripts/install_dependencies.py b/ts_scripts/install_dependencies.py index e3288ffeb5..86a7a9755b 100644 --- a/ts_scripts/install_dependencies.py +++ b/ts_scripts/install_dependencies.py @@ -55,7 +55,7 @@ def install_python_packages(self, cuda_version, requirements_file_path, nightly) # developer.txt also installs packages from common.txt os.system(f"{sys.executable} -m pip install -U -r {requirements_file_path}") # If conda is available install conda-build package - + # TODO: This will run 2 installations for torch but to make this cleaner we should first refactor all of our requirements.txt into just 2 files # And then make torch an optional dependency for the common.txt if nightly: