Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent Dqn #8

Open
lake4790k opened this issue Apr 21, 2016 · 31 comments
Open

Recurrent Dqn #8

lake4790k opened this issue Apr 21, 2016 · 31 comments

Comments

@lake4790k
Copy link
Collaborator

One central element of the Atari DQN is the use of 4 consecutive frames as input making the state more Markov, ie. having the vital dynamic movement information. This paper http://arxiv.org/abs/1507.06527v3 discusses DRQN: the multiframe input can be substituted with LSTM with the same effect (but no systematic advantage for one or the other). Also the Deepmind async paper mentions using LSTM instead of multi frame inputs for more challenging visual domains (Torcs and Labyrinth).

I think this would fit well in this codebase, I'll try to contribute this at one point.

@Kaixhin
Copy link
Owner

Kaixhin commented Apr 21, 2016

Yep a switch for using a DRQN architecture would be great. For now I'd go for using histLen as the number of frames to use BPTT on for a single-frame DRQN. Would be good to base it on the rnn library, especially since it now has the optimised SeqLSTM.

@lake4790k
Copy link
Collaborator Author

This is the Caffe implementation from the paper:
https://github.com/mhauskn/dqn/tree/recurrent

Altough Caffe I never looked at probably will help.

@lake4790k
Copy link
Collaborator Author

@Kaixhin I see you started working on this, cool. I'll have some time now, so I'll look at the multigpu and async modes.

@Kaixhin
Copy link
Owner

Kaixhin commented May 2, 2016

@lake4790k Almost have something working. Disabling this line lets the DRQN train, as otherwise it crashes here, somehow propagating a batch of size 20 forward but expecting the normal batch size of 32 backwards.

I'm new to the rnn library, so let me know if you have any ideas. Performance is considerably slower, which will be due to having to process several time steps sequentially. This is in line with Appendix B in that paper though.

@lake4790k
Copy link
Collaborator Author

@Kaixhin Awesome! I have no experience with rnn either, I will need to study it to have an idea. I have two 980TIs and will be able to run longer experiments to see if it goes anywhere.

@Kaixhin
Copy link
Owner

Kaixhin commented May 4, 2016

@lake4790k I'd have to delve into the original paper/code, but it looks like they train the network every step (as opposed to every 4). This seems like it'll be a problem for BPTT. In any case if you haven't used rnn before I'll focus on this.

@lake4790k
Copy link
Collaborator Author

@Kaixhin cool, I'll have my hands full with async for now, but in the meantime I'll be able to help with running longer rdqn experiments on my workstation when you think it's worth trying.

@Kaixhin
Copy link
Owner

Kaixhin commented May 12, 2016

Here's the result of running ./run.sh demo -recurrent true, so I'm reasonably confident that the DRQN is capable of learning, but I'm not testing this further for now so I'm leaving this issue open. In any case, I still haven't solved this issue (which I mentioned above).

scores

@Kaixhin
Copy link
Owner

Kaixhin commented May 19, 2016

Pinging @JoostvDoorn since he's contributed to rnn and may have ideas about the minibatch problem/performance improvements/whether it's possible to save and restore state before and after training (and if that should be done since the parameters have changed slightly).

@JoostvDoorn
Copy link
Contributor

@Kaixhin I will have a look later.

@lake4790k
Copy link
Collaborator Author

@Kaixhin I'm not getting the error you mentioned when doing validation on the last batch with size 20 when running demo. I'm using the master code which has sequencer:remember('both') enabled. You mention you had to disable that to not crash...? master runs fine for me as it is.

@JoostvDoorn
Copy link
Contributor

I think this is in the rnn branch. This may or may not be a bug when using FastLSTM with the nngraph version. Setting nn.FastLSTM.usenngraph = false changed the error for me, but I only got the chance to look at this for a moment.

@lake4790k
Copy link
Collaborator Author

lake4790k commented May 23, 2016

ok so there are two issues:

  1. nn.FastLSTM.usenngraph = true
    nngraph/gmodule.lua:335: split(4) cannot split 32 outputs
    this is issue in both rnn and master
  2. nn.FastLSTM.usenngraph = false
    Wrong size for view. Input size: 20x1x3. Output size: 32x3
    this is only in rnn, because @Kaixhin fixed Agent.valMemory and validate() #16 in master (but not in rnn) that returns before doing the backward during validation, because it is not even needed, so maybe no issue after all?

@Kaixhin
Copy link
Owner

Kaixhin commented May 23, 2016

  1. With nn.FastLSTM.usenngraph = true, I get the same error as @lake4790k. This seems to be incompatibility with nngraph's gmodule in evaluate mode Element-Research/rnn#172. Which is a shame, as apparently it's significantly faster with this flag enabled (see Slow LSTM speed relative to theano. Element-Research/rnn#182).
  2. Yes, so if you remove the return on line 374 in master then it fails. So I consider this a bug, albeit one that is being hidden by that return - why is this occurring even when states is 20x4x1x24x24 and QCurr is 20x1x3? If the error is dependent on previous batches then the learning must be incorrect. I was wrong and removing sequencer:remember('both') doesn't stop the crash.

@lake4790k
Copy link
Collaborator Author

@Kaixhin re: 2. agree, this error is bad, so returning before is not a solution. I'm not sure if learning is bad with the normal batch sizes, could be only not handling a batch size change somewhere properly. I tried an isolated FastLSTM+Sequencer net, there switching batch sizes worked fine, weird. I'm looking adding LSTM to async, once I get that working will experiment with this further.

@Kaixhin
Copy link
Owner

Kaixhin commented May 24, 2016

@lake4790k I also tried a simple FastLSTM + Sequencer net with different batch sizes - no problem. I agree with it being likely that some module is not switching its internal variables to the correct size, but finding out exactly where the problem lies is tricky. It may be that I haven't set up the recurrency correctly, but apart from this batch size issue it seems to work fine.

@lake4790k
Copy link
Collaborator Author

lake4790k commented May 24, 2016

@Kaixhin I need to refresh async from master for the recurrent, should I do a merge or rebase (I'm thinking of merge rather)? Does it even matter when merging back from async to master eventually?

@Kaixhin
Copy link
Owner

Kaixhin commented May 24, 2016

@lake4790k I'd go with a merge since it preserves history correctly. It's better to make sure all the changes in master are integrated sooner rather than later.

@lake4790k
Copy link
Collaborator Author

lake4790k commented May 24, 2016

Done the merge and added recurrent support for 1-step Q in async. This is 7 minutes of training, seems to work well:

scores

Agent sees only the latest frame per step and backpropagates with unrolling 5 steps on every step, weights are updated every 5 (or terminal) steps, no Sequencer is needed in this algo. I used sharedRmsProp and kept the ReLU after the FastLSTM to have comparable setup to my usual async testing.

Pretty cool that is works, I'll try if it performs similar with a flickering catch as they did in the paper with the flickering pong. Also in the async paper they added a half size LSTM layer after the linear instead of replacing it, will try that as well (although the DRQN paper says replacing is the best).

Will add support for the n-step methods as well, there it's a bit trickier to get right as there are steps taken forwards and backwards to calculate n-step returns, will have to take care that forwards/backwards are correct for LSTM as well.

@lake4790k
Copy link
Collaborator Author

Also tried replacing FastLSTM with GRU with everything else being the same, that did not converge after running it longer interestingly.

@JoostvDoorn
Copy link
Contributor

@lake4790k Do you have the flickering catch version somewhere?

@lake4790k
Copy link
Collaborator Author

@JoostvDoorn haven't got around to it since, but probably takes a few lines to add to rlenvs.

@Kaixhin
Copy link
Owner

Kaixhin commented Jun 23, 2016

@JoostvDoorn I can add that to rlenvs.Catch if you want? You may also be interested in the obscured option I set up, which blanks a strip of screen at the bottom so that the agent has to infer the motion of the ball properly. Quick enough to test by adding opt.obscured = true in Setup.lua.

@Kaixhin
Copy link
Owner

Kaixhin commented Jun 23, 2016

@JoostvDoorn Done. Just get the latest version of rlenvs and this repo. -flickering is a probability between 0 and 1 of the screen blanking out.

@JoostvDoorn
Copy link
Contributor

@Kaixhin Great thanks.

Have you tried storing the state instead of calling forget for every time step? I am doing this now, however it takes longer to train but it will probably converge. I agree this has to do with the changing state distribution, but we cannot really let the agent explore without considering the history to take full advantage of the LSTM.
scores4

@Kaixhin
Copy link
Owner

Kaixhin commented Jun 26, 2016

@JoostvDoorn I thought that this line would actually set remember for all internal modules, but I'm not certain? If that is not the case then yes I agree it should be set on the LSTM units themselves.

In summy, in Agent:observe, the only place that forget is called is at a terminal state. Of course when learning it should call forget before passing the minibatch through, and after learning as well. This means that memSampleFreq is the maximum amount of history the LSTMs keep during training, but they receive the entire history during validation/evaluation.

@JoostvDoorn
Copy link
Contributor

@Kaixhin Yes that line is enough, I will change that in my pull request.

I missed memSampleFreq, so I assumed it was calling forget every time. I guess memSampleFreq >= histLen is a good thing here, such that training, and updating have a similar distribution. Do note though that the 5th action will update based on the 2th, 3th, 4th, and 5th state in the Q-learning update, while the policy followed will be only be based on the 5th state, right?

@Kaixhin
Copy link
Owner

Kaixhin commented Jun 26, 2016

@JoostvDoorn Yep memSampleFreq >= histLen would be sensible. Sorry not sure I understand your last question though. During learning updates for recurrent networks, histLen is used to determine the sequence length of states fed in (no concatenating frames in time as with a normal DQN). During training the hidden state will go back until the last time forget was called (and forget is called every memSampleFreq).

@JoostvDoorn
Copy link
Contributor

I guess like this; forget is called at the first time step so the LSTM will not have accumulated any information at this point, once here it will start accumulating state information (note however on torch.uniform() < epsilon we don't accumulate info, which is a bug). Now after calling Agent:learn we call forget again. Then once the episode continues, and reaches the point here the state information is the same as in the start of the episode, depending on the environment this is a problem.

@Kaixhin
Copy link
Owner

Kaixhin commented Jun 26, 2016

Thanks for spotting the bug. @lake4790k please check 626712b to make sure async agents are accounted for as well.

@JoostvDoorn If I understand correctly then the issue is that the agent can't retain information during training because observe is interspersed with forget calls during learn? That's what I was wondering about above. My reasoning comes from the rnn docs. Also, it would be prohibitive to keep old states from before learn and pass them all through the network before starting again.

@lake4790k
Copy link
Collaborator Author

@Kaixhin yes this is needed for async, just created #47 to do it a bit differently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants