From 3f84f62fc9cdc238557559696042fbfaba5aec95 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 18 Sep 2023 19:31:50 -0700 Subject: [PATCH] Add support for Torch-TensorRT in Docker - Add install for Torch-TRT nightly - Add install validation --- utils/cuda_utils.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/utils/cuda_utils.py b/utils/cuda_utils.py index 7015865361..ba482d4634 100644 --- a/utils/cuda_utils.py +++ b/utils/cuda_utils.py @@ -82,12 +82,39 @@ def install_pytorch_nightly(cuda_version: str, env, dryrun=False): pytorch_nightly_url = f"https://download.pytorch.org/whl/nightly/{CUDA_VERSION_MAP[cuda_version]['pytorch_url']}" install_torch_cmd = ["pip", "install", "--pre", "--no-cache-dir"] install_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES) - install_torch_cmd.extend(["-i", pytorch_nightly_url]) + install_torch_cmd.extend(["-i", pytorch_nightly_url]) if dryrun: print(f"Install pytorch nightly: {install_torch_cmd}") else: subprocess.check_call(install_torch_cmd, env=env) + # Install Torch-TensorRT with validation + uninstall_torchtrt_cmd = ["pip", "uninstall", "-y", "torch_tensorrt"] + if dryrun: + print(f"Uninstall torch-tensorrt: {uninstall_torchtrt_cmd}") + else: + subprocess.check_call(uninstall_torchtrt_cmd) + + install_torchtrt_cmd = [ + "pip", + "install", + "--pre", + "--no-cache-dir", + "torch_tensorrt", + "--extra-index-url", + pytorch_nightly_url, + ] + validate_torchtrt_cmd = ["python", "-c", "'import torch_tensorrt'"] + if dryrun: + print(f"Install torch-tensorrt nightly: {install_torchtrt_cmd}") + else: + try: + subprocess.check_call(install_torch_cmd, env=env) + subprocess.check_call(validate_torchtrt_cmd, env=env) + except subprocess.CalledProcessError: + print(f"Failed to install torch-tensorrt, skipping install") + pass + def install_torch_deps(cuda_version: str): # install magma magma_pkg = CUDA_VERSION_MAP[cuda_version]["magma_version"]