-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
121 lines (78 loc) · 2.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from LoadUCF101DataByTorch import trainset_loader, testset_loader
from C3D import C3D
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
EPOCH = 10
LEARNING_RATE = 0.003
MOMENTUM = 0.9
GAMMA = 0.5
STEP_SIZE = 3
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(device)
model = C3D(num_classes=101).to(device)
optimizer = torch.optim.SGD(
model.parameters(),
lr=LEARNING_RATE,
momentum=MOMENTUM
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size = STEP_SIZE,
gamma = GAMMA
)
def save_checkpoint(path, model, optimizer):
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(state, path)
def train(epoch):
iteration = 0
loss_plt=[]
for i in range(epoch):
model.train()
print('current lr', scheduler.get_last_lr())
for index, data in enumerate(trainset_loader):
video_clips, label = data
video_clips = video_clips.to(device)
label = label.to(device)
optimizer.zero_grad()
output = model(video_clips)
loss = F.cross_entropy(output, label)
loss_plt.append(loss.item())
loss.backward()
optimizer.step()
iteration += 1
print("Epoch:", i, "/", epoch-1, "\tIteration:", index, "/", len(trainset_loader)-1, "\tLoss: " + str(loss.item()))
with open('log.txt', 'a') as f:
f.write("Epoch: " + str(i) + "/" + str(epoch-1) + "\tIteration:" + str(index) + "/" + str(len(trainset_loader)-1) + "\tLoss: " + str(loss.item()) + "\n")
save_checkpoint('model/checkpoint-%i.pth' % iteration, model, optimizer)
test(i)
scheduler.step()
save_checkpoint('model/checkpoint-%i.pth' % iteration, model, optimizer)
plt.figure()
plt.plot(loss_plt)
plt.title('Loss')
plt.xlabel('Iteration')
plt.ylabel('')
plt.show()
def test(i_epoch):
model.eval()
correct = 0
with torch.no_grad():
for index, data in enumerate(testset_loader):
video_clips, label = data
video_clips = video_clips.to(device)
label = label.to(device)
output = model(video_clips)
max_value, max_index = output.max(1, keepdim=True)
correct += max_index.eq(label.view_as(max_index)).sum().item()
print("Accuracy: " + str(correct * 1.0 * 100 / len(testset_loader.dataset)))
with open('log.txt', 'a') as f:
f.write("Epoch " + str(i_epoch) + "'s Accuracy: " + str(correct * 1.0 * 100 / len(testset_loader.dataset)) + "\n")
if __name__ == '__main__':
train(EPOCH)