-
Notifications
You must be signed in to change notification settings - Fork 35
/
Prepare_image.py
66 lines (54 loc) · 2.01 KB
/
Prepare_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
from torchvision import transforms
from PIL import Image
import torch
class Image_load(object):
def __init__(self, size, stride, interpolation=Image.BILINEAR):
assert isinstance(size, int)
self.size = size
self.stride = stride
self.interpolation = interpolation
def __call__(self, img):
image = self.adaptive_resize(img)
return self.generate_patches(image, input_size=self.stride)
def adaptive_resize(self, img):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""
h, w = img.size
if h < self.size or w < self.size:
return img
else:
return transforms.ToTensor()(transforms.Resize(self.size, self.interpolation)(img))
def to_numpy(self, image):
p = image.numpy()
return p.transpose((1, 2, 0))
def generate_patches(self, image, input_size, type=np.float32):
img = self.to_numpy(image)
img_shape = img.shape
img = img.astype(dtype=type)
if len(img_shape) == 2:
H, W, = img_shape
ch = 1
else:
H, W, ch = img_shape
if ch == 1:
img = np.asarray([img, ] * 3, dtype=img.dtype)
stride = int(input_size / 2)
hIdxMax = H - input_size
wIdxMax = W - input_size
hIdx = [i * stride for i in range(int(hIdxMax / stride) + 1)]
if H - input_size != hIdx[-1]:
hIdx.append(H - input_size)
wIdx = [i * stride for i in range(int(wIdxMax / stride) + 1)]
if W - input_size != wIdx[-1]:
wIdx.append(W - input_size)
patches_numpy = [img[hId:hId + input_size, wId:wId + input_size, :]
for hId in hIdx
for wId in wIdx]
patches_tensor = [transforms.ToTensor()(p) for p in patches_numpy]
patches_tensor = torch.stack(patches_tensor, 0).contiguous()
return patches_tensor.squeeze(0)