The official codebase for training video policies in VideoAgent
NEWS: We have released another repository for running our Meta-World and iTHOR experiments here!
This repository contains the code for training video policies presented in our work
VideoAgent: Self improving video generation
Achint Soni,
Sreyas Venkataraman,
Abhranil Chandra,
Sebastian Fischmeister,
Percy Liang,
Bo Dai,
Sherry Yang
website | paper | arXiv | experiment repo
@misc{soni2024videoagentselfimprovingvideogeneration,
title={VideoAgent: Self-Improving Video Generation},
author={Achint Soni and Sreyas Venkataraman and Abhranil Chandra and Sebastian Fischmeister and Percy Liang and Bo Dai and Sherry Yang},
year={2024},
eprint={2410.10076},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2410.10076},
}
We recommend to create a new environment with pytorch installed using conda.
conda create -n videoagent python=3.9
conda activate videoagent
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
Next, clone the repository and install the requirements
git clone https://github.com/Video-as-Agent/VideoAgent
cd VideoAgent
pip install -r requirements.txt
The pytorch dataset classes are defined in flowdiffusion/datasets.py
For Meta-World experiments, run
cd flowdiffusion
python train_mw.py --mode train
# or python train_mw.py -m train
or run with accelerate
accelerate launch train_mw.py
For iTHOR experiments, run train_thor.py
instead of train_mw.py
For bridge experiments, run train_bridge.py
instead of train_mw.py
The trained model should be saved in ../results
folder
To resume training, you can use -c
--checkpoint_num
argument.
# This will resume training with 1st checkpoint (should be named as model-1.pt)
python train_mw.py --mode train -c 1
Use the following arguments for inference
-p
--inference_path
: specify input video path
-t
--text
: specify the text discription of task
-n
sample_steps
Optional, the number of steps used in test time sampling. If the specified value less than 100, DDIM sampling will be used.
-g
guidance_weight
Optional, The weight used for classifier free guidance. Set to positive to turn on classifier free guidance.
For example:
python train_mw.py --mode inference -c 1 -p ../examples/assembly.gif -t assembly -g 2 -n 20
We also provide checkpoints of the models described in our experiments as following.
VideoAgent | VideoAgent-Online | VideoAgent-Suggestive
Download and put the .pt file in results/[environment]
folder. The resulting directory structure should be results/{mw, thor, bridge}/model-[x].pt
, for example results/mw/model-305.pt
Or use download.sh
./download.sh metaworld
# ./download.sh ithor
# ./download.sh bridge
After this, you can use argument -c [x]
to resume training or inference with our checkpoint. For example:
python train_mw.py --mode train -c 305
Or
python train_mw.py --mode inference -c 3053083 -p ../examples/assembly.gif -t assembly
This codebase is modified from the following repositories:
avdc