-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
91 lines (77 loc) · 2.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Training mask r-cnn
"""
from utils.setup import *
from utils.path import train_dir, model_dir, result_dir, log_dir
from model.mask_rcnn import MaskRCNN
from dataset.dataset import CellDataset, train_collate
from model.loss import total_loss
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.optim import Adam, SGD
import torch.nn as nn
import torch
import time
from config import Config
# configurations
lr = Config.LEARNING_RATE
mom = Config.LEARNING_MOMENTUM
num_epoch = 10
batch_size = Config.IMAGES_PER_GPU
print_freq = 10
# data transformations
#transform = transforms.Compose([
# transforms.Resize(),
# transforms.ToTensor(),
# transforms.Normalize()
#])
config = Config()
# load data
train_dataset = CellDataset(train_dir, config)#, transform=transform)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
collate_fn=train_collate)
# model setup
model = MaskRCNN(config).cuda()
#optimizer = Adam(model.parameters(), lr=lr)
optimizer = SGD(model.parameters(), lr=lr, momentum=mom)
for epoch in range(num_epoch):
running_loss = 0
end = time.time()
for i, (imgs, gts) in enumerate(train_loader):
imgs = imgs.float().cuda()
data_time = time.time() - end
# compute loss
logits = model.forward(imgs)
loss, saved_for_log = total_loss(logits, gts, config)
# learn & update params
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print logs
running_loss += saved_for_log['total_loss']
rpn_loss = saved_for_log['rpn_cls_loss']+saved_for_log['rpn_reg_loss']
mask_loss = saved_for_log['stage2_mask_loss']
rcnn_loss = saved_for_log['stage2_cls_loss']+saved_for_log['stage2_reg_loss']
batch_time = time.time() - end
end = time.time()
if i % print_freq == print_freq-1:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time:.3f}\t'
'Data {data_time:.3f}\t'
'Loss {loss:.4f}\t'
'Rpn {rpn:.4f}\t'
'Rcnn {rcnn:.4f}\t'
'Mask {mask:.4f}\t'.format(
epoch+1, i, len(train_loader),
batch_time=batch_time,
data_time=data_time,
loss=running_loss/print_freq,
rpn=rpn_loss,
mask=mask_loss,
rcnn=rcnn_loss))
running_loss = 0