-
Notifications
You must be signed in to change notification settings - Fork 10
/
run.lua
61 lines (53 loc) · 1.78 KB
/
run.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
----------------------------------------------------------------------
-- Main script for training a model for semantic segmentation
--
-- Abhishek Chaurasia, Eugenio Culurciello
-- Sangpil Kim, Adam Paszke
-- Edited by Eren Golge
----------------------------------------------------------------------
require 'pl'
require 'nn'
require 'cudnn'
require 'cunn'
local opts = require 'opts'
local DataLoader = require 'data/dataloader'
torch.setdefaulttensortype('torch.FloatTensor')
----------------------------------------------------------------------
-- Get the input arguments parsed and stored in opt
opt = opts.parse(arg)
print(opt)
--print(cutorch.getDeviceProperties(opt.devid))
cutorch.setDevice(opt.devid)
print("Folder created at " .. opt.save)
os.execute('mkdir -p ' .. opt.save)
----------------------------------------------------------------------
print '==> load modules'
local data, chunks, ft
-- data loading
local trainLoader, valLoader = DataLoader.create(opt)
opt.classes = trainLoader.classes
-- save opt to file
print 'saving opt as txt and t7'
local filename = paths.concat(opt.save,'opt.txt')
local file = io.open(filename, 'w')
for i,v in pairs(opt) do
file:write(tostring(i)..' : '..tostring(v)..'\n')
end
file:close()
torch.save(path.join(opt.save,'opt.t7'),opt)
----------------------------------------------------------------------
print '==> training!'
local epoch = 1
-- create model
t = paths.dofile("models/"..opt.model..".lua")
local train = require 'train'
local test = require 'test'
local besterr = 9999999999
while epoch < opt.maxepoch do
print("----- epoch # " .. epoch)
local trainConf, model, loss = train(trainLoader, epoch)
besterr = test(valLoader, epoch, trainConf, model, loss, besterr )
-- trainConf = nil
collectgarbage()
epoch = epoch + 1
end