Skip to content

Commit

Permalink
allow strong preferences from pick_device using a string; release fid…
Browse files Browse the repository at this point in the history
…dles (#85)

* pick_device allows strong preferences now
* Fixing deprecation
* using python 3.9 and above

---------

Co-authored-by: Will Dumm <wrhdumm@gmail.com>
  • Loading branch information
matsen and willdumm authored Nov 13, 2024
1 parent 2f01d77 commit 34d26cf
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8, "3.11"]
python-version: [3.9, "3.11"]

runs-on: ${{ matrix.os }}

Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Neural NETworks for antibody Affinity Maturation.

## pip installation

Netam is available on PyPI, and works with Python 3.8 through 3.11.
Netam is available on PyPI, and works with Python 3.9 through 3.11.

```
pip install netam
Expand Down Expand Up @@ -60,3 +60,12 @@ If you are running one of the experiment repos, such as:
* [dnsm-experiments-1](https://github.com/matsengrp/dnsm-experiments-1/)

you will want to visit those repos and follow the installation instructions there.


## Troubleshooting
* On some machines, pip may install a version of numpy that is too new for the
available version of pytorch, returning an error such as `A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash.` The solution is to downgrade to `numpy<2`:
```console
pip install --force-reinstall "numpy<2"
```
18 changes: 14 additions & 4 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,23 @@ def find_least_used_cuda_gpu():
return utilization.index(min(utilization))


def pick_device(gpu_index=None):
def pick_device(gpu_preference=None):
"""Pick a device for PyTorch to use.
If CUDA is available, use the least used GPU, and if all are idle use the gpu_index
If gpu_preference is a string, use the device with that name. This is considered a
strong preference from a user who knows what they are doing.
If gpu_preference is an integer, this is a weak preference for a numbered GPU. If
CUDA is available, use the least used GPU, and if all are idle use the gpu_index
modulo the number of GPUs. If gpu_index is None, then use a random GPU.
"""

# Strong preference for a specific device.
if gpu_preference is not None and isinstance(gpu_preference, str):
return torch.device(gpu_preference)

# else weak preference for a numbered GPU.

# check that CUDA is usable
def check_CUDA():
try:
Expand All @@ -216,10 +226,10 @@ def check_CUDA():
if torch.backends.cudnn.is_available() and check_CUDA():
which_gpu = find_least_used_cuda_gpu()
if which_gpu is None:
if gpu_index is None:
if gpu_preference is None:
which_gpu = np.random.randint(torch.cuda.device_count())
else:
which_gpu = gpu_index % torch.cuda.device_count()
which_gpu = gpu_preference % torch.cuda.device_count()
print(f"Using CUDA GPU {which_gpu}")
return torch.device(f"cuda:{which_gpu}")
elif torch.backends.mps.is_available():
Expand Down
7 changes: 3 additions & 4 deletions netam/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@

import os
import zipfile
import pkg_resources
from importlib.resources import files

import requests

from netam.framework import load_crepe

# This throws a deprecation warning. It could also be done by looking at
# __file__, or by using importlib.resources.
PRETRAINED_DIR = pkg_resources.resource_filename(__name__, "_pretrained")
with files(__package__).joinpath("_pretrained") as pretrained_path:
PRETRAINED_DIR = str(pretrained_path)

PACKAGE_LOCATIONS_AND_CONTENTS = (
# Order of entries:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/thrifty_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"\n",
"To get a pretrained crepe, just ask for it by name. 🍰!\n",
"\n",
"Here we get the `ThriftyHumV1.0-45` model:"
"Here we get the `ThriftyHumV0.2-45` model:"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
packages=find_packages(),
python_requires=">=3.8,<3.12",
python_requires=">=3.9,<3.12",
install_requires=[
"biopython",
"natsort",
Expand All @@ -22,6 +22,7 @@
"tensorboardX",
"torch",
"tqdm",
"fire",
],
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def pcp_df():
)
df = add_shm_model_outputs_to_pcp_df(
df,
pretrained.load("ThriftyHumV1.0-45"),
pretrained.load("ThriftyHumV0.2-45"),
)
return df
2 changes: 1 addition & 1 deletion tests/test_multihit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
@pytest.fixture
def mini_multihit_train_val_datasets():
df = pd.read_csv("data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz")
crepe = pretrained.load("ThriftyHumV1.0-45")
crepe = pretrained.load("ThriftyHumV0.2-45")
df = multihit.prepare_pcp_df(df, crepe, 500)
return multihit.train_test_datasets_of_pcp_df(df)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import netam.pretrained as pretrained
from pathlib import Path
import shutil


def test_names_unique():
# Check all defined models can be loaded and they have unique names
assert len(set(pretrained.MODEL_TO_LOCAL.keys())) == sum(
len(models) for _, _, _, models in pretrained.PACKAGE_LOCATIONS_AND_CONTENTS
)


def test_load_all_models():
# Remove cached models:
shutil.rmtree(Path(pretrained.PRETRAINED_DIR))
# Check each can be loaded without caching
for model_name in pretrained.MODEL_TO_LOCAL:
pretrained.load(model_name)

# Check each can be loaded with caching
for model_name in pretrained.MODEL_TO_LOCAL:
pretrained.load(model_name)

0 comments on commit 34d26cf

Please sign in to comment.