-
Notifications
You must be signed in to change notification settings - Fork 6
/
FastSam_segmentation.py
148 lines (119 loc) · 4.75 KB
/
FastSam_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Optional, Union
import yaml
from FastSAM.fastsam import FastSAM, FastSAMPrompt
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import numpy as np
import os
from PIL import Image
import cv2
import torch
IMAGES_TO_DISCARD = "images_to_discard.yaml"
IMAGES_TO_NOT_MASK = "images_to_not_mask.yaml"
"""
Mask the images using FastSAM model and save the masked images in the output folder.
There are some images that are not pills images, so they are discarded.
After the first masking process, if an image has not been masked well, it is added to the not_to_mask list. These images will be saved to the output folder without being masked.
"""
class FastSAM_segmentation:
"""
FastSAM model interface.
process: process the image and return the masked image
"""
def __init__(self, device: Union[str, int] = "cpu"):
import os
local_path = os.path.dirname(os.path.realpath(__file__))
self.model = FastSAM(local_path + "/FastSAM/weights/FastSAM.pt")
self.device = device
def process(self, image_path: str, output_name: str, test=False) -> Image:
everything_results = self.model(
image_path,
device=self.device,
retina_masks=True,
# imgsz=1024,
conf=0.4,
iou=0.9,
)
prompt_process = FastSAMPrompt(
image_path, everything_results, device=self.device
)
ann = prompt_process.everything_prompt()
# text prompt
ann = prompt_process.text_prompt(text="only the first or above pill")
# original_im
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
try:
mask_array = np.array(ann[0])
# plt.imshow(original_im)
# Ensure the mask is a binary mask (True/False)
# binary_mask = mask_array.astype(np.uint8) * 255
masked_complete = img.copy()
masked_complete = img * np.expand_dims(mask_array, axis=-1)
if test:
os.makedirs("./test", exist_ok=True)
cv2.imwrite(f"./test/{output_name}_masked.png", masked_complete)
else:
os.makedirs("./output", exist_ok=True)
cv2.imwrite(f"./output/{output_name}_masked.png", masked_complete)
# original_im.save("./output/original_im.jpg")
except:
if test:
os.makedirs("./test", exist_ok=True)
cv2.imwrite(f"./test/{output_name}_masked.png", img)
masked_complete = img.copy()
else:
os.makedirs("./output", exist_ok=True)
cv2.imwrite(f"./output/{output_name}_masked.png", img)
masked_complete = img.copy()
return masked_complete
def segment_image(image_paths: Optional[list[str]], test=False):
"""
Segment the images and return the masked images.
Used to compute the embeddings for the images.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FastSAM_segmentation(device)
masked_images = []
for image_path in image_paths:
image_name = image_path.split("/")[-1].split(".")[0]
masked_images.append(
model.process(image_path=image_path, output_name=image_name, test=test)
)
return masked_images
def save_not_masked_images(not_to_mask: Optional[list]):
"""
Save and return the not masked images.
"""
masked_images = []
for image_path in not_to_mask:
image_name = image_path.split("/")[-1].split(".")[0]
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
cv2.imwrite(f"./output/{image_name}_masked.png", img)
masked_images.append(img)
return masked_images
if __name__ == "__main__":
# not pills images, to discard
with open(IMAGES_TO_DISCARD, "r") as file:
discard = yaml.safe_load(file)["pills_images_to_discard"]["name"]
with open(IMAGES_TO_NOT_MASK, "r") as file:
not_mask = yaml.safe_load(file)["pills_images_to_not_mask"]["name"]
already_masked = os.listdir("./output")
already_masked = [
filename.split("_masked")[0]
for filename in already_masked
if filename.endswith(".png")
]
image_folder = "./images"
not_mask = [os.path.join(image_folder, filename) for filename in not_mask]
image_paths = [
os.path.join(image_folder, filename)
for filename in os.listdir(image_folder)
if filename.endswith(".jpg")
and filename.split("/")[-1].split(".")[0]
not in discard + already_masked + not_mask
]
image_paths.sort()
masked_images = segment_image(image_paths)
not_masked_images = save_not_masked_images(not_mask)
# print(f"masked {len(masked_images)} images")