Skip to content

Vanilla DQN, Double DQN, and Dueling DQN implemented in PyTorch

Notifications You must be signed in to change notification settings

dxyang/DQN_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vanilla DQN, Double DQN, and Dueling DQN in PyTorch

Description

This repo is a PyTorch implementation of Vanilla DQN, Double DQN, and Dueling DQN based off these papers.

Starter code is used from Berkeley CS 294 Assignment 3 and modified for PyTorch with some guidance from here. Tensorboard logging has also been added (thanks here for visualization during training in addition to what the Gym Monitor already does).

Background

Deep Q-networks use neural networks as function approximators for the action-value function, Q. The architecture used here specifically takes inputs frames from the Atari simulator as input (i.e., the state) and passes these frames through two convolutional layers and two fully connected layers before outputting a Q value for each action.

Human-level control through deep reinforcement learning introduced using a experience replay buffer that stores past observations and uses them as training input to reduce correlations between data samples. They also used a separate target network consisting of weights at a past time step for calculating the target Q value. These weights are periodically updated to match the updated, latest set of weights on the main Q network. This reduces the correlation between the target and current Q values. Q target is calculated as below.

Noting that vanilla DQN can overestimate action values, Deep Reinforcement Learning with Double Q-learning proposes an alternative Q target value that takes the argmax of the current Q network when inputted with the next observations. These actions, together with the next observations, are passed into the frozen target network to yield Q values at each update. This new Q target is shown below.

Finally, Dueling Network Architectures for Deep Reinforcement Learning proposes a different architecture for approximating Q functions. After the last convolutional layer, the output is split into two streams that separately estimate the state-value and advantages for each action within the state. These two estimations are then combined together to generate a Q value through the equation below. The architecture is also shown here in contrast to traditional Deep Q-Learning networks.

Dependencies

Usage

  • Execute the following command to train a model on vanilla DQN:
$ python main.py train --task-id $TASK_ID

From the Atari40M spec, here are the different environments you can use:

  • 0: BeamRider
  • 1: Breakout
  • 2: Enduro
  • 3: Pong
  • 4: Qbert
  • 5: Seaquest
  • 6: Spaceinvaders

Here are some options that you can use:

  • --gpu: id of the GPU you want to use (if not specified, will train on CPU)
  • --double-dqn: 1 to train with double DQN, 0 for vanilla DQN
  • --dueling-dqn: 1 to train with dueling DQN, 0 for vanilla DQN

Results

SpaceInvaders

Sample gameplay

Pong

Sample gameplay

Breakout

Sample gameplay

About

Vanilla DQN, Double DQN, and Dueling DQN implemented in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages