-
Notifications
You must be signed in to change notification settings - Fork 16
/
train.py
64 lines (53 loc) · 2.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
import numpy as np
from glob import glob
from transformers import get_linear_schedule_with_warmup
from utils import seed_everything
from preprocess import build_df
from dataset import split_df, get_loaders
from tokenizer import Tokenizer
from model import Encoder, Decoder, EncoderDecoder
from engine import train_eval
from config import CFG
if __name__ == '__main__':
seed_everything(42)
IMG_FILES = glob(CFG.img_path + "/*.jpg")
XML_FILES = glob(CFG.xml_path + "/*.xml")
assert len(IMG_FILES) == len(
XML_FILES) != 0, "images or xml files not found"
print("Number of found images: ", len(IMG_FILES))
df, classes = build_df(XML_FILES)
# build id to class name and vice verca mappings
cls2id = {cls_name: i for i, cls_name in enumerate(classes)}
id2cls = {i: cls_name for i, cls_name in enumerate(classes)}
train_df, valid_df = split_df(df)
print("Train size: ", train_df['id'].nunique())
print("Valid size: ", valid_df['id'].nunique())
tokenizer = Tokenizer(num_classes=len(classes), num_bins=CFG.num_bins,
width=CFG.img_size, height=CFG.img_size, max_len=CFG.max_len)
CFG.pad_idx = tokenizer.PAD_code
train_loader, valid_loader = get_loaders(
train_df, valid_df, tokenizer, CFG.img_size, CFG.batch_size, CFG.max_len, tokenizer.PAD_code)
encoder = Encoder(model_name=CFG.model_name, pretrained=True, out_dim=256)
decoder = Decoder(vocab_size=tokenizer.vocab_size,
encoder_length=CFG.num_patches, dim=256, num_heads=8, num_layers=6)
model = EncoderDecoder(encoder, decoder)
model.to(CFG.device)
optimizer = torch.optim.AdamW(
model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
num_training_steps = CFG.epochs * \
(len(train_loader.dataset) // CFG.batch_size)
num_warmup_steps = int(0.05 * num_training_steps)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
num_training_steps=num_training_steps,
num_warmup_steps=num_warmup_steps)
criterion = nn.CrossEntropyLoss(ignore_index=CFG.pad_idx)
train_eval(model,
train_loader,
valid_loader,
criterion,
optimizer,
lr_scheduler=lr_scheduler,
step='batch',
logger=None)