Skip to content

Commit

Permalink
revert LanguageModel
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjohnson committed Mar 4, 2016
1 parent 806b734 commit cd8d0bc
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions LanguageModel.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
require 'torch'
require 'nn'

require 'TemporalAdapter'
require 'VanillaRNN'
require 'LSTM'

Expand Down Expand Up @@ -29,8 +28,10 @@ function LM:__init(kwargs)

local V, D, H = self.vocab_size, self.wordvec_dim, self.rnn_size

self.rnns = {}
self.net = nn.Sequential()
self.rnns = {}
self.bn_view_in = {}
self.bn_view_out = {}

self.net:add(nn.LookupTable(V, D))
for i = 1, self.num_layers do
Expand All @@ -46,18 +47,47 @@ function LM:__init(kwargs)
table.insert(self.rnns, rnn)
self.net:add(rnn)
if self.batchnorm == 1 then
self.net:add(nn.TemporalAdapter(nn.BatchNormalization(H)))
local view_in = nn.View(1, 1, -1):setNumInputDims(3)
table.insert(self.bn_view_in, view_in)
self.net:add(view_in)
self.net:add(nn.BatchNormalization(H))
local view_out = nn.View(1, -1):setNumInputDims(2)
table.insert(self.bn_view_out, view_out)
self.net:add(view_out)
end
if self.dropout > 0 then
self.net:add(nn.Dropout(self.dropout))
end
end

self.net:add(nn.TemporalAdapter(nn.Linear(H, V)))
-- After all the RNNs run, we will have a tensor of shape (N, T, H);
-- we want to apply a 1D temporal convolution to predict scores for each
-- vocab element, giving a tensor of shape (N, T, V). Unfortunately
-- nn.TemporalConvolution is SUPER slow, so instead we will use a pair of
-- views (N, T, H) -> (NT, H) and (NT, V) -> (N, T, V) with a nn.Linear in
-- between. Unfortunately N and T can change on every minibatch, so we need
-- to set them in the forward pass.
self.view1 = nn.View(1, 1, -1):setNumInputDims(3)
self.view2 = nn.View(1, -1):setNumInputDims(2)

self.net:add(self.view1)
self.net:add(nn.Linear(H, V))
self.net:add(self.view2)
end


function LM:updateOutput(input)
local N, T = input:size(1), input:size(2)
self.view1:resetSize(N * T, -1)
self.view2:resetSize(N, T, -1)

for _, view_in in ipairs(self.bn_view_in) do
view_in:resetSize(N * T, -1)
end
for _, view_out in ipairs(self.bn_view_out) do
view_out:resetSize(N, T, -1)
end

return self.net:forward(input)
end

Expand Down

0 comments on commit cd8d0bc

Please sign in to comment.