Skip to content

Commit

Permalink
feat: Upgrade torch version to 2.2.1 (#1374)
Browse files Browse the repository at this point in the history
Co-authored-by: yyhhyy <yyhhyyyyyy@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent bb77e13 commit 5a0c20a
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 56 deletions.
2 changes: 1 addition & 1 deletion docker/base/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
ARG BASE_IMAGE="nvidia/cuda:12.1.0-runtime-ubuntu22.04"

FROM ${BASE_IMAGE}
ARG BASE_IMAGE
Expand Down
4 changes: 2 additions & 2 deletions docker/base/build_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd)

BASE_IMAGE_DEFAULT="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
BASE_IMAGE_DEFAULT="nvidia/cuda:12.1.0-runtime-ubuntu22.04"
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"

BASE_IMAGE=$BASE_IMAGE_DEFAULT
Expand All @@ -21,7 +21,7 @@ BUILD_NETWORK=""
DB_GPT_INSTALL_MODEL="default"

usage () {
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-runtime-ubuntu22.04] [--image-name db-gpt]"
echo "USAGE: $0 [--base-image nvidia/cuda:12.1.0-runtime-ubuntu22.04] [--image-name db-gpt]"
echo " [-b|--base-image base image name] Base image name"
echo " [-n|--image-name image name] Current image name, default: db-gpt"
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
Expand Down
151 changes: 98 additions & 53 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import shutil
import subprocess
import sys
import urllib.request
from enum import Enum
from typing import Callable, List, Optional, Tuple
Expand Down Expand Up @@ -40,15 +41,22 @@ def parse_requirements(file_name: str) -> List[str]:
]


def find_python():
python_path = sys.executable
print(python_path)
if not python_path:
print("Python command not found.")
return None
return python_path


def get_latest_version(package_name: str, index_url: str, default_version: str):
python_command = shutil.which("python")
python_command = find_python()
if not python_command:
python_command = shutil.which("python3")
if not python_command:
print("Python command not found.")
return default_version
print("Python command not found.")
return default_version

command = [
command_index_versions = [
python_command,
"-m",
"pip",
Expand All @@ -59,20 +67,40 @@ def get_latest_version(package_name: str, index_url: str, default_version: str):
index_url,
]

result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
print("Error executing command.")
print(result.stderr.decode())
return default_version
result_index_versions = subprocess.run(
command_index_versions, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if result_index_versions.returncode == 0:
output = result_index_versions.stdout.decode()
lines = output.split("\n")
for line in lines:
if "Available versions:" in line:
available_versions = line.split(":")[1].strip()
latest_version = available_versions.split(",")[0].strip()
# Query for compatibility with the latest version of torch
if package_name == "torch" or "torchvision":
latest_version = latest_version.split("+")[0]
return latest_version
else:
command_simulate_install = [
python_command,
"-m",
"pip",
"install",
f"{package_name}==",
]

output = result.stdout.decode()
lines = output.split("\n")
for line in lines:
if "Available versions:" in line:
available_versions = line.split(":")[1].strip()
latest_version = available_versions.split(",")[0].strip()
result_simulate_install = subprocess.run(
command_simulate_install, stderr=subprocess.PIPE
)
print(result_simulate_install)
stderr_output = result_simulate_install.stderr.decode()
print(stderr_output)
match = re.search(r"from versions: (.+?)\)", stderr_output)
if match:
available_versions = match.group(1).split(", ")
latest_version = available_versions[-1].strip()
return latest_version

return default_version


Expand Down Expand Up @@ -227,7 +255,7 @@ def _build_wheels(
base_url: str = None,
base_url_func: Callable[[str, str, str], str] = None,
pkg_file_func: Callable[[str, str, str, str, OSType], str] = None,
supported_cuda_versions: List[str] = ["11.7", "11.8"],
supported_cuda_versions: List[str] = ["11.8", "12.1"],
) -> Optional[str]:
"""
Build the URL for the package wheel file based on the package name, version, and CUDA version.
Expand All @@ -248,11 +276,17 @@ def _build_wheels(
py_version = "cp" + "".join(py_version.split(".")[0:2])
if os_type == OSType.DARWIN or not cuda_version:
return None
if cuda_version not in supported_cuda_versions:

if cuda_version in supported_cuda_versions:
cuda_version = cuda_version
else:
print(
f"Warnning: {pkg_name} supported cuda version: {supported_cuda_versions}, replace to {supported_cuda_versions[-1]}"
f"Warning: Your CUDA version {cuda_version} is not in our set supported_cuda_versions , we will use our set version."
)
cuda_version = supported_cuda_versions[-1]
if cuda_version < "12.1":
cuda_version = supported_cuda_versions[0]
else:
cuda_version = supported_cuda_versions[-1]

cuda_version = "cu" + cuda_version.replace(".", "")
os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
Expand All @@ -273,55 +307,57 @@ def _build_wheels(


def torch_requires(
torch_version: str = "2.0.1",
torchvision_version: str = "0.15.2",
torchaudio_version: str = "2.0.2",
torch_version: str = "2.2.1",
torchvision_version: str = "0.17.1",
torchaudio_version: str = "2.2.1",
):
os_type, _ = get_cpu_avx_support()
torch_pkgs = [
f"torch=={torch_version}",
f"torchvision=={torchvision_version}",
f"torchaudio=={torchaudio_version}",
]
torch_cuda_pkgs = []
os_type, _ = get_cpu_avx_support()
# Initialize torch_cuda_pkgs for non-Darwin OSes;
# it will be the same as torch_pkgs for Darwin or when no specific CUDA handling is needed
torch_cuda_pkgs = torch_pkgs[:]

if os_type != OSType.DARWIN:
cuda_version = get_cuda_version()
if cuda_version:
supported_versions = ["11.7", "11.8"]
# torch_url = f"https://download.pytorch.org/whl/{cuda_version}/torch-{torch_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
# torchvision_url = f"https://download.pytorch.org/whl/{cuda_version}/torchvision-{torchvision_version}+{cuda_version}-{py_version}-{py_version}-{os_pkg_name}.whl"
torch_url = _build_wheels(
"torch",
torch_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
supported_cuda_versions=supported_versions,
)
torchvision_url = _build_wheels(
"torchvision",
torchvision_version,
base_url_func=lambda v, x, y: f"https://download.pytorch.org/whl/{x}",
supported_cuda_versions=supported_versions,
)
supported_versions = ["11.8", "12.1"]
base_url_func = lambda v, x, y: f"https://download.pytorch.org/whl/{x}"
torch_url = _build_wheels(
"torch",
torch_version,
base_url_func=base_url_func,
supported_cuda_versions=supported_versions,
)
torchvision_url = _build_wheels(
"torchvision",
torchvision_version,
base_url_func=base_url_func,
supported_cuda_versions=supported_versions,
)

# Cache and add CUDA-dependent packages if URLs are available
if torch_url:
torch_url_cached = cache_package(
torch_url, "torch", os_type == OSType.WINDOWS
)
torch_cuda_pkgs[0] = f"torch @ {torch_url_cached}"
if torchvision_url:
torchvision_url_cached = cache_package(
torchvision_url, "torchvision", os_type == OSType.WINDOWS
)
torch_cuda_pkgs[1] = f"torchvision @ {torchvision_url_cached}"

torch_cuda_pkgs = [
f"torch @ {torch_url_cached}",
f"torchvision @ {torchvision_url_cached}",
f"torchaudio=={torchaudio_version}",
]

# Assuming 'setup_spec' is a dictionary where we're adding these dependencies
setup_spec.extras["torch"] = torch_pkgs
setup_spec.extras["torch_cpu"] = torch_pkgs
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs


def llama_cpp_python_cuda_requires():
cuda_version = get_cuda_version()
supported_cuda_versions = ["11.8", "12.1"]
device = "cpu"
if not cuda_version:
print("CUDA not support, use cpu version")
Expand All @@ -330,7 +366,10 @@ def llama_cpp_python_cuda_requires():
print("Disable GPU acceleration")
return
# Supports GPU acceleration
device = "cu" + cuda_version.replace(".", "")
if cuda_version <= "11.8" and not None:
device = "cu" + supported_cuda_versions[0].replace(".", "")
else:
device = "cu" + supported_cuda_versions[-1].replace(".", "")
os_type, cpu_avx = get_cpu_avx_support()
print(f"OS: {os_type}, cpu avx: {cpu_avx}")
supported_os = [OSType.WINDOWS, OSType.LINUX]
Expand All @@ -346,7 +385,7 @@ def llama_cpp_python_cuda_requires():
cpu_device = "basic"
device += cpu_device
base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
llama_cpp_version = "0.2.10"
llama_cpp_version = "0.2.26"
py_version = "cp310"
os_pkg_name = "manylinux_2_31_x86_64" if os_type == OSType.LINUX else "win_amd64"
extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}-{py_version}-{py_version}-{os_pkg_name}.whl"
Expand Down Expand Up @@ -493,7 +532,13 @@ def quantization_requires():
# autoawq requirements:
# 1. Compute Capability 7.5 (sm75). Turing and later architectures are supported.
# 2. CUDA Toolkit 11.8 and later.
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
cuda_version = get_cuda_version()
autoawq_latest_version = get_latest_version("autoawq", "", "0.2.4")
if cuda_version is None or cuda_version == "12.1":
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])
else:
# TODO(yyhhyy): Add autoawq install method for CUDA version 11.8
quantization_pkgs.extend(["autoawq", _build_autoawq_requires(), "optimum"])

setup_spec.extras["quantization"] = ["cpm_kernels"] + quantization_pkgs

Expand Down

0 comments on commit 5a0c20a

Please sign in to comment.