forked from xuebinqin/U-2-Net
-
Notifications
You must be signed in to change notification settings - Fork 14
/
demo.py
104 lines (77 loc) · 2.68 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from collections import defaultdict
from glob import glob
import PIL
import numpy as np
import streamlit as st
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, CenterCrop
import torchvision.transforms.functional as F
from pymatting import *
from lib import U2NET_full
from lib.utils.oom import free_up_memory
def create_ui(samples):
st.sidebar.title('u2net - segmentation')
st.sidebar.title('Select a model')
model_select = st.sidebar.selectbox('', [
'u2net_human_seg',
'checkpoint.pth'
], index=1)
st.sidebar.title('Select a sample')
sample_select = st.sidebar.selectbox('', samples)
return model_select, sample_select
def load_samples(folder_path='./dataset/demo'):
assert os.path.isdir(folder_path), f'Unable to open {folder_path}'
samples = glob(os.path.join(folder_path, f'*.png'))
return samples
device = 'cuda'
samples = load_samples()
model_select, sample_select = create_ui(samples)
def square_pad(image, fill=255):
w, h = image.size
max_wh = np.max([w, h])
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
padding = (hp, vp, hp, vp)
return F.pad(image, padding, fill, 'constant')
def get_transform():
transforms = []
# transforms.append(Resize(440)) # TBD: keep aspect ratio
transforms.append(ToTensor())
transforms.append(Normalize(mean=[.5,.5,.5],
std=[.5,.5,.5]))
return Compose(transforms)
device = 'cpu'
checkpoint = torch.load(f'./checkpoints/{model_select}.pth', map_location=device)
model = U2NET_full().to(device=device)
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
else:
model.load_state_dict(checkpoint)
image = Image.open(sample_select).convert('RGB')
image = square_pad(image, 0)
image = image.resize((448, 448), Image.ANTIALIAS)
st.image(image, width=800)
transforms = get_transform()
model.eval()
with torch.no_grad():
x = transforms(image)
x = x.to(device=device).unsqueeze(dim=0)
y_hat, _ = model(x)
alpha_image = y_hat.mul(255)
alpha_image = Image.fromarray(alpha_image.squeeze().cpu().detach().numpy()).convert('L')
st.image(alpha_image, width=800)
image = np.asarray(image)
background = np.zeros(image.shape)
background[:, :] = [0, 177 / 255, 64 / 255]
alpha = y_hat.squeeze().cpu().detach()
alpha = np.asarray(alpha)
# alpha = (alpha * 255).astype(np.uint8)
image = image.astype(np.float32) / 255
foreground = estimate_foreground_ml(
image, alpha) # , n_big_iterations=1, n_small_iterations=1, regularization=10e-10
new_image = blend(foreground, background, alpha)
st.image(new_image, width=800)
del y_hat
free_up_memory()