Skip to content

Commit

Permalink
Merge pull request #1 from GirinChutia/dev
Browse files Browse the repository at this point in the history
update latest changes
  • Loading branch information
GirinChutia authored May 4, 2024
2 parents 1a8081a + 78d9713 commit ff0559b
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 33 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
/fig
/dist
/src/SAM_ONNX.egg-info
/.env
/build
/src/sam_onnx/__pycache__
/model_weights
/.ipynb_checkpoints
/src/sam_onnx/.ipynb_checkpoints
25 changes: 9 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
# SAM ONNX (* Under Development)

`
Welcome to the documentation for Project Name! This repository contains the source code for the project.

## Documentation
## Environment creation Creation (For Development)

Explore our detailed documentation on [GitHub Pages](https://girinchutia.github.io/SAM_ONNX/).
> python -m venv sam_onnx_env
## Installation
- Use pip 24.0 or higher version

1. Clone the repository: `git clone https://github.com/GirinChutia/SAM_ONNX.git`
2. Build the project: `py -m build`
3. Install the package: `pip install dist/SAM_ONNX-0.0.1-py3-none-any.whl`
## Installation
In SAM_ONNX Folder
> python -m pip install -e .
## Usage

```
import onnx_sam
```

## Contributing

We welcome contributions! Please see our [Contribution Guidelines](CONTRIBUTING.md) for more details.
![alt text](repo_assests/demo.png)

## License

This project is licensed under the [MIT License](LICENSE).
This project is licensed under the [MIT License](LICENSE).
65 changes: 60 additions & 5 deletions demo.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
[build-system]
requires = ["setuptools>=61.0","matplotlib==3.8.3","numpy==1.26.4","onnxruntime==1.17.1","opencv_python==4.9.0.80"]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "SAM_ONNX"
version = "0.0.1"
description = "A simple package for using SAM with ONNX without pytorch dependencies"
description = "A simple package for using SAM (ONNX format) without pytorch dependencies"
readme = "README.md"
requires-python = ">=3.9"
dependencies = ["matplotlib>=3.8.3", "numpy>=1.26.4","onnxruntime==1.17.1","opencv_python==4.9.0.80","gdown==5.1.0"]
authors = [{ name = "Girin Chutia", email = "girin.iitm@gmail.com" }]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]

[project.optional-dependencies]
dev=['jupyter']
Binary file added repo_assests/demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 75 additions & 10 deletions src/sam_onnx/sam_onnx.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,62 @@
from copy import deepcopy
from typing import Any, Tuple, Union
from typing import Any, Tuple, Union, List
import cv2
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import glob


import gdown
import os

def check_and_download_weights(model_name='l0'):

__supported_modelnames = ['l0', 'xl0']

assert model_name in __supported_modelnames, f'Model name not supported. Please use one of : {__supported_modelnames}'

l0_weights = {'encoder' : 'https://drive.google.com/file/d/1a0tRmHQeGTAbSeMqBMhu4DinsOR3cSv6/view?usp=sharing',
'decoder': 'https://drive.google.com/file/d/13J7pNfh016sBqOQ17CludkUFdKgkkyQM/view?usp=sharing'}

xl0_weights = {'encoder': 'https://drive.google.com/file/d/1NzavgCAqk6mSzTnQ_LKfl78V_O68lWNX/view?usp=sharing',
'decoder': 'https://drive.google.com/file/d/1lrn5bQRE01Mwtp-nr9DBNTHcxk4Q6iiP/view?usp=sharing'}

if os.path.exists('model_weights'):
model_weights_folder_path = os.path.abspath('model_weights')
else:
os.makedirs('model_weights',
exist_ok = False)
model_weights_folder_path = os.path.abspath('model_weights')

if os.path.exists(f'model_weights/{model_name}/encoder.onnx'):
encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')
else:
os.makedirs(f'model_weights/{model_name}',
exist_ok = True)
if model_name == 'l0':
gdown.download(l0_weights['encoder'],
f'model_weights/{model_name}/encoder.onnx',
fuzzy=True)
if model_name == 'xl0':
gdown.download(xl0_weights['encoder'],
f'model_weights/{model_name}/encoder.onnx',
fuzzy=True)
encoder_weights_path = os.path.abspath(f'model_weights/{model_name}/encoder.onnx')

if os.path.exists(f'model_weights/{model_name}/decoder.onnx'):
decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')
else:
if model_name == 'l0':
gdown.download(l0_weights['decoder'],
f'model_weights/{model_name}/decoder.onnx',
fuzzy=True)
if model_name == 'xl0':
gdown.download(xl0_weights['decoder'],
f'model_weights/{model_name}/decoder.onnx',
fuzzy=True)
decoder_weights_path = os.path.abspath(f'model_weights/{model_name}/decoder.onnx')

return encoder_weights_path, decoder_weights_path

def show_mask(mask, ax, random_color=False):
"""
Visualize a mask image on the given axis.
Expand Down Expand Up @@ -526,23 +576,27 @@ class InferSAM:
"""

def __init__(self, model_dir: str, model_name: str = "l0"):
assert model_dir is not None, "model_dir is null"
def __init__(self, model_name: str = "l0"):
# assert model_dir is not None, "model_dir is null"
assert model_name is not None, "model_name is null"

self.model_name = model_name

encoder_weights_path, decoder_weights_path = check_and_download_weights(model_name)

# Find encoder and decoder models
encoder_path = glob.glob(model_dir + "/*_encoder.onnx")[0]
decoder_path = glob.glob(model_dir + "/*_decoder.onnx")[0]
encoder_path = encoder_weights_path # glob.glob(model_dir + "/*_encoder.onnx")[0]
decoder_path = decoder_weights_path # glob.glob(model_dir + "/*_decoder.onnx")[0]

self.encoder = SamEncoder(encoder_path)
self.decoder = SamDecoder(decoder_path)

self.figsize = (10,10)

def infer(
self,
img_path: str,
boxes: list[list] = [[80, 50, 320, 420], [300, 20, 530, 420]],
boxes: List[list] = [[80, 50, 320, 420], [300, 20, 530, 420]],
visualize=False,
) -> np.array:
"""
Infer segmentation masks for a given image using the SAM model.
Expand Down Expand Up @@ -585,7 +639,18 @@ def infer(
origin_image_size=origin_image_size,
boxes=boxes,
)

if visualize:
plt.figure(figsize=self.figsize)
plt.imshow(raw_img)
for mask in masks:
show_mask(mask, plt.gca(),
random_color=True)
for box in boxes:
show_box(box, plt.gca())
plt.show()
return masks

def set_figsize(self,figsize=(10,10)):
self.figsize = figsize


Binary file added tests/images/test1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/test2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit ff0559b

Please sign in to comment.