-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
72 lines (56 loc) · 2.37 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import glob
import torch
from tqdm import tqdm
from tensorboardX import SummaryWriter
from config import max_checkpoint_num, proposalN, eval_trainset, set
import eval
def train(model,
trainloader,
testloader,
criterion,
optimizer,
scheduler,
save_path,
start_epoch,
end_epoch,
save_interval):
for epoch in range(start_epoch + 1, end_epoch + 1):
model.train()
print('Training %d epoch' % epoch)
lr = next(iter(optimizer.param_groups))['lr']
for i, data in enumerate(tqdm(trainloader)):
if set == 'CUB':
images, labels, _, _ = data
else:
images, labels = data
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
proposalN_windows_scores, proposalN_windows_logits, proposalN_indices, window_scores, raw_logits, local_logits, FMA_Output
= model(images, epoch, i, 'train')
raw_loss = criterion(raw_logits, labels)
local_loss = criterion(local_logits, labels)
windowscls_loss = criterion(proposalN_windows_logits,
labels.unsqueeze(1).repeat(1, proposalN).view(-1))
if epoch < 2:
total_loss = raw_loss
else:
total_loss = raw_loss + local_loss + windowscls_loss
total_loss.backward()
optimizer.step()
scheduler.step()
# save checkpoint
if (epoch % save_interval == 0) or (epoch == end_epoch):
print('Saving checkpoint')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'learning_rate': lr,
}, os.path.join(save_path, 'epoch' + str(epoch) + '.pth'))
# Limit the number of checkpoints to less than or equal to max_checkpoint_num,
# and delete the redundant ones
checkpoint_list = [os.path.basename(path) for path in glob.glob(os.path.join(save_path, '*.pth'))]
if len(checkpoint_list) == max_checkpoint_num + 1:
idx_list = [int(name.replace('epoch', '').replace('.pth', '')) for name in checkpoint_list]
min_idx = min(idx_list)
os.remove(os.path.join(save_path, 'epoch' + str(min_idx) + '.pth'))