diff --git a/async/QAgent.lua b/async/QAgent.lua index ce60afb..97ca98d 100644 --- a/async/QAgent.lua +++ b/async/QAgent.lua @@ -20,7 +20,6 @@ function QAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sha self.dTheta:zero() self.doubleQ = opt.doubleQ - self.recurrent = opt.recurrent self.epsilonStart = opt.epsilonStart self.epsilon = self.epsilonStart @@ -34,7 +33,8 @@ function QAgent:_init(opt, policyNet, targetNet, theta, targetTheta, atomic, sha self.tic = 0 self.step = 0 - self.alwaysComputeGreedyQ = not self.doubleQ + -- Forward state anyway if recurrent + self.alwaysComputeGreedyQ = opt.recurrent or not self.doubleQ self.QCurr = torch.Tensor(0) end @@ -61,10 +61,6 @@ function QAgent:eGreedy(state, net) end if torch.uniform() < self.epsilon then - -- Forward state anyway if recurrent - if self.recurrent then - net:forward(state) - end return torch.random(1,self.m) end