Breakout | Pong | Space Ms Pacman |
---|---|---|
This repository contains a PyTorch implementation of the Deep Q-Network (DQN) algorithm for playing Atari games. The implementation is based on the original paper by Mnih et al. (2015) and contains the following extensions:
- Double Q-Learning (van Hasselt et al., 2015)
- Dueling Network Architectures (Wang et al., 2016)
- Prioritized Experience Replay (Schaul et al., 2016)
- Vectorized environment for parallel training
It is recommended to install the dependencies in a virtual environment, preferably with conda. The following commands will create a new environment and install the required packages using Poetry.
conda create -n dqn python=3.10
conda activate dqn
pip install poetry
poetry install
The DQN agent can be trained using the command line with Hydra. The following command will train a DQN agent to play the Breakout Atari game.
python src/dqn_atari/main.py model.env_id=BreakoutNoFrameskip-v4
Run python src/dqn_atari/main.py --help
to see all available options.
The following code snippet demonstrates how to train a DQN agent to play the Breakout Atari game.
from dqn_atari import DQN
from dqn_atari import PrioritizedReplayBuffer
# Initialize the DQN agent
dqn_model = DQN(
'BreakoutNoFrameskip-v4',
num_envs=8, # or 1 for single environment
double_dqn=True,
dueling=True,
layers=[64, 64],
buffer_class=PrioritizedReplayBuffer,
buffer_size=100_000,
batch_size=32,
lr=2e-5,
force_cpu=False,
)
# Train the agent
dqn_model.train(
training_steps=1_000_000,
# Evaluate the agent every 10,000 steps
eval_every=10_000,
eval_runs=30,
)
# It is also possible to evaluate the agent directly
reward = dqn_model.evaluate()
print(f'Reward: {reward}')
# Save and load the model
dqn_model.save('my/folder/breakout.pt')
dqn_model = DQN.load('my/folder/breakout.pt')
# Continue training the model
dqn_model.train(training_steps=1_000_000)