-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathprocess.py
57 lines (48 loc) · 1.65 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import Tensor
def f(
masks: Tensor,
labels: Tensor,
scores: Tensor,
threshold: float,
retained_labels: Tensor = None,
) -> Tensor:
n, h, w = masks.shape
c = scores > threshold
if retained_labels is not None:
l = torch.zeros_like(c)
for retained_label in retained_labels:
l |= labels == retained_label
c &= l
masks = masks[c[..., None, None].expand(n, h, w)].reshape(-1, h, w)
return masks
def g(masks: Tensor, policy: str):
m, h, w = masks.shape
if m > 0:
if policy == "aggregate":
mask = masks.sum(dim=-3, dtype=torch.bool)
elif policy == "biggest":
mask = masks[masks.sum(dim=[-2, -1]).argmax()]
elif policy in ["left", "right", "top", "bottom", "center"]:
x, y = torch.zeros(m, device=masks.device), torch.zeros(m, device=masks.device)
for i in range(m):
y[i], x[i] = torch.nonzero(masks[i]).float().mean(dim=0)
if policy == "left":
i = x.argmin()
elif policy == "right":
i = x.argmax()
elif policy == "top":
i = y.argmin()
elif policy == "bottom":
i = y.argmax()
elif policy == "center":
d = torch.stack([x - w / 2, y - h / 2], dim=-1).norm(p=2, dim=-1)
i = d.argmin()
else:
raise ValueError(f"Unknown policy: {policy}")
mask = masks[i]
else:
raise ValueError(f"Unknown policy: {policy}")
else:
mask = torch.zeros(h, w, dtype=torch.bool)
return mask