-
Notifications
You must be signed in to change notification settings - Fork 20
/
main.lua
75 lines (65 loc) · 1.5 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
require('torch')
require('nn')
require('image')
require('optim')
require('json')
opts = dofile('opts.lua')
opt = opts.parse(arg)
print(opt)
json.save(paths.concat(opt.result_path, 'opts.json'), opt)
torch.manualSeed(opt.manual_seed)
if not opt.no_cuda then
require('cutorch')
require('cunn')
require('cudnn')
cudnn.fastest = true
cudnn.benchmark = true
cudnn.verbose = false
cutorch.setDevice(opt.gpu_id)
cutorch.manualSeed(opt.manual_seed)
end
utils = dofile('utils.lua')
dofile('mean.lua')
dofile('model.lua')
dofile('dataset.lua')
dofile('data_threads.lua')
if not opt.no_train then
optimizer = optim['sgd']
optim_state = {
learningRate = opt.learning_rate,
weightDecay = opt.weight_decay,
momentum = opt.momentum,
learningRateDecay = 0,
}
if opt.freeze_params ~= 0 then
lrs, lrs_model = utils.get_learning_rates_for_freezing_layers(model, opt.freeze_params)
optim_state.learningRates = lrs
end
dofile('train.lua')
end
if not opt.no_val then
dofile('val.lua')
end
print('run')
for i = opt.begin_epoch, opt.n_epochs do
epoch = i
if not opt.no_train then
if opt.regimes ~= None then
for _, row in ipairs(opt.regimes) do
if epoch >= row[1] and epoch <= row[2] then
optim_state.learningRate = row[3]
optim_state.weightDecay = row[4]
end
end
end
train_epoch()
end
if not opt.no_val then
val_epoch()
end
end
if opt.test_video then
print('test video')
dofile('test_video.lua')
test_video()
end