-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
141 lines (112 loc) · 4.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import tensorflow as tf
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
except: pass
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage
from mltu.tensorflow.dataProvider import DataProvider
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
from mltu.tensorflow.metrics import CWERMetric
from model import train_model
from configs import ModelConfigs
import os
import tarfile
from tqdm import tqdm
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile
gpus = tf.config.experimental.list_physical_devices("GPU")
print(tf.config.experimental.get_memory_growth(gpus[0]))
def download_and_unzip(url, extract_to="Datasets", chunk_size=1024*1024):
http_response = urlopen(url)
data = b""
iterations = http_response.length // chunk_size + 1
for _ in tqdm(range(iterations)):
data += http_response.read(chunk_size)
zipfile = ZipFile(BytesIO(data))
zipfile.extractall(path=extract_to)
dataset_path = os.path.join("Datasets", "IAM_Words")
if not os.path.exists(dataset_path):
download_and_unzip("https://git.io/J0fjL", extract_to="Datasets")
file = tarfile.open(os.path.join(dataset_path, "words.tgz"))
file.extractall(os.path.join(dataset_path, "words"))
dataset, vocab, max_len = [], set(), 0
# Preprocess the dataset by the specific IAM_Words dataset file structure
words = open(os.path.join(dataset_path, "words.txt"), "r").readlines()
for line in tqdm(words):
if line.startswith("#"):
continue
line_split = line.split(" ")
if line_split[1] == "err":
continue
folder1 = line_split[0][:3]
folder2 = "-".join(line_split[0].split("-")[:2])
file_name = line_split[0] + ".png"
label = line_split[-1].rstrip("\n")
rel_path = os.path.join(dataset_path, "words", folder1, folder2, file_name)
if not os.path.exists(rel_path):
print(f"File not found: {rel_path}")
continue
dataset.append([rel_path, label])
vocab.update(list(label))
max_len = max(max_len, len(label))
# Create a ModelConfigs object to store model configurations
configs = ModelConfigs()
# Save vocab and maximum text length to configs
configs.vocab = "".join(vocab)
configs.max_text_length = max_len
configs.save()
# Create a data provider for the dataset
data_provider = DataProvider(
dataset=dataset,
skip_validation=True,
batch_size=configs.batch_size,
data_preprocessors=[ImageReader(CVImage)],
transformers=[
ImageResizer(configs.width, configs.height, keep_aspect_ratio=False),
LabelIndexer(configs.vocab),
LabelPadding(max_word_length=configs.max_text_length, padding_value=len(configs.vocab)),
],
)
# Split the dataset into training and validation sets
train_data_provider, val_data_provider = data_provider.split(split = 0.9)
# Augment training data with random brightness, rotation and erode/dilate
train_data_provider.augmentors = [
RandomBrightness(),
RandomErodeDilate(),
RandomSharpen(),
RandomRotate(angle=10),
]
# Creating TensorFlow model architecture
model = train_model(
input_dim = (configs.height, configs.width, 3),
output_dim = len(configs.vocab),
)
# Compile the model and print summary
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate),
loss=CTCloss(),
metrics=[CWERMetric(padding_token=len(configs.vocab))],
)
model.summary(line_length=110)
# Define callbacks
earlystopper = EarlyStopping(monitor="val_CER", patience=20, verbose=1)
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
trainLogger = TrainLogger(configs.model_path)
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=10, verbose=1, mode="auto")
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")
# Train the model
model.fit(
train_data_provider,
validation_data=val_data_provider,
epochs=100,
callbacks=[earlystopper, checkpoint, trainLogger, reduceLROnPlat, tb_callback, model2onnx],
workers=configs.train_workers
)
# Save training and validation datasets as csv files
train_data_provider.to_csv(os.path.join(configs.model_path, "train.csv"))
val_data_provider.to_csv(os.path.join(configs.model_path, "val.csv"))