Skip to content

Commit

Permalink
build/fix: hard pin torch; preemptively install specific torch w/ conda
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jul 25, 2024
1 parent 702a96d commit f30697d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy==1.26.4
torch>=2.3.1
torch==2.3.1

horde_sdk~=0.14.0
horde_safety~=0.2.3
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def tracked_dependencies() -> list[str]:


@pytest.fixture(scope="session")
def horde_dependency_versions() -> list[tuple[str, str]]:
def horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the requirements file."""
with open(REQUIREMENTS_FILE_PATH) as f:
requirements = f.readlines()

dependencies = []
dependencies = {}
for req in requirements:
for dep in TRACKED_DEPENDENCIES:
if req.startswith(dep):
Expand All @@ -55,18 +55,18 @@ def horde_dependency_versions() -> list[tuple[str, str]]:

# Strip any info starting from the `+` character
version = version.split("+")[0]
dependencies.append((dep, version))
dependencies[dep] = version

return dependencies


@pytest.fixture(scope="session")
def rocm_horde_dependency_versions() -> list[tuple[str, str]]:
def rocm_horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the ROCm requirements file."""
with open(ROCM_REQUIREMENTS_FILE_PATH) as f:
requirements = f.readlines()

dependencies = []
dependencies = {}
for req in requirements:
for dep in TRACKED_DEPENDENCIES:
if req.startswith(dep):
Expand All @@ -81,6 +81,6 @@ def rocm_horde_dependency_versions() -> list[tuple[str, str]]:

# Strip any info starting from the `+` character
version = version.split("+")[0]
dependencies.append((dep, version))
dependencies[dep] = version

return dependencies
33 changes: 26 additions & 7 deletions tests/test_horde_dep_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,45 @@ def test_horde_bridge_updating(horde_dependency_versions: list[tuple[str, str]])
assert "-U" in line, "No -U flag found in pip install command"
for dep, version in haidra_deps:
assert f"{dep}~={version}" in line, f"Dependency {dep} not found in pip install command"
assert found_line, "No pip install command found in horde-bridge.cmd"

assert found_line

HORDE_UPDATE_RUNTIME_SCRIPT = Path(__file__).parent.parent / "update-runtime.cmd"


def test_horde_update_runtime_updating(horde_dependency_versions: dict[str, str]) -> None:
"""Check that the versions of horde deps. in update-runtime.cmd match the versions in requirements.txt."""
torch_dep_string = "torch"
torch_version = horde_dependency_versions["torch"]

script_lines = HORDE_UPDATE_RUNTIME_SCRIPT.read_text().split("\n")

found_line = False
for line in script_lines:
if "python -s -m pip install torch==" in line:
found_line = True
assert (
f"{torch_dep_string}=={torch_version}" in line
), f"Torch {torch_version} not found in initial torch install command"

assert found_line, "No initial torch install command found"


def test_different_requirements_files_match(
horde_dependency_versions: list[tuple[str, str]],
horde_dependency_versions: dict[str, str],
rocm_horde_dependency_versions: list[tuple[str, str]],
) -> None:
"""Check that the versions of horde deps. in the main and rocm requirements files match."""
main_deps = dict(horde_dependency_versions)
rocm_deps = dict(rocm_horde_dependency_versions)

for dep in main_deps:
for dep in horde_dependency_versions:
assert dep in rocm_deps, f"Dependency {dep} not found in rocm requirements file"
assert (
main_deps[dep] == rocm_deps[dep]
horde_dependency_versions[dep] == rocm_deps[dep]
), f"Dependency {dep} has different versions in main and rocm requirements files"

for dep in rocm_deps:
assert dep in main_deps, f"Dependency {dep} not found in main requirements file"
assert dep in horde_dependency_versions, f"Dependency {dep} not found in main requirements file"
assert (
rocm_deps[dep] == main_deps[dep]
rocm_deps[dep] == horde_dependency_versions[dep]
), f"Dependency {dep} has different versions in main and rocm requirements files"
4 changes: 2 additions & 2 deletions tests/test_pre_commit_dep_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_pre_commit_dep_versions(
horde_dependency_versions: list[tuple[str, str]],
horde_dependency_versions: dict[str, str],
tracked_dependencies: list[str],
) -> None:
"""Check that the versions of horde deps. in .pre-commit-config.yaml match the versions in requirements.txt.
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_pre_commit_dep_versions(
), f"Some dependencies are missing their versions.\n{versions}"

# Check if the versions match
matches = sum(1 for dep, version in horde_dependency_versions if versions.get(dep) == version)
matches = sum(1 for dep, version in horde_dependency_versions.items() if versions.get(dep) == version)

assert matches == len(
horde_dependency_versions,
Expand Down
4 changes: 3 additions & 1 deletion update-runtime.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ micromamba.exe shell hook -s cmd.exe -p %MAMBA_ROOT_PREFIX% -v
call "%MAMBA_ROOT_PREFIX%\condabin\mamba_hook.bat"
call "%MAMBA_ROOT_PREFIX%\condabin\micromamba.bat" activate windows

python -s -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu121 -U

if defined hordelib (
python -s -m pip uninstall -y hordelib horde_engine horde_model_reference
python -s -m pip install horde_engine horde_model_reference --extra-index-url https://download.pytorch.org/whl/cu121
) else (
if defined scribe (
python -s -m pip install -r requirements-scribe.txt
) else (
python -s -m pip install -r requirements.txt -U --extra-index-url https://download.pytorch.org/whl/cu121
python -s -m pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 -U
)
)
call deactivate
Expand Down

0 comments on commit f30697d

Please sign in to comment.