Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest #2

Merged
merged 6 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions .github/workflows/docs.yaml

This file was deleted.

3 changes: 2 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ on:

jobs:
test:
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest, macos-latest]

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
![GitHub Org's stars](https://img.shields.io/github/stars/Janelia-cellmap)
[![GitHub Org's stars](https://img.shields.io/github/stars/Janelia-cellmap)](https://github.com/janelia-cellmap)


<img src="assets/CellMapLogo2.png" alt="CellMap logo" width="85%">
<img src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-models/main/assets/CellMapLogo2.png" alt="CellMap logo" width="85%">

# cellmap-models

Expand Down
9 changes: 9 additions & 0 deletions docs/cellpose.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<!-- Lightweight client-side loader that feature-detects and load polyfills only when necessary -->
<script src="https://cdn.jsdelivr.net/npm/@webcomponents/webcomponentsjs@2/webcomponents-loader.min.js"></script>

<!-- Load the element definition -->
<script type="module" src="https://cdn.jsdelivr.net/gh/zerodevx/zero-md@1/src/zero-md.min.js"></script>

<!-- Simply set the `src` attribute to your MD file and win -->
<zero-md
src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-models/main/src/cellmap_models/pytorch/cellpose/README.md"></zero-md>
9 changes: 9 additions & 0 deletions docs/cosem.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<!-- Lightweight client-side loader that feature-detects and load polyfills only when necessary -->
<script src="https://cdn.jsdelivr.net/npm/@webcomponents/webcomponentsjs@2/webcomponents-loader.min.js"></script>

<!-- Load the element definition -->
<script type="module" src="https://cdn.jsdelivr.net/gh/zerodevx/zero-md@1/src/zero-md.min.js"></script>

<!-- Simply set the `src` attribute to your MD file and win -->
<zero-md
src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-models/main/src/cellmap_models/pytorch/cosem/README.md"></zero-md>
8 changes: 8 additions & 0 deletions docs/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<!-- Lightweight client-side loader that feature-detects and load polyfills only when necessary -->
<script src="https://cdn.jsdelivr.net/npm/@webcomponents/webcomponentsjs@2/webcomponents-loader.min.js"></script>

<!-- Load the element definition -->
<script type="module" src="https://cdn.jsdelivr.net/gh/zerodevx/zero-md@1/src/zero-md.min.js"></script>

<!-- Simply set the `src` attribute to your MD file and win -->
<zero-md src="../README.md"></zero-md>
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ dev = [
'black',
'mypy',
'pdoc',
'pre-commit'
]
pretrained = [
'pre-commit',
'cellpose[gui]'
]

Expand Down
4 changes: 2 additions & 2 deletions src/cellmap_models/pytorch/cellpose/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<!-- FILEPATH: /Users/rhoadesj/Repos/cellmap-models/src/cellmap_models/pytorch/cellpose/README.md -->
<h1 style="height: 56pt;">Finetuned Cellpose Models<img src="https://www.cellpose.org/static/images/cellpose_transparent.png" alt="cellpose logo" height=56pt></h1>
<h1 style="height: 56pt;">Finetuned <a href="https://www.cellpose.org"> Cellpose Models<img src="https://www.cellpose.org/static/images/cellpose_transparent.png" alt="cellpose logo" height=56pt></h1></a>

This directory contains finetuned scripts for downloading Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data.
This directory contains scripts for downloading finetuned Cellpose models, particularly for use with the `cellpose` package. The models are trained on a variety of cell types from CellMap FIBSEM images, and can be used for segmentation of new data. These models are returned as [CellposeModel objects](https://cellpose.readthedocs.io/en/latest/api.html#cellposemodel).

## Models

Expand Down
2 changes: 1 addition & 1 deletion src/cellmap_models/pytorch/cellpose/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def get_model(
print(f"Downloading {model_name} from {models_dict[model_name]}")
download_url_to_file(models_dict[model_name], full_path)
print(f"Downloaded model {model_name} to {base_path}.")
return
return full_path
17 changes: 11 additions & 6 deletions src/cellmap_models/pytorch/cellpose/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from pathlib import Path
import torch
from .get_model import get_model
from cellpose.models import CellposeModel


def load_model(
model_name: str,
base_path: str = f"{Path(__file__).parent}/models",
device: str = "cuda",
):
device: str | torch.device = "cuda",
) -> torch.nn.Module:
"""Load model

Args:
Expand All @@ -19,11 +20,15 @@ def load_model(
Returns:
model: model
"""

get_model(model_name, base_path)
model_path = get_model(model_name, base_path)
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
print("CUDA not available. Using CPU.")
model = torch.jit.load(os.path.join(base_path, f"{model_name}.pt"), device)
model.eval()
if isinstance(device, str):
device = torch.device(device)

model = CellposeModel(pretrained_model=model_path, device=device)

print(f"{model.diam_labels} diameter labels were used for training")

return model
2 changes: 1 addition & 1 deletion src/cellmap_models/pytorch/cosem/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<img src="../../../../assets/COSEM_logo_semi-invert_transparent.png" alt="CellMap logo" width="85%">
<img src="https://raw.githubusercontent.com/janelia-cellmap/cellmap-models/main/assets/COSEM_logo_semi-invert_transparent.png" alt="COSEM logo" width="85%">

# COSEM Trained PyTorch Networks

Expand Down
2 changes: 2 additions & 0 deletions src/cellmap_models/pytorch/cosem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
}

models_list = list(models_dict.keys())

model_names = list(set(x.split("/")[0] for x in models_list))
33 changes: 23 additions & 10 deletions src/cellmap_models/pytorch/cosem/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,47 @@ def get_param_dict(model_params):
return param_dict


def load_model(checkpoint_name):
def load_model(checkpoint_name: str) -> torch.nn.Module:
"""
Load a model from a checkpoint file.

Args:
checkpoint_name (str): Name of the checkpoint file.
"""
from . import models_dict, models_list # avoid circular import
from . import models_dict, models_list, model_names # avoid circular import

# Make sure the checkpoint exists
if (
checkpoint_name not in models_dict
and Path(checkpoint_name).with_suffix(".pth") not in models_list
):
raise ValueError(f"Model {checkpoint_name} not found")
checkpoint_path = Path(Path(__file__).parent / Path(checkpoint_name)).with_suffix(
".pth"
)
if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)
if checkpoint_name in model_names:
checkpoint_path = Path(
Path(__file__).parent / Path(checkpoint_name) / "model.py"
)
no_weights = True
else:
raise ValueError(f"Model {checkpoint_name} not found")
else:
checkpoint_path = Path(
Path(__file__).parent / Path(checkpoint_name)
).with_suffix(".pth")
if not checkpoint_path.exists():
url = models_dict[checkpoint_name]
print(f"Downloading {checkpoint_name} from {url}")
download_url_to_file(url, checkpoint_path)
no_weights = False

model_params = SourceFileLoader(
"model", str(Path(checkpoint_path).parent / "model.py")
).load_module()

model = Architecture(model_params)

if no_weights:
print(f"Not loading weights for {checkpoint_name}.")
return model

print(f"Loading model from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
new_checkpoint = deepcopy(checkpoint)
Expand All @@ -69,6 +81,7 @@ def load_model(checkpoint_name):
new_key = key.replace("architecture.", "")
new_checkpoint["model"][new_key] = new_checkpoint["model"].pop(key)
model.load_state_dict(new_checkpoint["model"])
model.eval()

return model

Expand Down
2 changes: 0 additions & 2 deletions src/cellmap_models/pytorch/cosem/setup04/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from pathlib import Path
import numpy as np
from cellmap_models import download_url_to_file

# voxel size parameters
voxel_size_output = np.array((4,) * 3)
Expand Down
2 changes: 0 additions & 2 deletions src/cellmap_models/pytorch/cosem/setup26.1/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from pathlib import Path
import numpy as np
from cellmap_models import download_url_to_file

# voxel size parameters
voxel_size_output = np.array((4,) * 3)
Expand Down
2 changes: 0 additions & 2 deletions src/cellmap_models/pytorch/cosem/setup28/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
from pathlib import Path
from cellmap_models import download_url_to_file

# voxel size parameters
voxel_size_output = np.array((4,) * 3)
Expand Down
2 changes: 0 additions & 2 deletions src/cellmap_models/pytorch/cosem/setup36/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from pathlib import Path
import numpy as np
from cellmap_models import download_url_to_file

# voxel size parameters
voxel_size_output = np.array((4,) * 3)
Expand Down
2 changes: 0 additions & 2 deletions src/cellmap_models/pytorch/cosem/setup45/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
from pathlib import Path
from cellmap_models import download_url_to_file

# voxel size parameters
voxel_size_output = np.array((4,) * 3)
Expand Down
2 changes: 0 additions & 2 deletions tests/test_assert.py

This file was deleted.

7 changes: 7 additions & 0 deletions tests/test_cosem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cellmap_models import cosem


def test_load_model():
for model_name in cosem.model_names:
model = cosem.load_model(model_name)
assert model is not None
Loading