forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Sequential.lua
98 lines (88 loc) · 4.11 KB
/
Sequential.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
local Sequential, parent = nn.Sequential, nn.Container
function Sequential:profile()
function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do
local start = torch.Timer()
currentOutput = self.modules[i]:updateOutput(currentOutput)
if cutorch then cutorch.synchronize() end
print(torch.type(self.modules[i])..' updateOutput: '..start:time().real.." s")
end
self.output = currentOutput
return currentOutput
end
function Sequential:updateGradInput(input, gradOutput)
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
local start = torch.Timer()
currentGradOutput = currentModule:updateGradInput(previousModule.output, currentGradOutput)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' updateGradInput: '..start:time().real.." s")
currentModule = previousModule
end
local start = torch.Timer()
currentGradOutput = currentModule:updateGradInput(input, currentGradOutput)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' updateGradInput: '..start:time().real.." s")
self.gradInput = currentGradOutput
return currentGradOutput
end
function Sequential:accGradParameters(input, gradOutput, scale)
scale = scale or 1
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
local start = torch.Timer()
currentModule:accGradParameters(previousModule.output, currentGradOutput, scale)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' accGradParameters: '..start:time().real.." s")
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
local start = torch.Timer()
currentModule:accGradParameters(input, currentGradOutput, scale)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' accGradParameters: '..start:time().real.." s")
end
function Sequential:backward(input, gradOutput, scale)
scale = scale or 1
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
local start = torch.Timer()
currentGradOutput = currentModule:backward(previousModule.output, currentGradOutput, scale)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' backward: '..start:time().real.." s")
currentModule.gradInput = currentGradOutput
currentModule = previousModule
end
local start = torch.Timer()
currentGradOutput = currentModule:backward(input, currentGradOutput, scale)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' backward: '..start:time().real.." s")
self.gradInput = currentGradOutput
return currentGradOutput
end
function Sequential:accUpdateGradParameters(input, gradOutput, lr)
local currentGradOutput = gradOutput
local currentModule = self.modules[#self.modules]
for i=#self.modules-1,1,-1 do
local previousModule = self.modules[i]
local start = torch.Timer()
currentModule:accUpdateGradParameters(previousModule.output, currentGradOutput, lr)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' accUpdateGradParameters: '..start:time().real.." s")
currentGradOutput = currentModule.gradInput
currentModule = previousModule
end
local start = torch.Timer()
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
if cutorch then cutorch.synchronize() end
print(torch.type(currentModule)..' accUpdateGradParameters: '..start:time().real.." s")
end
parent.profile(self)
end