This repository is the official implementation of Unsupervised Object-centric Learning with Bi-level Optimized Query Slot Attention
We provide all environment configurations in requirements.txt
. To install all packages, you can create a conda environment and install the packages as follows:
conda create -n BO-QSA python=3.8
conda activate BO-QSA
pip install -r environment.txt
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for torch
and torchvision
.
ShapeStacks dataset is avaiable at: https://ogroth.github.io/shapestacks/
# Download compressed dataset
$ cd data/ShapeStacks
$ wget http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-mjcf.tar.gz
$ wget http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-meta.tar.gz
$ wget http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-rgb.tar.gz
$ wget http://shapestacks-file.robots.ox.ac.uk/static/download/v1/shapestacks-iseg.tar.gz
# Uncompress files
$ tar xvzf shapestacks-meta.tar.gz
$ tar xvzf shapestacks-mjcf.tar.gz
$ tar xvzf shapestacks-rgb.tar.gz
$ tar xvzf shapestacks-iseg.tar.gz
ObjectsRoom dataset is avaiable at: https://console.cloud.google.com/storage/browser/multi-object-datasets;tab=objects?prefix=&forceOnObjectsSortingFiltering=false, provided by Multi-Object Datasets
# Download compressed dataset
$ cd data/ObjectsRoom
$ gsutil -m cp -r \
"gs://multi-object-datasets/objects_room" \
.
Before you start training, you need to run objectsroom_process.py
to save the tfrecords dataset as a png format
ClevrTex dataset is available at: https://www.robots.ox.ac.uk/~vgg/data/clevrtex/
PTR dataset is available at: http://ptr.csail.mit.edu
Download the 'Training Images', 'Validation Images', 'Test Images', 'Training Annotations', 'Validation Annotations' and then uncompress them.
Birds, Dogs, Cars datasets are available at: https://drive.google.com/drive/folders/1zEzsKV2hOlwaNRzrEXc9oGdpTBrrVIVk, provided by DRC.
Download the 'birds.zip', 'cars.tar' and 'dogs.zip' and then uncompress them.
Flowers dataset is available at: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
Download the 'Dataset images', 'Image segmentations' and 'The data splits' and then uncompress them.
To train the model from scratch we provide the following model files:
train_trans_dec.py
: transformer-based modeltrain_mixture_dec.py
: mixture-based modeltrain_base_sa.py
: original slot-attention We provide training scripts underscripts/train
. Please use the following command and change.sh
file to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following:
$ cd scripts
$ cd train
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh
To reload checkpoints and only run inference, we provide the following model files:
test_trans_dec.py
: transformer-based modeltest_mixture_dec.py
: mixture-based modeltest_base_sa.py
: original slot-attention
Similarly, we provide testing scripts under scripts/test
. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers)
and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints at this link. Please use the following command and change .sh
file to the model you want to experiment with:
$ cd scripts
$ cd test
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh
If you find our paper and/or code helpful, please consider citing:
@article{jia2022egotaskqa,
title = {Unsupervised Object-Centric Learning with Bi-Level Optimized Query Slot Attention},
author = {Jia, Baoxiong and Liu, Yu and Huang, Siyuan},
journal = {arXiv preprint arXiv:2210.08990},
year = {2022}
}
This code heavily used resources from SLATE, SlotAttention, GENESISv2, DRC, Multi-Object Datasets, shapestacks. We thank the authors for open-sourcing their awesome projects.