From f30697decc03d6813eef137f240f3124e69dbd10 Mon Sep 17 00:00:00 2001 From: tazlin Date: Thu, 25 Jul 2024 19:57:32 -0400 Subject: [PATCH] build/fix: hard pin torch; preemptively install specific torch w/ conda --- requirements.txt | 2 +- tests/conftest.py | 12 +++++----- tests/test_horde_dep_updates.py | 33 ++++++++++++++++++++++------ tests/test_pre_commit_dep_version.py | 4 ++-- update-runtime.cmd | 4 +++- 5 files changed, 38 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index c084ec55..f41f457d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy==1.26.4 -torch>=2.3.1 +torch==2.3.1 horde_sdk~=0.14.0 horde_safety~=0.2.3 diff --git a/tests/conftest.py b/tests/conftest.py index 8cf04591..69a1fe65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,12 +35,12 @@ def tracked_dependencies() -> list[str]: @pytest.fixture(scope="session") -def horde_dependency_versions() -> list[tuple[str, str]]: +def horde_dependency_versions() -> dict[str, str]: """Get the versions of horde dependencies from the requirements file.""" with open(REQUIREMENTS_FILE_PATH) as f: requirements = f.readlines() - dependencies = [] + dependencies = {} for req in requirements: for dep in TRACKED_DEPENDENCIES: if req.startswith(dep): @@ -55,18 +55,18 @@ def horde_dependency_versions() -> list[tuple[str, str]]: # Strip any info starting from the `+` character version = version.split("+")[0] - dependencies.append((dep, version)) + dependencies[dep] = version return dependencies @pytest.fixture(scope="session") -def rocm_horde_dependency_versions() -> list[tuple[str, str]]: +def rocm_horde_dependency_versions() -> dict[str, str]: """Get the versions of horde dependencies from the ROCm requirements file.""" with open(ROCM_REQUIREMENTS_FILE_PATH) as f: requirements = f.readlines() - dependencies = [] + dependencies = {} for req in requirements: for dep in TRACKED_DEPENDENCIES: if req.startswith(dep): @@ -81,6 +81,6 @@ def rocm_horde_dependency_versions() -> list[tuple[str, str]]: # Strip any info starting from the `+` character version = version.split("+")[0] - dependencies.append((dep, version)) + dependencies[dep] = version return dependencies diff --git a/tests/test_horde_dep_updates.py b/tests/test_horde_dep_updates.py index 8c4ad114..38b45445 100644 --- a/tests/test_horde_dep_updates.py +++ b/tests/test_horde_dep_updates.py @@ -17,26 +17,45 @@ def test_horde_bridge_updating(horde_dependency_versions: list[tuple[str, str]]) assert "-U" in line, "No -U flag found in pip install command" for dep, version in haidra_deps: assert f"{dep}~={version}" in line, f"Dependency {dep} not found in pip install command" + assert found_line, "No pip install command found in horde-bridge.cmd" - assert found_line + +HORDE_UPDATE_RUNTIME_SCRIPT = Path(__file__).parent.parent / "update-runtime.cmd" + + +def test_horde_update_runtime_updating(horde_dependency_versions: dict[str, str]) -> None: + """Check that the versions of horde deps. in update-runtime.cmd match the versions in requirements.txt.""" + torch_dep_string = "torch" + torch_version = horde_dependency_versions["torch"] + + script_lines = HORDE_UPDATE_RUNTIME_SCRIPT.read_text().split("\n") + + found_line = False + for line in script_lines: + if "python -s -m pip install torch==" in line: + found_line = True + assert ( + f"{torch_dep_string}=={torch_version}" in line + ), f"Torch {torch_version} not found in initial torch install command" + + assert found_line, "No initial torch install command found" def test_different_requirements_files_match( - horde_dependency_versions: list[tuple[str, str]], + horde_dependency_versions: dict[str, str], rocm_horde_dependency_versions: list[tuple[str, str]], ) -> None: """Check that the versions of horde deps. in the main and rocm requirements files match.""" - main_deps = dict(horde_dependency_versions) rocm_deps = dict(rocm_horde_dependency_versions) - for dep in main_deps: + for dep in horde_dependency_versions: assert dep in rocm_deps, f"Dependency {dep} not found in rocm requirements file" assert ( - main_deps[dep] == rocm_deps[dep] + horde_dependency_versions[dep] == rocm_deps[dep] ), f"Dependency {dep} has different versions in main and rocm requirements files" for dep in rocm_deps: - assert dep in main_deps, f"Dependency {dep} not found in main requirements file" + assert dep in horde_dependency_versions, f"Dependency {dep} not found in main requirements file" assert ( - rocm_deps[dep] == main_deps[dep] + rocm_deps[dep] == horde_dependency_versions[dep] ), f"Dependency {dep} has different versions in main and rocm requirements files" diff --git a/tests/test_pre_commit_dep_version.py b/tests/test_pre_commit_dep_version.py index 153658cf..fac58db7 100644 --- a/tests/test_pre_commit_dep_version.py +++ b/tests/test_pre_commit_dep_version.py @@ -6,7 +6,7 @@ def test_pre_commit_dep_versions( - horde_dependency_versions: list[tuple[str, str]], + horde_dependency_versions: dict[str, str], tracked_dependencies: list[str], ) -> None: """Check that the versions of horde deps. in .pre-commit-config.yaml match the versions in requirements.txt. @@ -52,7 +52,7 @@ def test_pre_commit_dep_versions( ), f"Some dependencies are missing their versions.\n{versions}" # Check if the versions match - matches = sum(1 for dep, version in horde_dependency_versions if versions.get(dep) == version) + matches = sum(1 for dep, version in horde_dependency_versions.items() if versions.get(dep) == version) assert matches == len( horde_dependency_versions, diff --git a/update-runtime.cmd b/update-runtime.cmd index fb1e5d7c..0375811a 100644 --- a/update-runtime.cmd +++ b/update-runtime.cmd @@ -49,6 +49,8 @@ micromamba.exe shell hook -s cmd.exe -p %MAMBA_ROOT_PREFIX% -v call "%MAMBA_ROOT_PREFIX%\condabin\mamba_hook.bat" call "%MAMBA_ROOT_PREFIX%\condabin\micromamba.bat" activate windows +python -s -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 -U + if defined hordelib ( python -s -m pip uninstall -y hordelib horde_engine horde_model_reference python -s -m pip install horde_engine horde_model_reference --extra-index-url https://download.pytorch.org/whl/cu121 @@ -56,7 +58,7 @@ if defined hordelib ( if defined scribe ( python -s -m pip install -r requirements-scribe.txt ) else ( - python -s -m pip install -r requirements.txt -U --extra-index-url https://download.pytorch.org/whl/cu121 + python -s -m pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -U ) ) call deactivate