Skip to content

Official implementation of the RLC 2024 paper "Policy-Guided Diffusion"

License

Notifications You must be signed in to change notification settings

EmptyJackson/policy-guided-diffusion

Repository files navigation

Policy-Guided Diffusion

animated

The official implementation of Policy-Guided Diffusion - built by Matthew Jackson and Michael Matthews.

  • Offline RL agents (TD3+BC, IQL),
  • Trajectory-level U-Net diffusion model,
  • EDM diffusion training and sampling,
  • Runs on the D4RL benchmark.

Diffusion and agent training is implemented entirely in Jax, with extensive JIT-compilation and parallelization!

Update (28/06/24): Added WandB report with diffusion and agent model training logs.

Running experiments

Diffusion and agent training is executed with python3 train_diffusion.py and python3 train_agent.py, with all arguments found in util/args.py.

  • --log --wandb_entity [entity] --wandb_project [project] enables logging to WandB.
  • --debug disables JIT compilation.

Docker installation

  1. Build docker image
cd docker && ./build.sh && cd ..
  1. (To enable WandB logging) Add your account key to docker/wandb_key:
echo [KEY] > docker/wandb_key

Launching experiments

./run_docker.sh [GPU index] python3.9 [train_script] [args]

Diffusion training example:

./run_docker.sh 0 python3.9 train_diffusion.py --log --wandb_project diff --wandb_team flair --dataset_name walker2d-medium-v2

Agent training example:

./run_docker.sh 6 python3.9 train_agent.py --log --wandb_project agents --wandb_team flair --dataset_name walker2d-medium-v2 --agent iql

Citation

If you use this implementation in your work, please cite us with the following:

@misc{jackson2024policyguided,
      title={Policy-Guided Diffusion},
      author={Matthew Thomas Jackson and Michael Tryfan Matthews and Cong Lu and Benjamin Ellis and Shimon Whiteson and Jakob Foerster},
      year={2024},
      eprint={2404.06356},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

Official implementation of the RLC 2024 paper "Policy-Guided Diffusion"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published