forked from lRomul/argus-tgs-salt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
after_train_folds.py
109 lines (86 loc) · 3.62 KB
/
after_train_folds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from os.path import join
import re
import math
import torch
import argus
from argus.callbacks import MonitorCheckpoint, LoggingToFile
from argus import load_model
from torch.utils.data import DataLoader
from src.dataset import SaltDataset
from src.transforms import SimpleDepthTransform, SaltTransform
from src.argus_models import SaltMetaModel
from src.losses import LovaszProbLoss
from src import config
def get_best_model_path(dir_path):
model_scores = []
for model_name in os.listdir(dir_path):
score = re.search(r'-(\d+(?:\.\d+)?).pth', model_name)
if score is not None:
score = score.group(0)[1:-4]
model_scores.append((model_name, score))
model_score = sorted(model_scores, key=lambda x: x[1])
best_model_name = model_score[-1][0]
best_model_path = os.path.join(dir_path, best_model_name)
return best_model_path
BASE_EXPERIMENT_NAME = 'mos-fpn-lovasz-se-resnext50-001'
EXPERIMENT_NAME = 'mos-fpn-lovasz-se-resnext50-001-after-001'
FOLDS = list(range(config.N_FOLDS))
BATCH_SIZE = 16
IMAGE_SIZE = (128, 128)
OUTPUT_SIZE = (101, 101)
TRAIN_FOLDS_PATH = '/workdir/data/train_folds_148_mos_emb_1.csv'
LR = 0.005
SAVE_DIR = f'/workdir/data/experiments/{EXPERIMENT_NAME}'
class CosineAnnealingLR:
def __init__(self, base_lr, T_max, eta_min=0.):
self.T_max = T_max
self.eta_min = eta_min
self.base_lr = base_lr
def __call__(self, epoch):
return self.eta_min + (self.base_lr - self.eta_min) \
* (1 + math.cos(math.pi * (epoch % self.T_max) / self.T_max)) / 2
cos_ann = CosineAnnealingLR(LR, 50, eta_min=LR*0.1)
@argus.callbacks.on_epoch_start
def update_lr(state: argus.engine.State):
lr = cos_ann(state.epoch)
state.model.set_lr(lr)
state.logger.info(f"Set lr: {lr}")
def train_fold(save_dir, train_folds, val_folds, model_path):
depth_trns = SimpleDepthTransform()
train_trns = SaltTransform(IMAGE_SIZE, True, 'crop')
val_trns = SaltTransform(IMAGE_SIZE, False, 'crop')
train_dataset = SaltDataset(TRAIN_FOLDS_PATH, train_folds, train_trns, depth_trns)
val_dataset = SaltDataset(TRAIN_FOLDS_PATH, val_folds, val_trns, depth_trns)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
drop_last=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
model = load_model(model_path)
model.loss.lovasz_weight = 0.5
model.loss.prob_weight = 0.5
callbacks = [
MonitorCheckpoint(save_dir, monitor='val_crop_iout', max_saves=3, copy_last=False),
LoggingToFile(os.path.join(save_dir, 'log.txt')),
update_lr
]
model.fit(train_loader,
val_loader=val_loader,
max_epochs=500,
callbacks=callbacks,
metrics=['crop_iout'])
if __name__ == "__main__":
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR)
else:
print(f"Folder {SAVE_DIR} already exists.")
with open(os.path.join(SAVE_DIR, 'source.py'), 'w') as outfile:
outfile.write(open(__file__).read())
for i in range(len(FOLDS)):
val_folds = [FOLDS[i]]
train_folds = FOLDS[:i] + FOLDS[i + 1:]
save_fold_dir = os.path.join(SAVE_DIR, f'fold_{FOLDS[i]}')
print(f"Val folds: {val_folds}, Train folds: {train_folds}")
print(f"Fold save dir {save_fold_dir}")
model_path = get_best_model_path(join('/workdir/data/experiments', BASE_EXPERIMENT_NAME, 'fold_%d' % i))
print(f'Base model path: {model_path}')
train_fold(save_fold_dir, train_folds, val_folds, model_path)