[Project Website] [Paper] [Data] [Models]
This is the official PyTorch implementation of the paper "Diffusion Reward: Learning Rewards via Conditional Video Diffusion" by
Tao Huang*, Guangqi Jiang*, Yanjie Ze, Huazhe Xu.
Clone this repository.
git clone https://github.com/TaoHuang13/diffusion_reward.git
cd diffusion_reward
Create a virtual environment.
conda env create -f conda_env.yml
conda activate diffusion_reward
pip install -e .
Install extra dependencies.
- Install PyTorch.
pip3 install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
-
Install mujoco210 and mujoco-py following instructions here.
-
Install Adroit dependencies.
cd env_dependencies
pip install -e mj_envs/.
pip install -e mjrl/.
cd ..
- Install MetaWorld following instructions here.
Domain | Tasks | Episodes | Size | Collection | Link |
---|---|---|---|---|---|
Adroit | 3 | 150 | 23.8M | VRL3 | Download |
MetaWorld | 7 | 140 | 38.8M | Scripts | Download |
You can download the datasets and place them to /video_dataset
to reproduce the results in this paper.
Train VQGAN encoder.
bash scripts/run/codec_model/vqgan_${domain}.sh # [adroit, metaworld]
Train video models.
bash scripts/run/video_model/${video_model}_${domain}.sh # [vqdiffusion, videogpt]_[adroit, metaworld]
We also provide the pre-trained reward models (including Diffusion Reward and VIPER) used in this paper for result reproduction. You may download the models with configuration files here, and place the folders in /exp_local
.
Train DrQv2 with different rewards.
bash scripts/run/rl/drqv2_${domain}_${reward}.sh ${task} # [adroit, metaworld]_[diffusion_reward, viper, viper_std, amp, rnd, raw_sparse_reward]
Notice that you should login wandb for logging experiments online. Turn it off, if you aim to log locally, in configuration file here.
diffusion_reward
|- configs # experiment configs
| |- models # configs of codec models and video models
| |- rl # configs of rl
|
|- envs # envrionments, wrappers, env maker
| |- adroit.py # Adroit env
| |- metaworld.py # MetaWorld env
| |- wrapper.py # env wrapper and utils
|
|- models # implements core codec models and video models
| |- codec_models # image encoder, e.g., VQGAN
| |- video_models # video prediction models, e.g., VQDiffusion and VideoGPT
| |- reward_models # reward models, e.g., Diffusion Reward and VIPER
|
|- rl # implements core rl algorithms
For any questions, please feel free to email taou.cs13@gmail.com or luccachiang@gmail.com.
Our code is built upon VQGAN, VQ-Diffusion, VIPER, AMP, RND, and DrQv2. We thank all these authors for their nicely open sourced code and their great contributions to the community.
This repository is released under the MIT license. See LICENSE for additional details.
If you find our work useful, please consider citing:
@article{Huang2023DiffusionReward,
title={Diffusion Reward: Learning Rewards via Conditional Video Diffusion},
author={Tao Huang and Guangqi Jiang and Yanjie Ze and Huazhe Xu},
journal={arxiv},
year={2023},
}