forked from jeya-maria-jose/TransWeather
-
Notifications
You must be signed in to change notification settings - Fork 1
/
quantize2onnx.py
executable file
·155 lines (123 loc) · 4.72 KB
/
quantize2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import sys
from tabnanny import verbose
# sys.path.append("/content/drive/MyDrive/DERAIN/TransWeather")
import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils import validation, validation_val, calc_psnr, calc_ssim
import os
import numpy as np
import random
from transweather_model_extra import Transweather
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
from random import randrange
import torchvision.utils as utils
import cv2
import re
from tqdm import tqdm
from skimage import img_as_ubyte
from torchinfo import summary
def preprocessImage(input_img):
# Resizing image in the multiple of 16"
wd_new, ht_new, _ = input_img.shape
if ht_new>wd_new and ht_new>2048:
wd_new = int(np.ceil(wd_new*2048/ht_new))
ht_new = 2048
elif ht_new<=wd_new and wd_new>2048:
ht_new = int(np.ceil(ht_new*2048/wd_new))
wd_new = 2048
wd_new = int(16*np.ceil(wd_new/16.0))
ht_new = int(16*np.ceil(ht_new/16.0))
# input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
# input_img = cv2.resize(input_img, (wd_new, ht_new), interpolation=cv2.INTER_AREA)
input_img = cv2.resize(input_img, (ht_new, wd_new), interpolation=cv2.INTER_AREA)
# --- Transform to tensor --- #
# transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# input_im = transform_input(input_img)
# input_img = input_img / 255.0
# input_img = (input_img - 0.5) / 0.5
# transform_input = Compose([ToTensor()])
# input_im = transform_input(input_img)
input_im = torch.from_numpy(input_img.astype(np.float32))
return input_im
val_batch_size = 1
exp_name = "ckpt"
#set seed
seed = 19
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
random.seed(seed)
print('Seed:\t{}'.format(seed))
video_path = "/home/ao/tmp/clip_videos/h97cam_water_video.mp4"
output_video_path = "./videos/h97cam_water_lambda00_video.avi"
model_path = "ckpt/best_psnr+lambda0.01"
video = cv2.VideoCapture(video_path)
# video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Transweather()
net = nn.DataParallel(net)
if device == torch.device("cpu"):
net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
print("====> model ", model_path, " loaded")
else:
net.load_state_dict(torch.load(model_path))
net.to(device)
print("====> model ", model_path, " loaded")
net.eval()
net = net.module
# try quantization with quantize_dynamic
backend = "qnnpack"
net.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
net_int8 = torch.quantization.quantize_dynamic(
net, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8) # the target dtype for quantized weights
sample_img = None
while True:
ret, frame = video.read()
if not ret:
break
sample_image = frame
sample_image = cv2.resize(frame, (960, 540))
break
if sample_image is not None:
print("[INFO] image shape: ", sample_image.shape)
else:
print("[INFO] image is None")
input_img = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)
input_img = preprocessImage(input_img)
input_img = input_img.unsqueeze(0)
# input_img = input_img.to(device)
res = net_int8(input_img)
torch.onnx.export(net_int8, input_img, "./ckpt/transweather_quant.onnx", verbose=True, input_names=['input'], output_names=['output'])
print("[FINISHED] quantized onnx model exported")
# ### NOTE: start evaluation ###
# with torch.no_grad():
# while True:
# ret, frame = video.read()
# if not ret:
# break
# frame = frame[:, 180:1200, :]
# # pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# input_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# input_img = preprocessImage(input_img)
# input_img = input_img.to(device)
# input_img = input_img.unsqueeze(0)
# print("[INFO] ", input_img.shape)
# pred_image = net(input_img)
# pred_image_cpu = pred_image[0].permute(1,2,0).cpu().numpy()
# pred_image_cpu = img_as_ubyte(pred_image_cpu)
# pred_image_cpu = cv2.resize(pred_image_cpu, (frame.shape[1],frame.shape[0]))
# image = np.concatenate((frame, pred_image_cpu[..., ::-1]), axis=1)
# # video_saving.write(image)
# cv2.imshow("image", image)
# if cv2.waitKey(1) == 27:
# break