-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_unet_long.py
executable file
·134 lines (112 loc) · 3.97 KB
/
train_unet_long.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import time
import datetime
from unet_model import *
from gen_patches import *
import os
import os.path
import numpy as np
import tifffile as tiff
from keras.callbacks import CSVLogger
from keras.callbacks import TensorBoard
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
def normalize(img):
min = img.min()
max = img.max()
x = 2.0 * (img - min) / (max - min) - 1.0
return x
N_BANDS = 8
N_CLASSES = 5 # buildings, roads, trees, crops and water
CLASS_WEIGHTS = [0.2, 0.3, 0.1, 0.1, 0.3]
N_EPOCHS = 100
UPCONV = True
PATCH_SZ = 160 # should divide by 16
BATCH_SIZE = 128
TRAIN_SZ = 1000 # train size
VAL_SZ = 200 # validation size
print('Some information about environment:')
os.system('lscpu')
def get_model():
return unet_model(
N_CLASSES,
PATCH_SZ,
n_channels=N_BANDS,
upconv=UPCONV,
class_weights=CLASS_WEIGHTS,
)
weights_path = "weights"
if not os.path.exists(weights_path):
os.makedirs(weights_path)
weights_path += "/unet_weights.hdf5"
trainIds = [
str(i).zfill(2) for i in range(1, 25)
] # all availiable ids: from "01" to "24"
class CustomCSVLoggerCallback(CSVLogger):
""" custom callback that does the same as the original but also records time per epoch """
def on_epoch_begin(self, epoch, logs=None):
self.epoch_start_time = time.time()
super().on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
elapsed = time.time() - self.epoch_start_time
logs["elapsed_time"] = elapsed
super().on_epoch_end(epoch, logs)
if __name__ == "__main__":
X_DICT_TRAIN = dict()
Y_DICT_TRAIN = dict()
X_DICT_VALIDATION = dict()
Y_DICT_VALIDATION = dict()
print("Reading images")
for img_id in trainIds:
img_m = normalize(
tiff.imread("./data/mband/{}.tif".format(img_id)).transpose([1, 2, 0])
)
mask = (
tiff.imread("./data/gt_mband/{}.tif".format(img_id)).transpose([1, 2, 0])
/ 255
)
train_xsz = int(
3 / 4 * img_m.shape[0]
) # use 75% of image as train and 25% for validation
X_DICT_TRAIN[img_id] = img_m[:train_xsz, :, :]
Y_DICT_TRAIN[img_id] = mask[:train_xsz, :, :]
X_DICT_VALIDATION[img_id] = img_m[train_xsz:, :, :]
Y_DICT_VALIDATION[img_id] = mask[train_xsz:, :, :]
print(img_id + " read")
print("Images were read")
def train_net():
print("start train net")
x_train, y_train = get_patches(
X_DICT_TRAIN, Y_DICT_TRAIN, n_patches=TRAIN_SZ, sz=PATCH_SZ
)
x_val, y_val = get_patches(
X_DICT_VALIDATION, Y_DICT_VALIDATION, n_patches=VAL_SZ, sz=PATCH_SZ
)
model = get_model()
if os.path.isfile(weights_path):
model.load_weights(weights_path)
# model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss', save_weights_only=True, save_best_only=True)
# early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode='auto')
# reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=5, min_lr=0.00001)
model_checkpoint = ModelCheckpoint(
weights_path, monitor="val_loss", save_best_only=True
)
now = datetime.datetime.now()
csvfpath = (
f"log_unet_{now.year}-{now.month}-{now.day}-{now.hour}:{now.minute}.csv"
)
csv_logger = CustomCSVLoggerCallback(csvfpath, append=True, separator=";")
tensorboard = TensorBoard(
log_dir="./tensorboard_unet/", write_graph=True, write_images=True
)
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=N_EPOCHS,
verbose=1,
shuffle=True,
callbacks=[model_checkpoint, csv_logger, tensorboard],
validation_data=(x_val, y_val),
)
return model
train_net()