-
Notifications
You must be signed in to change notification settings - Fork 480
/
augment_data.py
82 lines (68 loc) · 2.57 KB
/
augment_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import imgaug
import numpy as np
from concern.config import State
from .data_process import DataProcess
from data.augmenter import AugmenterBuilder
import cv2
import math
class AugmentData(DataProcess):
augmenter_args = State(autoload=False)
def __init__(self, **kwargs):
self.augmenter_args = kwargs.get('augmenter_args')
self.keep_ratio = kwargs.get('keep_ratio')
self.only_resize = kwargs.get('only_resize')
self.augmenter = AugmenterBuilder().build(self.augmenter_args)
def may_augment_annotation(self, aug, data):
pass
def resize_image(self, image):
origin_height, origin_width, _ = image.shape
resize_shape = self.augmenter_args[0][1]
height = resize_shape['height']
width = resize_shape['width']
if self.keep_ratio:
width = origin_width * height / origin_height
N = math.ceil(width / 32)
width = N * 32
image = cv2.resize(image, (width, height))
return image
def process(self, data):
image = data['image']
aug = None
shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
if self.only_resize:
data['image'] = self.resize_image(image)
else:
data['image'] = aug.augment_image(image)
self.may_augment_annotation(aug, data, shape)
filename = data.get('filename', data.get('data_id', ''))
data.update(filename=filename, shape=shape[:2])
if not self.only_resize:
data['is_training'] = True
else:
data['is_training'] = False
return data
class AugmentDetectionData(AugmentData):
def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
line_polys = []
for line in data['lines']:
if self.only_resize:
new_poly = [(p[0], p[1]) for p in line['poly']]
else:
new_poly = self.may_augment_poly(aug, shape, line['poly'])
line_polys.append({
'points': new_poly,
'ignore': line['text'] == '###',
'text': line['text'],
})
data['polys'] = line_polys
return data
def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly