-
Notifications
You must be signed in to change notification settings - Fork 60
/
basnet.py
87 lines (63 loc) · 2.03 KB
/
basnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import sys
sys.path.insert(0, 'BASNet')
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from data_loader import RescaleT
from data_loader import ToTensorLab
from model import BASNet
model_dir = './BASNet/saved_models/basnet_bsi/basnet.pth'
print("Loading BASNet...")
net = BASNet(3, 1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3
if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([RescaleT(256), ToTensorLab(flag=0)])
sample = transform({'image': image, 'label': label})
return sample
def run(img):
torch.cuda.empty_cache()
sample = preprocess(img)
inputs_test = sample['image'].unsqueeze(0)
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1, d2, d3, d4, d5, d6, d7, d8 = net(inputs_test)
# Normalization.
pred = d1[:, 0, :, :]
predict = normPRED(pred)
# Convert to PIL Image
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np * 255).convert('RGB')
# Cleanup.
del d1, d2, d3, d4, d5, d6, d7, d8
return im