-
Notifications
You must be signed in to change notification settings - Fork 0
/
transform.py
90 lines (68 loc) · 2.44 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
from skimage.transform import rescale, rotate
from torchvision.transforms import Compose
def transforms(scale=None, angle=None, flip_prob=None):
transform_list = []
if scale is not None:
transform_list.append(Scale(scale))
if angle is not None:
transform_list.append(Rotate(angle))
if flip_prob is not None:
transform_list.append(HorizontalFlip(flip_prob))
return Compose(transform_list)
class Scale(object):
def __init__(self, scale):
self.scale = scale
def __call__(self, sample):
image, mask = sample
img_size = image.shape[0]
scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)
image = rescale(
image,
(scale, scale),
channel_axis=-1,
preserve_range=True,
mode="constant",
anti_aliasing=False,
)
mask = rescale(
mask,
(scale, scale),
order=0,
channel_axis=-1,
preserve_range=True,
mode="constant",
anti_aliasing=False,
)
if scale < 1.0:
diff = (img_size - image.shape[0]) / 2.0
padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
image = np.pad(image, padding, mode="constant", constant_values=0)
mask = np.pad(mask, padding, mode="constant", constant_values=0)
else:
x_min = (image.shape[0] - img_size) // 2
x_max = x_min + img_size
image = image[x_min:x_max, x_min:x_max, ...]
mask = mask[x_min:x_max, x_min:x_max, ...]
return image, mask
class Rotate(object):
def __init__(self, angle):
self.angle = angle
def __call__(self, sample):
image, mask = sample
angle = np.random.uniform(low=-self.angle, high=self.angle)
image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
mask = rotate(
mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
)
return image, mask
class HorizontalFlip(object):
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, sample):
image, mask = sample
if np.random.rand() > self.flip_prob:
return image, mask
image = np.fliplr(image).copy()
mask = np.fliplr(mask).copy()
return image, mask