Lior Cohen, Kaixin Wang, Bingyi Kang, Shie Mannor
Paper: Improving Token-Based World Models with Parallel Observation Prediction.
If you find this code useful, please cite in your paper
@inproceedings{
cohen2024improving,
title={Improving Token-Based World Models with Parallel Observation Prediction},
author={Lior Cohen and Kaixin Wang and Bingyi Kang and Shie Mannor},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=Lfp5Dk1xb6}
}
This repository started as a fork of IRIS.
- Python 3.10
- Install PyTorch (torch and torchvision). Code developed with torch==1.13.1 and torchvision==0.14.0, but also tested with torch==2.2.0.
- Install other dependencies:
pip install -r requirements.txt
- Warning: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.
python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.device=cuda:0 wandb.mode=online
By default, the logs are synced to weights & biases, set wandb.mode=disabled
to turn it off.
- All configuration files are located in
config/
, the main configuration file isconfig/config.yaml
. - The simplest way to customize the configuration is to edit these files directly.
- Please refer to Hydra for more details regarding configuration management.
Each new run is located at outputs/env.id/YYYY-MM-DD/hh-mm-ss/
. This folder is structured as:
outputs/env.id/YYYY-MM-DD/hh-mm-ss/
│
└─── checkpoints
│ │ last.pt
| | optimizer.pt
| | ...
│ │
│ └─── dataset
│ │ 0.pt
│ │ 1.pt
│ │ ...
│
└─── config
│ | config.yaml
|
└─── media
│ │
│ └─── episodes
│ | │ ...
│ │
│ └─── reconstructions
│ | │ ...
│
└─── scripts
| | eval.py
│ │ play.sh
│ │ resume.sh
| | ...
|
└─── src
| | ...
|
└─── wandb
| ...
-
checkpoints
: contains the last checkpoint of the model, its optimizer and the dataset. -
media
:-
episodes
: contains train / test / imagination episodes for visualization purposes. -
reconstructions
: contains original frames alongside their reconstructions with the autoencoder.
-
-
scripts
: from the run folder, you can use the following three scripts.-
eval.py
: Launchpython ./scripts/eval.py
to evaluate the run. -
resume.sh
: Launch./scripts/resume.sh
to resume a training that crashed. -
play.sh
: Tool to visualize some interesting aspects of the run.- Launch
./scripts/play.sh
to watch the agent play live in the environment. If you add the flag-r
, the left panel displays the original frame, the center panel displays the same frame downscaled to the input resolution of the discrete autoencoder, and the right panel shows the output of the autoencoder (what the agent actually sees). - Launch
./scripts/play.sh -w
to unroll live trajectories with your keyboard inputs (i.e. to play in the world model). Note that since the world model was trained with segments of$H$ steps where the first$c$ observations serve as a context, the memory of the world model is flushed every$H-c$ frames. - Launch
./scripts/play.sh -a
to watch the agent play live in the world model. World model memory flush applies here as well for the same reasons. - Launch
./scripts/play.sh -e
to visualize the episodes contained inmedia/episodes
. - Add the flag
-h
to display a header with additional information. - Press '
,
' to start and stop recording. The corresponding segment is saved inmedia/recordings
in mp4 and numpy formats. - Add the flag
-s
to enter 'save mode', where the user is prompted to save trajectories upon completion.
- Launch
-
The folder results/data/
contains raw scores (for each game, and for each training run) for REM other baselines, and ablations.
The results python scripts can be used to reproduce the plots from the paper.
This repository started as a fork of IRIS.
We extend the RetNet implementation of yet-another-retnet.