-
Notifications
You must be signed in to change notification settings - Fork 0
/
DebugModule.lua
79 lines (64 loc) · 2.63 KB
/
DebugModule.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
require 'nn'
require 'image'
require 'strict'
local _
--------------------------------------- DebugModule---------------------------------------
-- Plots and prints incoming fw/bw data
--TODO: would be nice to add logging to a Logger (a special DebugLogging instance shared among all DebugModules). User can specify aggregation function of input/gradOutput->single number
local DebugModule, parent = torch.class('myrock.DebugModule', 'nn.Module')
local help_desc = [[todo]]
function DebugModule:__init(config)
assert(config ~= nil)
parent.__init(self)
_, self.name, self.plot, self.print =
xlua.unpack({config}, 'myrock.DebugModule', help_desc,
{arg='name', type='string', help='User identifier of the module', req=true},
{arg='plot', type='boolean', help='Plot yes/no', req=false, default=false},
{arg='print', type='boolean', help='Full print yes/no', req=false, default=false}
)
end
function DebugModule:updateOutput(input)
assert(input ~= nil)
self:displayData(input, 'FW');
self.output = input
--torch.save('/home/simonovm/workspace/pyrann/'..self.name..'FW', input)
return self.output
end
function DebugModule:updateGradInput(input, gradOutput)
assert(input ~= nil and gradOutput ~= nil)
self:displayData(gradOutput, 'BW');
self.gradInput = gradOutput
--torch.save('/home/simonovm/workspace/pyrann/'..self.name..'BW', gradOutput)
return self.gradInput
end
function DebugModule:displayData(input, fwbw)
assert(input ~= nil and fwbw ~= nil)
if (torch.isTensor(input)) then
print('DebugModule ' .. self.name .. ': ' .. fwbw .. ' input size ' .. formatSizeStr(input))
if (self.print) then
print(input)
end
if (self.plot) then
if (input:dim()==4) then
for i=1,input:size(1) do
image.display{image=input[i], legend=self.name .. ' /slice' .. i .. '_' .. fwbw}
end
else
image.display{image=input, legend=self.name .. '_' .. fwbw}
end
end
elseif (torch.type(input) == 'table') then
print('DebugModule ' .. self.name .. ': ' .. fwbw)
print(formatSizeStr(input))
if (self.plot) then
for i=1,#input do
if (input[i]:dim()>0) then
image.display{image=input[i], legend=self.name .. '_' .. fwbw .. ' (' .. i .. ')'}
end
end
end
else
print('DebugModule ' .. self.name .. ': ' .. fwbw .. 'unknown type ' .. torch.type(input))
end
return input
end