-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_dir.py
75 lines (62 loc) · 2.28 KB
/
run_dir.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys
SOLO_ROOT = os.environ.get("SOLO_ROOT")
if SOLO_ROOT is not None:
sys.path.insert(1, os.path.abspath(os.path.dirname(SOLO_ROOT)))
import argparse
from pathlib import Path
import shutil
from tqdm import tqdm
from PIL import Image
import cv2
import torch
from SOLO.mmdet.apis import inference_detector, init_detector
from process import f, g
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src_dir", type=Path, required=True)
parser.add_argument("--dst_dir", type=Path, required=True)
parser.add_argument("--cfg", type=Path, required=True)
parser.add_argument("--ckpt", type=Path, required=True)
parser.add_argument("--threshold", type=float, default=0.5)
parser.add_argument("--labels", nargs="*", default=None, type=int)
parser.add_argument("--policy", type=str, default="aggregate")
parser.add_argument("--src_extension", type=str, default="")
args = parser.parse_args()
assert args.src_dir.exists()
assert args.cfg.exists()
assert args.ckpt.exists()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = init_detector(str(args.cfg), str(args.ckpt), device=device)
if not args.labels:
args.labels = None
else:
args.labels = torch.tensor(args.labels, device=device)
print(args.src_video)
print(args.dst_video)
print(args.cfg)
print(args.ckpt)
src_images = sorted(args.src_dir.glob(f"*{args.src_extension}"))
print(src_images)
args.dst_dir.mkdir(parents=True, exist_ok=True)
for i, src_image in enumerate(tqdm(src_images)):
image = cv2.imread(f"{src_image}")
(result,) = inference_detector(model, image)
if result is not None:
masks, labels, scores = result
masks = f(
masks=masks,
labels=labels,
scores=scores,
threshold=args.threshold,
retained_labels=args.labels,
)
mask = g(
masks=masks,
policy=args.policy
)
else:
h, w, _ = image.shape
mask = torch.zeros(h, w, dtype=torch.bool)
mask = Image.fromarray(mask.cpu().numpy())
mask.save(args.dst_dir / f"{src_image.stem}.png")