-
Notifications
You must be signed in to change notification settings - Fork 4
/
trainTwoStreamNet.py
103 lines (64 loc) · 2.7 KB
/
trainTwoStreamNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from LoadUCF101Data import trainset_loader, testset_loader
from Two_Stream_Net import TwoStreamNet
import torch
import torch.optim as optim
import torch.nn.functional as F
EPOCH = 100
LEARNING_RATE = 0.0001
MOMENTUM = 0.9
SAVE_INTERVAL = 500
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
twoStreamNet = TwoStreamNet().to(device)
optimizer = optim.SGD(
params=twoStreamNet.parameters(),
lr=LEARNING_RATE,
momentum=MOMENTUM
)
def save_checkpoint(path, model, optimizer):
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, path)
def train(epoch, save_interval):
iteration = 0
for i in range(epoch):
twoStreamNet.train()
for index, data in enumerate(trainset_loader):
RGB_images, OpticalFlow_images, label = data
RGB_images = RGB_images.to(device)
OpticalFlow_images = OpticalFlow_images.to(device)
label = label.to(device)
optimizer.zero_grad()
output = twoStreamNet(RGB_images, OpticalFlow_images)
loss = F.cross_entropy(output, label)
loss.backward()
optimizer.step()
if iteration % save_interval == 0 and iteration > 0:
save_checkpoint('model/checkpoint-%i.pth' % iteration, twoStreamNet, optimizer) # OpticalFlow_ResNetModel
iteration += 1
print("Loss: " + str(loss.item()))
with open('log.txt', 'a') as f:
f.write("Epoch " + str(i+1) + ", Iteration " + str(index+1) + "'s Loss: " + str(loss.item()) + "\n")
test(i+1)
save_checkpoint('model/checkpoint-%i.pth' % iteration, twoStreamNet, optimizer)
def test(i_epoch):
twoStreamNet.eval()
correct = 0
with torch.no_grad():
for index, data in enumerate(testset_loader):
RGB_images, OpticalFlow_images, label = data
RGB_images = RGB_images.to(device)
OpticalFlow_images = OpticalFlow_images.to(device)
label = label.to(device)
output = twoStreamNet(RGB_images, OpticalFlow_images)
max_value, max_index = output.max(1, keepdim=True)
correct += max_index.eq(label.view_as(max_index)).sum().item()
print("Accuracy: " + str(correct*1.0*100/len(testset_loader.dataset)))
with open('log.txt', 'a') as f:
f.write("Epoch " + str(i_epoch) + "'s Accuracy: " + str(correct*1.0*100/len(testset_loader.dataset)) + "\n")
if __name__ == '__main__':
train(EPOCH, SAVE_INTERVAL)