From ef770cfed51ad2216e04a38e820ecd00efc2c7db Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 19 Jun 2024 15:14:18 -0400 Subject: [PATCH] Add package contraints to torchbench --- install.py | 19 +++++++++++++++---- torchbenchmark/__init__.py | 4 ++-- torchbenchmark/util/env_check.py | 9 +++++---- utils/__init__.py | 26 ++++++++++++++++++-------- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/install.py b/install.py index 943de9dc11..63930a33b7 100644 --- a/install.py +++ b/install.py @@ -5,7 +5,7 @@ from pathlib import Path from userbenchmark import list_userbenchmarks -from utils import get_pkg_versions, TORCH_DEPS +from utils import get_pkg_versions, TORCH_DEPS, generate_pkg_constraints REPO_ROOT = Path(__file__).parent @@ -38,6 +38,11 @@ def pip_install_requirements(requirements_txt="requirements.txt"): action="store_true", help="Run in test mode and check package versions", ) + parser.add_argument( + "--check-only", + action="store_true", + help="Only run the version check and generate the contraints" + ) parser.add_argument("--canary", action="store_true", help="Install canary model.") parser.add_argument("--continue_on_fail", action="store_true") parser.add_argument("--verbose", "-v", action="store_true") @@ -51,12 +56,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"): os.chdir(os.path.realpath(os.path.dirname(__file__))) print( - f"checking packages {', '.join(TORCH_DEPS)} are installed...", + f"checking packages {', '.join(TORCH_DEPS)} are installed, generating constaints...", end="", flush=True, ) if args.userbenchmark: - TORCH_DEPS = ["torch"] + TORCH_DEPS = ["numpy", "torch"] try: versions = get_pkg_versions(TORCH_DEPS) except ModuleNotFoundError as e: @@ -65,8 +70,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"): f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark." ) sys.exit(-1) + generate_pkg_constraints(versions) print("OK") + if args.check_only: + exit(0) + if args.userbenchmark: # Install userbenchmark dependencies if exists userbenchmark_dir = REPO_ROOT.joinpath("userbenchmark", args.userbenchmark) @@ -101,7 +110,9 @@ def pip_install_requirements(requirements_txt="requirements.txt"): new_versions = get_pkg_versions(TORCH_DEPS) if versions != new_versions: print( - f"The torch packages are re-installed after installing the benchmark deps. \ + f"The numpy and torch package versions become inconsistent after installing the benchmark deps. \ Before: {versions}, after: {new_versions}" ) sys.exit(-1) + else: + print(f"installed torchbench with package constraints: {versions}") diff --git a/torchbenchmark/__init__.py b/torchbenchmark/__init__.py index 142b547f4c..1bbdcb4f05 100644 --- a/torchbenchmark/__init__.py +++ b/torchbenchmark/__init__.py @@ -181,10 +181,10 @@ def setup( versions = get_pkg_versions(TORCH_DEPS) success, errmsg, stdout_stderr = _install_deps(model_path, verbose=verbose) if test_mode: - new_versions = get_pkg_versions(TORCH_DEPS, reload=True) + new_versions = get_pkg_versions(TORCH_DEPS) if versions != new_versions: print( - f"The torch packages are re-installed after installing the benchmark model {model_path}. \ + f"The numpy and torch packages are re-installed after installing the benchmark model {model_path}. \ Before: {versions}, after: {new_versions}" ) sys.exit(-1) diff --git a/torchbenchmark/util/env_check.py b/torchbenchmark/util/env_check.py index 9f97babb03..bb9755e0b2 100644 --- a/torchbenchmark/util/env_check.py +++ b/torchbenchmark/util/env_check.py @@ -3,9 +3,7 @@ This file may be loaded without torch packages installed, e.g., in OnDemand CI. """ -import argparse import copy -import importlib import os import shutil import argparse @@ -187,10 +185,13 @@ def deterministic_torch_manual_seed(*args, **kwargs): def get_pkg_versions(packages: List[str]) -> Dict[str, str]: + import sys + import subprocess versions = {} for module in packages: - module = importlib.import_module(module) - versions[module] = module.__version__ + cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)'] + version = subprocess.check_output(cmd).decode().strip() + versions[module] = version return versions diff --git a/utils/__init__.py b/utils/__init__.py index d117c848a8..49e32c8f7e 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,8 +1,10 @@ -import importlib import sys +import subprocess from typing import Dict, List +from pathlib import Path -TORCH_DEPS = ["torch", "torchvision", "torchaudio"] +REPO_DIR = Path(__file__).parent.parent +TORCH_DEPS = ["numpy", "torch", "torchvision", "torchaudio"] class add_path: @@ -18,12 +20,20 @@ def __exit__(self, exc_type, exc_value, traceback): except ValueError: pass - -def get_pkg_versions(packages: List[str], reload: bool = False) -> Dict[str, str]: +def get_pkg_versions(packages: List[str]) -> Dict[str, str]: versions = {} for module in packages: - module = importlib.import_module(module) - if reload: - module = importlib.reload(module) - versions[module.__name__] = module.__version__ + cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)'] + version = subprocess.check_output(cmd).decode().strip() + versions[module] = version return versions + +def generate_pkg_constraints(package_versions: Dict[str, str]): + """ + Generate package versions dict and save them to {REPO_ROOT}/build/constraints.txt + """ + output_dir = REPO_DIR.joinpath("build") + output_dir.mkdir(exist_ok=True) + with open(output_dir.joinpath("constraints.txt"), "w") as fp: + for k, v in package_versions.items(): + fp.write(f"{k}=={v}\n")