-
Notifications
You must be signed in to change notification settings - Fork 0
/
object_tracking_store.py
71 lines (53 loc) · 2.08 KB
/
object_tracking_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
from collections import deque
from pprint import pprint
DISTANCE_THRESHOLD = 0.5
class ObjectTrackingStore:
def __init__(self) -> None:
self.store = {}
self.next_object_id = 0
def push(self, object_id, position):
if object_id not in self.store:
self.store[object_id] = {
'positions': deque(maxlen=10),
'mean': np.array([0, 0, 0]),
'count': 0
}
length = len(self.store[object_id]['positions'])
self.store[object_id]['mean'] = ((self.store[object_id]['mean'] * length) + position) / (length + 1)
self.store[object_id]['positions'].append(position)
self.store[object_id]['count'] = (self.store[object_id]['count'] + 1) % 10
return self.store[object_id]['count']
def query(self, position):
distances = {
object_id: np.linalg.norm(position - self.store[object_id]['mean'])
for object_id in self.store
}
object_id = min(distances, key=distances.get)
return object_id, distances[object_id]
def query_and_push(self, position):
if len(self.store) == 0:
object_id = self.next_object_id
count = self.push(object_id, position)
self.next_object_id += 1
else:
object_id, distance = self.query(position)
if distance <= DISTANCE_THRESHOLD:
count = self.push(object_id, position)
else:
object_id = self.next_object_id
count = self.push(object_id, position)
self.next_object_id += 1
return object_id, count
def _test():
store = ObjectTrackingStore()
store.query_and_push(np.array([0, 0, 0]))
store.query_and_push(np.array([0, 0, 1]))
store.query_and_push(np.array([1, 1, 1]))
store.query_and_push(np.array([1, 1, 2]))
store.query_and_push(np.array([2, 2, 2]))
store.query_and_push(np.array([2, 2, 2.5]))
store.query_and_push(np.array([0.5, 1, 1]))
pprint(store.store)
if __name__ == '__main__':
_test()