Skip to content

Commit

Permalink
Add support for Torch-TensorRT in Docker
Browse files Browse the repository at this point in the history
- Add install for Torch-TRT nightly
- Add install validation
  • Loading branch information
gs-olive committed Sep 26, 2023
1 parent 64409d5 commit ba8722d
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion utils/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,39 @@ def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
pytorch_nightly_url = f"https://download.pytorch.org/whl/nightly/{CUDA_VERSION_MAP[cuda_version]['pytorch_url']}"
install_torch_cmd = ["pip", "install", "--pre", "--no-cache-dir"]
install_torch_cmd.extend(TORCHBENCH_TORCH_NIGHTLY_PACKAGES)
install_torch_cmd.extend(["-i", pytorch_nightly_url])
install_torch_cmd.extend(["-i", pytorch_nightly_url])
if dryrun:
print(f"Install pytorch nightly: {install_torch_cmd}")
else:
subprocess.check_call(install_torch_cmd, env=env)

# Install Torch-TensorRT with validation
uninstall_torchtrt_cmd = ["pip", "uninstall", "-y", "torch_tensorrt"]
if dryrun:
print(f"Uninstall torch-tensorrt: {uninstall_torchtrt_cmd}")
else:
subprocess.check_call(uninstall_torchtrt_cmd)

install_torchtrt_cmd = [
"pip",
"install",
"--pre",
"--no-cache-dir",
"torch_tensorrt",
"--extra-index-url",
pytorch_nightly_url,
]
validate_torchtrt_cmd = ["python", "-c", "'import torch_tensorrt'"]
if dryrun:
print(f"Install torch-tensorrt nightly: {install_torchtrt_cmd}")
else:
try:
subprocess.check_call(install_torchtrt_cmd, env=env)
subprocess.check_call(validate_torchtrt_cmd, env=env)
except subprocess.CalledProcessError:
print(f"Failed to install torch-tensorrt, skipping install")
pass

def install_torch_deps(cuda_version: str):
# install magma
magma_pkg = CUDA_VERSION_MAP[cuda_version]["magma_version"]
Expand Down

0 comments on commit ba8722d

Please sign in to comment.