Skip to content

Commit

Permalink
Add cellmap.add_cellpose script and update model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Mar 8, 2024
1 parent 20381fc commit 2e8e1b9
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ exclude = ['setup*']
ignore_missing_imports = true

[project.scripts]
cellmap.add_cellpose = "cellmap_models.cellpose:add_model"
"cellmap.add_cellpose" = "cellmap_models.pytorch.cellpose:add_model"
Binary file modified src/cellmap_models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file modified src/cellmap_models/__pycache__/utils.cpython-310.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions src/cellmap_models/pytorch/cellpose/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .add_model import add_model
from .load_model import load_model
from .get_model import get_model

models_dict = {
"jrc_mus-epididymis-1_nuc_cp": "https://github.com/janelia-cellmap/cellmap-models/releases/download/2024.03.08/jrc_mus-epididymis-1_nuc_cp",
Expand Down
23 changes: 8 additions & 15 deletions src/cellmap_models/pytorch/cellpose/add_model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from . import models_dict
from cellpose.io import _add_model
import sys
from typing import Optional
from cellpose.io import add_model as _add_model
from cellpose.models import MODEL_DIR
from cellpose.utils import download_url_to_file
from .get_model import get_model


def add_model(model_name: str):
def add_model(model_name: Optional[str] = None):
"""Add model to cellpose
Args:
model_name (str): model name
"""
# download model to cellpose directory
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)
if model_name is None:
model_name = sys.argv[1]
base_path = MODEL_DIR

if not (base_path / f"{model_name}.pth").exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], str(base_path / f"{model_name}.pth")
)
get_model(model_name, base_path)
_add_model(str(base_path / f"{model_name}.pth"))
print(
f"Added model {model_name}. This will now be available in the cellpose model list."
Expand Down
29 changes: 29 additions & 0 deletions src/cellmap_models/pytorch/cellpose/get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from pathlib import Path
from cellpose.utils import download_url_to_file


def get_model(
model_name: str,
base_path: str = f"{Path(__file__).parent}/models",
):
"""Add model to cellpose
Args:
model_name (str): model name
base_path (str, optional): base path to store Torchscript model. Defaults to "./models".
"""
from . import models_dict

# download model to cellpose directory
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)

if not (base_path / f"{model_name}.pth").exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], str(base_path / f"{model_name}.pth")
)
print("Downloaded model {model_name} to {base_path}.")
return
14 changes: 3 additions & 11 deletions src/cellmap_models/pytorch/cellpose/load_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path
from . import models_dict
from cellmap_models.utils import download_url_to_file
import torch
from .get_model import get_model


def load_model(
Expand All @@ -19,15 +18,8 @@ def load_model(
Returns:
model: model
"""
if model_name not in models_dict:
raise ValueError(
f"Model {model_name} is not available. Available models are {list(models_dict.keys())}."
)
if not (base_path / f"{model_name}.pth").exists():
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(
models_dict[model_name], str(base_path / f"{model_name}.pth")
)

get_model(model_name, base_path)
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
print("CUDA not available. Using CPU.")
Expand Down

0 comments on commit 2e8e1b9

Please sign in to comment.