Skip to content

0x1DA9430/Decision-Transformer-vs-Decision-Mamba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 

Repository files navigation

Decision Transformer vs. Decision Mamba: Analysing the Complexity of Sequential Decision Making in Atari Games

arXiv

Abstract

This work analyses the disparity in performance between Decision Transformer (DT) and Decision Mamba (DM) in sequence modelling reinforcement learning tasks for different Atari games. The study first observed that DM generally outperformed DT in the games Breakout and Qbert, while DT performed better in more complicated games, such as Hero and Kung Fu Master. To understand these differences, we expanded the number of games to 12 and performed a comprehensive analysis of game characteristics, including action space complexity, visual complexity, average trajectory length, and average steps to the first non-zero reward. In order to further analyse the key factors that impact the disparity in performance between DT and DM, we employ various approaches, including quantifying visual complexity, random forest regression, correlation analysis, and action space simplification strategies. The results indicate that the performance gap between DT and DM is affected by the complex interaction of multiple factors, with the complexity of the action space and visual complexity (particularly evaluated by compression ratio) being the primary determining factors. DM performs well in environments with simple action and visual elements, while DT shows an advantage in games with higher action and visual complexity. Our findings contribute to a deeper understanding of how the game characteristics affect the performance difference in sequential modelling reinforcement learning, potentially guiding the development of future model design and applications for diverse and complex environments.

The Decision Transformer (Left) and Decision Mamba (Right). $N$ represents normalization layers, activation function $\sigma$ stands for GELU (Gaussian Error Linear Unit), and $+$ are addition operations for skip connections.

Dependencies

conda create -n [env name] python=3.9

Activate environment and install requirements

conda activate [env name]
pip install -r requirements.txt

it may take 5 minutes, depends on the device

Or PyTorch for CUDA 12.1 (or CUDA 11.8)

conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia

Install mamba-ssm and causal-conv1d manually

pip install mamba-ssm --no-cache-dir
pip install causal-conv1d --no-cache-dir

To be specific: mamba-ssm==2.0.3 causal-conv1d==1.2.2.post1 newer versions may also work.

Upgrade charset_normalizer manually (if needed)

pip install --upgrade charset_normalizer

Dataset

Download dqn_replay

Create a directory for dataset

mkdir [DIRECTORY]

or use the existing directory /data/data_atari

Download the datasets using gsutil (you may install gsutil first).

gsutil -m cp -R gs://atari-replay-datasets/dqn/[GAME_NAME] [DIRECTORY]

e.g.

gsutil -m cp -R gs://atari-replay-datasets/dqn/Breakout ./msc-project/atari/data/data_atari

Download game ROMS

wget http://www.atarimania.com/roms/Roms.rar
unrar x Roms.rar

Load the ROMS

python -m atari_py.import_roms ROMS

Run a Minimal Experiment

Command Template

python train_atari.py \
        --game '[GAME NAME]' \
        --data_dir_prefix [PATH TO DATA DIR] \
        --context_length 5 \
        --n_layer 3 \
        --n_embd 8 \
        --token_mixer '[mamba for DM / attn for DT]' \
        --epochs 2 \
        --batch_size 64 \
        --num_steps 2 \
        --num_buffers 1 \
        --trajectories_per_buffer 5 \
        --output [PATH TO OUTPUT DIR] \
        --experiment [experiment name] \
        --seed 123

e.g.

python train_atari.py \
        --game 'Breakout' \
        --data_dir_prefix ./data/data_atari/ \
        --context_length 5 \
        --n_layer 3 \
        --n_embd 8 \
        --token_mixer 'mamba' \
        --epochs 2 \
        --batch_size 64 \
        --num_steps 5000 \
        --num_buffers 50 \
        --trajectories_per_buffer 5 \
        --output ./output/ \
        --experiment test_experiment \
        --seed 123

Run a Standard Experiment

  • run.sh is a script to run the standard experiments for Breakout, with 3 random seeds, context length 10.

Use action fusion for the game Hero/KungFuMaster

add --use_action_fusion as the flag at the end of the command

e.g.

python train_atari.py \
        --game 'Hero' \
        --data_dir_prefix ./data/data_atari/ \
        --context_length 5 \
        --n_layer 3 \
        --n_embd 8 \
        --token_mixer 'mamba' \
        --epochs 2 \
        --batch_size 64 \
        --num_steps 5000 \
        --num_buffers 50 \
        --trajectories_per_buffer 5 \
        --output ./output/ \
        --experiment test_experiment \
        --seed 123 \
        --use_action_fusion

Run complexity analysis for the games in the dataset

Random sampling the data for analysis:

python dataset_analyze_rand.py --game [GAME NAME] 

Only use the last 1% of the data for analysis:

python dataset_analyze_last_1p.py --game [GAME NAME] 

Analyze the outputs

  • Calculate the Normalized Scores

    • atari/output_analyze/analyse_output.ipynb
  • Visualize the Results

    • atari/output_analyze/plot_output.ipynb
  • Random Forest Regression, Correlation Analysis

    • atari/output_analyze/regression_analysis.ipynb