-
Notifications
You must be signed in to change notification settings - Fork 0
/
realtime.py
165 lines (123 loc) · 6.34 KB
/
realtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from PIL import Image
import numpy as np
import torch
from torch import nn
import model
import utils
from utils import FlatFolderDataset
import torch
import json
import cv2
# This is a very useful class that load all the useful config from a json file
class DevEnvironment():
def __init__(self, config_file):
'''
config_file: path to the config file
'''
config = json.loads(open(config_file).read())
self.batch_size = config["batch_size"] if "batch_size" in config else 32
# This is the device used for training
self.device = torch.device(config["device"]) if "device" in config else torch.device("cuda")
# Setting up the layers we need on the vgg pretrained encoder
self.vgg_encoder = model.vgg
self.vgg_encoder.load_state_dict(torch.load("./model_save/vgg_normalised.pth"))
self.vgg_encoder = nn.Sequential(*list(self.vgg_encoder.children())[:44])
# Setting up the image decoder
self.decoder = model.decoder.to(self.device)
# And then the Style Attention Network
self.network : model.MultiLevelStyleAttention = model.MultiLevelStyleAttention(self.vgg_encoder, self.decoder)
self.network.to(self.device)
self.network.train()
# Setting up everything related to the learning rate
self.lr = config["lr"] if "lr" in config else 1e-4
self.lr_decay = config["lr_decay"] if "lr_decay" in config else 0.99999
self.decay_after = config["decay_after"] if "decay_after" in config else 5000
# Total number of training steps to train
self.iters = config["iters"] if "iters" in config else 200000
# Setting up the weights for the loss calculation
self.style_weight = config["style_weight"] if "style_weight" in config else 5.0
self.content_weight = config["content_weight"] if "content_weight" in config else 1.0
self.identity1_weight = config["identity1_weight"] if "identity1_weight" in config else 50.0
self.identity2_weight = config["identity2_weight"] if "identity2_weight" in config else 1.0
# And then some variables about log and saving intervals
self.save_checkpoint_interval = config["save_checkpoint_interval"] if "save_checkpoint_interval" in config else 1000
self.log_generated_interval = config["log_generated_interval"] if "log_generated_interval" in config else 20
self.img_generated_interval = config["img_generated_interval"] if "img_generated_interval" in config else 100
def load_save(self, file_path : str):
# Loading the dict of all parameters dict from the file located at file_path
saved = torch.load(file_path, map_location=lambda storage, loc: storage)
# Loading the different part of the model separatly
self.network.decoder.load_state_dict(saved["decoder"], strict=False)
self.network.sa_module.load_state_dict(saved["sa_module"], strict=False)
self.network = self.network.half()
# Loading the configs in env variable
env = DevEnvironment("config.json")
# Load the preceding models if not the first training step
start_iteration = 177000
if start_iteration != 0:
env.load_save(f"model_save/{str(start_iteration).zfill(6)}.pt")
CONTENT_SIZE = 720
# Use custom style folder to display sample data
custom_style_dataset = FlatFolderDataset("custom_style/", 512)
cam_transform = FlatFolderDataset("custom_style/", CONTENT_SIZE)
def camera_feed():
cap = cv2.VideoCapture(0)
if not (cap.isOpened()):
print("Could not open video device")
cap.set(3, 1280)
cap.set(4, 720)
for param in env.network.parameters():
param.grad = None
torch.backends.cudnn.benchmark = True
cv2.namedWindow("window", cv2.WND_PROP_FULLSCREEN)
cv2.setWindowProperty("window",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN)
change_each = 30
until_change = change_each
i = 0
# To save or not
save = False
with torch.cuda.amp.autocast() and torch.no_grad():
style_len = len(custom_style_dataset)
style_index = 0
style = torch.unsqueeze(custom_style_dataset.__getitem__(style_index, False), 0).to(env.device).half()
style_np = cv2.resize(cv2.cvtColor(style[0].permute(1, 2, 0).float().cpu().numpy(), cv2.COLOR_BGR2RGB), (CONTENT_SIZE, CONTENT_SIZE))
while(True):
utils.clean()
# Capture frame-by-frame
ret, frame = cap.read()
frame = cam_transform.transform_test(Image.fromarray(frame))
out = env.network(torch.unsqueeze(frame.to(env.device), 0), style, train=False)
res = np.vstack((style_np, frame.permute(1, 2, 0).float().numpy()))
w, h, c = res.shape
res = cv2.resize(res, (0, 0), fx=0.5, fy=0.5)
res = np.hstack((res, out[0, [2, 1, 0]].permute(1, 2, 0).float().cpu().numpy()))
# Display the resulting frame
cv2.imshow('window', res)
if save:
# Save Frame by Frame into disk using imwrite method
cv2.imwrite('vid/Frame'+str(i)+'.jpg', res * 255.0)
# Loop through styles
until_change -= 1
i += 1
k = cv2.waitKey(33)
#Waits for a user input to quit the application
if k == ord('q'):
break
elif k == ord('a') or not until_change:
if (style_index - 1 >= 0):
style_index -= 1
else:
style_index = style_len - 1
style = torch.unsqueeze(custom_style_dataset.__getitem__(style_index, False), 0).to(env.device).half()
style_np = cv2.resize(cv2.cvtColor(style[0].permute(1, 2, 0).float().cpu().numpy(), cv2.COLOR_BGR2RGB), (CONTENT_SIZE, CONTENT_SIZE), fx=2.5, fy=2.5)
until_change = change_each
elif k == ord('d'):
if (style_index + 1 < style_len):
style_index += 1
else:
style_index = 0
style = torch.unsqueeze(custom_style_dataset.__getitem__(style_index, False), 0).to(env.device).half()
style_np = cv2.resize(cv2.cvtColor(style[0].permute(1, 2, 0).float().cpu().numpy(), cv2.COLOR_BGR2RGB), (CONTENT_SIZE, CONTENT_SIZE), fx=2.5, fy=2.5)
cap.release()
cv2.destroyAllWindows()
camera_feed()