-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
58 lines (47 loc) · 1.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import torch
from data_test import *
from ResNet_FillAndHollow import *
from torch.utils.data import DataLoader
from torchvision import transforms
model_name = "resnet_fill_hollow"
result_path = "./result/"
model_path = "./checkpoint/"
test_path = "./dataset/test"
dbd_dataset_dut = ImageData(root_path=test_path, name='dut')
dbd_dataset_xu = ImageData(root_path=test_path, name='xu')
model = resnet_dbd_edge()
model = model.cuda()
net = torch.load(model_path+'checkpoint.pth')
model_new = {}
for k, v in net.items():
k = k.split('module.')[-1]
model_new[k] = v
model.load_state_dict(model_new)
model.eval()
for k in range(1):
if k == 0:
name = 'xu'
dataloader = DataLoader(dbd_dataset_xu, batch_size=1, shuffle=False)
if k == 1:
name = 'dut'
dataloader = DataLoader(dbd_dataset_dut, batch_size=1, shuffle=False)
count_yes = 0
count_no = 0
save_path = result_path + 'result'
if not os.path.exists(save_path):
os.makedirs(save_path)
for i, sample_batch in enumerate(dataloader):
images_batch, dbd_batch = sample_batch[0]['image'], sample_batch[0]['dbd']
if torch.cuda.is_available():
input_image = Variable(images_batch.cuda())
dbd = Variable(dbd_batch.cuda())
else:
input_image = Variable(images_batch)
dbd = Variable(dbd_batch)
output1_dbd1, _ = model(input_image)
output_dbd1 = output1_dbd1.cpu()
output_dbd1 = output_dbd1[0, :, :, :]
output_dbd1 = torch.squeeze(output_dbd1)
img = transforms.ToPILImage()(output_dbd1)
img.save(os.path.join(save_path, str(i+1) + '.bmp'))