-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdemo_interactive.py
85 lines (67 loc) · 3.29 KB
/
demo_interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch, cv2
import os
import timeit
from davisinteractive.session import DavisInteractiveSession
from davisinteractive import utils as interactive_utils
from davisinteractive.dataset import Davis
from osvos_scribble import OSVOSScribble
from mypath import Path
def main():
# General parameters
gpu_id = 0
# Configuration used in the challenges
max_nb_interactions = 8 # Maximum number of interactions
max_time_per_interaction = 30 # Maximum time per interaction per object
# Total time available to interact with a sequence and an initial set of scribbles
max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object
# Interactive parameters
subset = 'val'
host = 'localhost' # 'localhost' for subsets train and val.
# OSVOS parameters
time_budget_per_object = 20
parent_model = 'osvos_parent.pth'
prev_mask = True # Use previous mask as no-care area when fine-tuning
save_model_dir = Path.models_dir()
report_save_dir = Path.save_root_dir()
save_result_dir = report_save_dir
model = OSVOSScribble(parent_model, save_model_dir, gpu_id, time_budget_per_object,
save_result_dir=save_result_dir)
seen_seq = {}
with DavisInteractiveSession(host=host,
davis_root=Path.db_root_dir(),
subset=subset,
report_save_dir=report_save_dir,
max_nb_interactions=max_nb_interactions,
max_time=max_time) as sess:
while sess.next():
t_total = timeit.default_timer()
# Get the current iteration scribbles
sequence, scribbles, first_scribble = sess.get_scribbles()
if first_scribble:
n_interaction = 1
n_objects = Davis.dataset[sequence]['num_objects']
first_frame = interactive_utils.scribbles.annotated_frames(scribbles)[0]
seen_seq[sequence] = 1 if sequence not in seen_seq.keys() else seen_seq[sequence]+1
else:
n_interaction += 1
pred_masks = []
print('\nRunning sequence {} in interaction {} and scribble iteration {}'
.format(sequence, n_interaction, seen_seq[sequence]))
for obj_id in range(1, n_objects+1):
model.train(first_frame, n_interaction, obj_id, scribbles, seen_seq[sequence],
subset=subset,
use_previous_mask=prev_mask)
pred_masks.append(model.test(sequence, n_interaction, obj_id,
subset=subset,
scribble_iter=seen_seq[sequence]))
final_masks = interactive_utils.mask.combine_masks(pred_masks)
# Submit your prediction
sess.submit_masks(final_masks)
t_end = timeit.default_timer()
print('Total time (training and testing) for single interaction: ' + str(t_end - t_total))
# Get the DataFrame report
report = sess.get_report()
# Get the global summary
summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))
if __name__ == '__main__':
main()