-
Notifications
You must be signed in to change notification settings - Fork 48
/
train.py
137 lines (116 loc) · 5.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# torch and visulization
from tqdm import tqdm
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from model.parse_args_train import parse_args
# metric, loss .etc
from model.utils import *
from model.metric import *
from model.loss import *
from model.load_param_data import load_dataset, load_param
# model
from model.model_DNANet import Res_CBAM_block
from model.model_DNANet import DNANet
class Trainer(object):
def __init__(self, args):
# Initial
self.args = args
self.ROC = ROCMetric(1, 10)
self.mIoU = mIoU(1)
self.save_prefix = '_'.join([args.model, args.dataset])
self.save_dir = args.save_dir
nb_filter, num_blocks = load_param(args.channel_size, args.backbone)
# Read image index from TXT
if args.mode == 'TXT':
dataset_dir = args.root + '/' + args.dataset
train_img_ids, val_img_ids, test_txt = load_dataset(args.root, args.dataset, args.split_method)
# Preprocess and load data
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
trainset = TrainSetLoader(dataset_dir,img_id=train_img_ids,base_size=args.base_size,crop_size=args.crop_size,transform=input_transform,suffix=args.suffix)
testset = TestSetLoader (dataset_dir,img_id=val_img_ids,base_size=args.base_size, crop_size=args.crop_size, transform=input_transform,suffix=args.suffix)
self.train_data = DataLoader(dataset=trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.workers,drop_last=True)
self.test_data = DataLoader(dataset=testset, batch_size=args.test_batch_size, num_workers=args.workers,drop_last=False)
# Choose and load model (this paper is finished by one GPU)
if args.model == 'DNANet':
model = DNANet(num_classes=1,input_channels=args.in_channels, block=Res_CBAM_block, num_blocks=num_blocks, nb_filter=nb_filter, deep_supervision=args.deep_supervision)
model = model.cuda()
model.apply(weights_init_xavier)
print("Model Initializing")
self.model = model
# Optimizer and lr scheduling
if args.optimizer == 'Adam':
self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
elif args.optimizer == 'Adagrad':
self.optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
if args.scheduler == 'CosineAnnealingLR':
self.scheduler = lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=args.epochs, eta_min=args.min_lr)
self.scheduler.step()
# Evaluation metrics
self.best_iou = 0
self.best_recall = [0,0,0,0,0,0,0,0,0,0,0]
self.best_precision = [0,0,0,0,0,0,0,0,0,0,0]
# Training
def training(self,epoch):
tbar = tqdm(self.train_data)
self.model.train()
losses = AverageMeter()
for i, ( data, labels) in enumerate(tbar):
data = data.cuda()
labels = labels.cuda()
if args.deep_supervision == 'True':
preds= self.model(data)
loss = 0
for pred in preds:
loss += SoftIoULoss(pred, labels)
loss /= len(preds)
else:
pred = self.model(data)
loss = SoftIoULoss(pred, labels)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
losses.update(loss.item(), pred.size(0))
tbar.set_description('Epoch %d, training loss %.4f' % (epoch, losses.avg))
self.train_loss = losses.avg
# Testing
def testing (self, epoch):
tbar = tqdm(self.test_data)
self.model.eval()
self.mIoU.reset()
losses = AverageMeter()
with torch.no_grad():
for i, ( data, labels) in enumerate(tbar):
data = data.cuda()
labels = labels.cuda()
if args.deep_supervision == 'True':
preds = self.model(data)
loss = 0
for pred in preds:
loss += SoftIoULoss(pred, labels)
loss /= len(preds)
pred =preds[-1]
else:
pred = self.model(data)
loss = SoftIoULoss(pred, labels)
losses.update(loss.item(), pred.size(0))
self.ROC .update(pred, labels)
self.mIoU.update(pred, labels)
ture_positive_rate, false_positive_rate, recall, precision = self.ROC.get()
_, mean_IOU = self.mIoU.get()
tbar.set_description('Epoch %d, test loss %.4f, mean_IoU: %.4f' % (epoch, losses.avg, mean_IOU ))
test_loss=losses.avg
# save high-performance model
save_model(mean_IOU, self.best_iou, self.save_dir, self.save_prefix,
self.train_loss, test_loss, recall, precision, epoch, self.model.state_dict())
def main(args):
trainer = Trainer(args)
for epoch in range(args.start_epoch, args.epochs):
trainer.training(epoch)
trainer.testing(epoch)
if __name__ == "__main__":
args = parse_args()
main(args)