forked from hkchengrex/XMem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
colab.py
101 lines (82 loc) · 2.88 KB
/
colab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from os import path
from argparse import ArgumentParser
import shutil
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from PIL import Image
from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset
from inference.data.mask_mapper import MaskMapper
from model.network import XMem
from inference.inference_core import InferenceCore
from progressbar import progressbar
torch.set_grad_enabled(False)
if torch.cuda.is_available():
print('Using GPU')
device = 'cuda'
else:
print('CUDA not available. Please connect to a GPU instance if possible.')
device = 'cpu'
# default configuration
config = {
'top_k': 30,
'mem_every': 5,
'deep_update_every': -1,
'enable_long_term': True,
'enable_long_term_count_usage': True,
'num_prototypes': 128,
'min_mid_term_frames': 5,
'max_mid_term_frames': 10,
'max_long_term_elements': 10000,
}
network = XMem(config, './saves/XMem.pth').eval().to(device)
# print(network)
video_name = 'video.mp4'
mask_name = 'first_frame.png'
# from base64 import b64encode
# data_url = "data:video/mp4;base64," + b64encode(open(video_name, 'rb').read()).decode()
# import IPython.display
# IPython.display.Image('first_frame.png', width=400)
mask = np.array(Image.open(mask_name))
print(np.unique(mask))
num_objects = len(np.unique(mask)) - 1
import cv2
from inference.interact.interactive_utils import image_to_torch, index_numpy_to_one_hot_torch, torch_prob_to_numpy_mask, overlay_davis
torch.cuda.empty_cache()
processor = InferenceCore(network, config=config)
processor.set_all_labels(range(1, num_objects+1)) # consecutive labels
cap = cv2.VideoCapture(video_name)
# You can change these two numbers
frames_to_propagate = 80
visualize_every = 20
current_frame_index = 0
import matplotlib.pyplot as plt
with torch.cuda.amp.autocast(enabled=True):
while (cap.isOpened()):
# load frame-by-frame
_, frame = cap.read()
if frame is None or current_frame_index > frames_to_propagate:
break
# convert numpy array to pytorch tensor format