-
Notifications
You must be signed in to change notification settings - Fork 69
/
train.py
112 lines (92 loc) · 3.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam, Nadam
from tensorflow.keras.metrics import *
from glob import glob
from sklearn.model_selection import train_test_split
from model import build_model
from utils import *
from metrics import *
def read_image(x):
x = x.decode()
image = cv2.imread(x, cv2.IMREAD_COLOR)
image = np.clip(image - np.median(image)+127, 0, 255)
image = image/255.0
image = image.astype(np.float32)
return image
def read_mask(y):
y = y.decode()
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
mask = mask/255.0
mask = mask.astype(np.float32)
mask = np.expand_dims(mask, axis=-1)
return mask
def parse_data(x, y):
def _parse(x, y):
x = read_image(x)
y = read_mask(y)
y = np.concatenate([y, y], axis=-1)
return x, y
x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
x.set_shape([384, 512, 3])
y.set_shape([384, 512, 2])
return x, y
def tf_dataset(x, y, batch=8):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=32)
dataset = dataset.map(map_func=parse_data)
dataset = dataset.repeat()
dataset = dataset.batch(batch)
return dataset
if __name__ == "__main__":
np.random.seed(42)
tf.random.set_seed(42)
create_dir("files")
train_path = "../1/new_data/train/"
valid_path = "../1/new_data/valid/"
## Training
train_x = sorted(glob(os.path.join(train_path, "image", "*.jpg")))
train_y = sorted(glob(os.path.join(train_path, "mask", "*.jpg")))
## Shuffling
train_x, train_y = shuffling(train_x, train_y)
## Validation
valid_x = sorted(glob(os.path.join(valid_path, "image", "*.jpg")))
valid_y = sorted(glob(os.path.join(valid_path, "mask", "*.jpg")))
model_path = "files/model.h5"
batch_size = 16
epochs = 300
lr = 1e-4
shape = (384, 512, 3)
model = build_model(shape)
metrics = [
dice_coef,
iou,
Recall(),
Precision()
]
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)
model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=metrics)
callbacks = [
ModelCheckpoint(model_path),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=20),
CSVLogger("files/data.csv"),
TensorBoard(),
EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False)
]
train_steps = (len(train_x)//batch_size)
valid_steps = (len(valid_x)//batch_size)
if len(train_x) % batch_size != 0:
train_steps += 1
if len(valid_x) % batch_size != 0:
valid_steps += 1
model.fit(train_dataset,
epochs=epochs,
validation_data=valid_dataset,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
callbacks=callbacks,
shuffle=False)