-
Notifications
You must be signed in to change notification settings - Fork 1
/
global_features_web.py
293 lines (253 loc) · 10.4 KB
/
global_features_web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import uvicorn
if __name__ == '__main__':
uvicorn.run('global_features_web:app', host='127.0.0.1', port=33334, log_level="info")
exit()
from os import environ
if not "GET_FILENAMES" in environ:
print("GET_FILENAMES not found! Defaulting to 0...")
GET_FILENAMES = 0
else:
if environ["GET_FILENAMES"] not in ["0","1"]:
print("GET_FILENAMES has wrong argument! Defaulting to 0...")
GET_FILENAMES = 0
else:
GET_FILENAMES = int(environ["GET_FILENAMES"])
import traceback
from os.path import exists
from typing import Optional, Union
from pydantic import BaseModel
from fastapi import FastAPI, File,Form, HTTPException, Response, status
import numpy as np
import asyncio
from PIL import Image
from pathlib import Path
import io
import faiss
import pickle
index = None
DATA_CHANGED_SINCE_LAST_SAVE = False
from modules.byte_ops import int_to_bytes, int_from_bytes
from modules.inference_ops import get_image_features, get_device
from modules.transform_ops import transform
from modules.lmdb_ops import get_dbs
dim = 768
device = get_device()
pca_w_file = Path("./data/pca_w.pkl")
pca = None
if pca_w_file.is_file():
with open(pca_w_file, 'rb') as pickle_file:
pca = pickle.load(pickle_file)
print("USING PCA")
else:
print("pca_w.pkl not found. Proceeding without PCA")
app = FastAPI()
def main():
global DB_features, DB_filename_to_id, DB_id_to_filename
init_index()
DB_features, DB_filename_to_id, DB_id_to_filename = get_dbs()
loop = asyncio.get_event_loop()
loop.call_later(10, periodically_save_index,loop)
def read_img_buffer(image_data):
img = Image.open(io.BytesIO(image_data))
# img=img.convert('L').convert('RGB') #GREYSCALE
if img.mode != 'RGB':
img = img.convert('RGB')
return img
def check_if_exists_by_image_id(image_id):
with DB_features.begin(buffers=True) as txn:
x = txn.get(int_to_bytes(image_id), default=False)
if x:
return True
return False
def get_filenames_bulk(image_ids):
image_ids_bytes = [int_to_bytes(x) for x in image_ids]
with DB_id_to_filename.begin(buffers=False) as txn:
with txn.cursor() as curs:
file_names = curs.getmulti(image_ids_bytes)
for i in range(len(file_names)):
file_names[i] = file_names[i][1].decode()
return file_names
def get_image_id_by_filename(file_name):
with DB_filename_to_id.begin(buffers=True) as txn:
image_id = txn.get(file_name.encode(), default=False)
if not image_id:
return False
return int_from_bytes(image_id)
def delete_descriptor_by_id(image_id):
image_id_bytes = int_to_bytes(image_id)
with DB_features.begin(write=True, buffers=True) as txn:
txn.delete(image_id_bytes) #True = deleted False = not found
with DB_id_to_filename.begin(write=True, buffers=True) as txn:
file_name_bytes = txn.get(image_id_bytes, default=False)
txn.delete(image_id_bytes)
with DB_filename_to_id.begin(write=True, buffers=True) as txn:
txn.delete(file_name_bytes)
def add_descriptor(image_id, features):
file_name_bytes = f"{image_id}.online".encode()
image_id_bytes = int_to_bytes(image_id)
with DB_features.begin(write=True, buffers=True) as txn:
txn.put(image_id_bytes, features.tobytes())
with DB_id_to_filename.begin(write=True, buffers=True) as txn:
txn.put(image_id_bytes, file_name_bytes)
with DB_filename_to_id.begin(write=True, buffers=True) as txn:
txn.put(file_name_bytes, image_id_bytes)
def get_features(image_buffer):
image=read_img_buffer(image_buffer)
image = transform(image).unsqueeze(0).to(device)
feature_vector = get_image_features(image)
return feature_vector
def get_aqe_vector(feature_vector, n, alpha):
_, I = index.search(feature_vector, n)
top_features=[]
for i in range(n):
top_features.append(index.reconstruct(int(list(I[0])[i])).flatten())
new_feature=[]
for i in range(dim):
_sum=0
for j in range(n):
_sum+=top_features[j][i] * np.dot(feature_vector, top_features[j].T)**alpha
new_feature.append(_sum)
new_feature=np.array(new_feature)
new_feature/=np.linalg.norm(new_feature)
new_feature=new_feature.astype(np.float32).reshape(1,-1)
return new_feature
def nn_find_similar(feature_vector, k, distance_threshold, aqe_n, aqe_alpha):
if aqe_n is not None and aqe_alpha is not None:
feature_vector=get_aqe_vector(feature_vector,aqe_n, aqe_alpha)
if k is not None:
D, I = index.search(feature_vector, k)
D = D.flatten()
I = I.flatten()
elif distance_threshold is not None:
_, D, I = index.range_search(feature_vector, distance_threshold)
res=[{"image_id":int(I[i]),"distance":float(D[i])} for i in range(len(D))]
res = sorted(res, key=lambda x: x["distance"])
return res
@app.get("/")
async def read_root():
return {"Hello": "World"}
class Item_global_features_get_similar_images_by_id(BaseModel):
image_id: int
k: Union[str,int,None] = None
distance_threshold: Union[str,float,None] = None
aqe_n: Union[str,int,None] = None
aqe_alpha: Union[str,float,None] = None
@app.post("/global_features_get_similar_images_by_id")
async def global_features_get_similar_images_by_id_handler(item: Item_global_features_get_similar_images_by_id):
try:
k=item.k
distance_threshold=item.distance_threshold
aqe_n = item.aqe_n
aqe_alpha = item.aqe_alpha
if k:
k = int(k)
if distance_threshold:
distance_threshold = float(distance_threshold)
if aqe_n:
aqe_n = int(aqe_n)
if aqe_alpha:
aqe_alpha = float(aqe_alpha)
if (k is None) == (distance_threshold is None):
raise HTTPException(status_code=500, detail="both k and distance_threshold present")
target_features = index.reconstruct(item.image_id).reshape(1,-1)
similar = nn_find_similar(target_features, k, distance_threshold, aqe_n, aqe_alpha)
if GET_FILENAMES:
file_names = get_filenames_bulk([el["image_id"] for el in similar])
for i in range(len(similar)):
similar[i]["file_name"] = file_names[i]
return similar
except:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Error in global_features_get_similar_images_by_id")
@app.post("/global_features_get_similar_images_by_image_buffer")
async def global_features_get_similar_images_by_image_buffer_handler(image: bytes = File(...), k: Optional[str] = Form(None),
distance_threshold: Optional[str] = Form(None), aqe_n: Optional[str] = Form(None),aqe_alpha: Optional[str] = Form(None)):
try:
if k:
k = int(k)
if distance_threshold:
distance_threshold = float(distance_threshold)
if aqe_n:
aqe_n = int(aqe_n)
if aqe_alpha:
aqe_alpha = float(aqe_alpha)
if (k is None) == (distance_threshold is None):
raise HTTPException(status_code=500, detail="both k and distance_threshold present")
target_features=get_features(image)
if pca:
target_features=pca.transform(target_features)
target_features/=np.linalg.norm(target_features)
target_features=target_features.astype(np.float32)
similar = nn_find_similar(target_features, k, distance_threshold, aqe_n, aqe_alpha)
if GET_FILENAMES:
file_names = get_filenames_bulk([el["image_id"] for el in similar])
for i in range(len(similar)):
similar[i]["file_name"] = file_names[i]
return similar
except:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Error in global_features_get_similar_images_by_image_buffer")
@app.post("/calculate_global_features")
async def calculate_global_features_handler(image: bytes = File(...),image_id: str = Form(...)):
try:
global DATA_CHANGED_SINCE_LAST_SAVE
image_id = int(image_id)
if check_if_exists_by_image_id(image_id):
return Response(content="Image with the same id is already in the db", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, media_type="text/plain")
features = get_features(image)
if pca:
features=pca.transform(features)
features/=np.linalg.norm(features)
features=features.astype(np.float32)
add_descriptor(image_id, features)
index.add_with_ids(features.reshape(1,-1), np.int64([image_id]))
DATA_CHANGED_SINCE_LAST_SAVE = True
return Response(status_code=status.HTTP_200_OK)
except:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Can't calculate global features")
class Item_delete_global_features(BaseModel):
image_id: Union[int ,None] = None
file_name: Union[None,str] = None
@app.post("/delete_global_features")
async def delete_global_features_handler(item:Item_delete_global_features):
try:
global DATA_CHANGED_SINCE_LAST_SAVE
if item.file_name:
image_id = get_image_id_by_filename(item.file_name)
else:
image_id = item.image_id
res = index.remove_ids(np.int64([image_id]))
if res != 0:
delete_descriptor_by_id(image_id)
DATA_CHANGED_SINCE_LAST_SAVE = True
else: #nothing to delete
print(f"err: no image with id {image_id}")
return Response(status_code=status.HTTP_200_OK)
except:
traceback.print_exc()
raise HTTPException(status_code=500, detail="Can't delete global features")
def init_index():
global index
if exists("./data/populated.index"):
index = faiss.read_index("./data/populated.index")
else:
print("Index is not found!")
print("Creating empty index")
import subprocess
try:
subprocess.call(['python3', 'add_to_index.py'])
except:
pass
try: #one of these should exist
subprocess.call(['python', 'add_to_index.py'])
except:
pass
init_index()
def periodically_save_index(loop):
global DATA_CHANGED_SINCE_LAST_SAVE, index
if DATA_CHANGED_SINCE_LAST_SAVE:
DATA_CHANGED_SINCE_LAST_SAVE=False
faiss.write_index(index, "./data/populated.index")
loop.call_later(10, periodically_save_index,loop)
main()