-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
99 lines (77 loc) · 3.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import math
import os
from tqdm import tqdm
import torch
from model import DuckNet
from utils.dataloader import get_dataloader
from utils.losses import criterion
from utils.metrics import dice_coef, jaccard_index, precision, recall
from utils.config import Configs
def train_epoch(model:DuckNet, train_loader, optimizer:torch.optim.Optimizer, device:torch.device):
model.train()
train_loss = 0
for batch in tqdm(train_loader):
images, masks = batch['image'], batch['mask']
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(masks, outputs)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss / len(train_loader)
@torch.no_grad()
def valid_epoch(model:DuckNet, valid_loader, device:torch.device):
model.eval()
valid_loss = 0
dice = 0
jaccard = 0
prec = 0
rec = 0
for batch in valid_loader:
images, masks = batch['image'], batch['mask']
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(masks, outputs)
valid_loss += loss.item()
dice += dice_coef(masks, outputs)
jaccard += jaccard_index(masks, outputs)
prec += precision(masks, outputs)
rec += recall(masks, outputs)
return valid_loss / len(valid_loader), dice / len(valid_loader), jaccard / len(valid_loader), prec / len(valid_loader), rec / len(valid_loader)
def main(config:Configs):
device = torch.device(f'cuda:{config.gpu_id}' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
model = DuckNet(input_channels=config.input_channels, num_classes=config.num_classes, num_filters=config.num_filters)
model.to(device)
print('Model created')
train_dl = get_dataloader(os.path.join(config.ROOT_DIR, 'data/train'), config.input_channels, config.batch_size, shuffle=True, num_workers=config.num_workers)
valid_dl = get_dataloader(os.path.join(config.ROOT_DIR, 'data/val'), config.input_channels, config.batch_size, shuffle=False, num_workers=config.num_workers)
print('DataLoaders created')
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, betas=config.betas, weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
os.makedirs(os.path.join(config.ROOT_DIR, config.save_dir), exist_ok=True)
best_valid_loss = math.inf
early_stop_patience = 0
print('Starting training')
for epoch in range(config.num_epochs):
train_loss = train_epoch(model, train_dl, optimizer, device)
valid_loss, dice, jaccard, prec, rec = valid_epoch(model, valid_dl, device)
scheduler.step()
msg = f'Epoch {epoch+1}/{config.num_epochs} - Train Loss: {train_loss:.4f} - Valid Loss: {valid_loss:.4f} - Dice: {dice:.4f} - Jaccard: {jaccard:.4f} - Precision: {prec:.4f} - Recall: {rec:.4f}'
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), os.path.join(config.ROOT_DIR, config.save_dir, 'best_model.pt'))
msg += ' - Model saved'
early_stop_patience = 0
else:
early_stop_patience += 1
if early_stop_patience >= config.early_stopping:
msg += ' - Early stopping'
break
print(msg)
print('Training completed')
if __name__ == '__main__':
config = Configs()
config.ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
main(config)