-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdrift.py
202 lines (149 loc) · 5.3 KB
/
drift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import datetime
import optparse
import os
import warnings
from typing import Tuple
import numpy as np
import requests
import tabulate
from sklearn.metrics.pairwise import cosine_similarity
warnings.filterwarnings("ignore", category=FutureWarning)
parser = optparse.OptionParser()
parser.add_option(
"--ROBOFLOW_KEY",
dest="ROBOFLOW_KEY",
help="Roboflow API Key",
default=os.environ.get("ROBOFLOW_KEY"),
)
parser.add_option(
"--ROBOFLOW_PROJECT",
dest="ROBOFLOW_PROJECT",
help="Roboflow Project ID",
default=os.environ.get("ROBOFLOW_PROJECT"),
)
parser.add_option(
"--ROBOFLOW_WORKSPACE",
dest="ROBOFLOW_WORKSPACE",
help="Roboflow Workspace ID",
default=os.environ.get("ROBOFLOW_WORKSPACE"),
)
parser.add_option(
"--DRIFT_PROJECT",
dest="DRIFT_PROJECT",
help="ID of your project storing representative images from Roboflow Collect",
default=os.environ.get("DRIFT_PROJECT"),
)
parser.add_option(
"--INCREMENT",
dest="INCREMENT",
help="Increment to group images by",
default=os.environ.get("INCREMENT", "month"),
)
args = parser.parse_args()
if (
not args[0].ROBOFLOW_KEY
or not args[0].ROBOFLOW_PROJECT
or not args[0].ROBOFLOW_WORKSPACE
or not args[0].DRIFT_PROJECT
):
parser.print_help()
exit()
if args[0].INCREMENT not in ["day", "month", "year"]:
print("Increment must be day, month, or year")
exit()
def retrieve_by_period(period: str, images: list) -> Tuple[list, dict]:
"""
Split up images by time period.
"""
clip_vectors = {}
for time in images:
date = datetime.datetime.strptime(time, period)
formatted_period = date.strftime(period)
if formatted_period not in clip_vectors:
clip_vectors[formatted_period] = []
clip_vectors[formatted_period].extend(images[time])
clip_vectors = dict(sorted(clip_vectors.items()))
avg_clip_vectors = {}
for time_period in clip_vectors:
avg_clip_vectors[time_period] = [
sum(x) / len(x) for x in zip(*clip_vectors[time_period])
]
for time_period in avg_clip_vectors:
avg_clip_vectors[time_period] = np.array(
avg_clip_vectors[time_period]
).reshape(1, -1)
return clip_vectors, avg_clip_vectors
def get_clip_vectors(
project_id: str, is_drift: bool = False, period: str = "%Y-%m"
) -> Tuple[list, dict]:
project_clip_vectors = []
images_by_time = {}
limit = 125
offset = 0
while True:
response = requests.post(
f"https://api.roboflow.com/{args[0].ROBOFLOW_WORKSPACE}/{project_id}/search?api_key={args[0].ROBOFLOW_KEY}",
json={
"limit": limit,
"query": "drift",
"fields": ["split", "embedding", "tags", "created"],
"in_dataset": "true",
"offset": offset,
},
)
offset += limit
if response.status_code != 200:
raise Exception(f"Error retrieving images: {response.text}")
response = response.json()
if len(response["results"]) == 0:
break
for image in response["results"]:
created = image["created"]
# convert from milliseconds to YYYY-MM-DD
created = datetime.datetime.fromtimestamp(created / 1000).strftime("%Y-%m-%d")
if image["split"] != "valid" and not is_drift:
print(f"Skipping {image['image_id']} because it's not in the valid split")
continue
formatted_date = datetime.datetime.strptime(created, "%Y-%m-%d")
time = f"{formatted_date.year}-{formatted_date.month:02d}"
if time not in images_by_time:
images_by_time[time] = []
images_by_time[time].append(image["embedding"])
project_clip_vectors.append(image["embedding"])
avg_clip_vectors_by_month = retrieve_by_period("%Y-%m", images_by_time)
return project_clip_vectors, avg_clip_vectors_by_month[0]
def main():
increment = args[0].INCREMENT
if increment == "day":
period = "%Y-%m-%d"
elif increment == "month":
period = "%Y-%m"
elif increment == "year":
period = "%Y"
main_project_clip_vectors, main_project_clip_vectors_by_period = get_clip_vectors(
args[0].ROBOFLOW_PROJECT,
period=period,
)
drift_project_clip_vectors, drift_project_clip_vectors_by_period = get_clip_vectors(
args[0].DRIFT_PROJECT, period=period, is_drift=True
)
avg_val_main_clip_vectors = [
sum(x) / len(x) for x in zip(*main_project_clip_vectors)
]
by_month = []
for time_period in main_project_clip_vectors_by_period:
drift_vectors = drift_project_clip_vectors_by_period[time_period]
by_month.append(
[
time_period,
cosine_similarity([drift_vectors[0]], [avg_val_main_clip_vectors])[0],
]
)
avg_drift_vectors = [sum(x) / len(x) for x in zip(*drift_project_clip_vectors)]
avg_main_vectors = [sum(x) / len(x) for x in zip(*main_project_clip_vectors)]
by_month.append(
["All Time", cosine_similarity([avg_drift_vectors], [avg_main_vectors])[0]]
)
print(tabulate.tabulate(by_month, headers=["Month", "Cosine Similarity"]))
if __name__ == "__main__":
main()