This repository contains the official implementation of the ICLR 2023 paper:
Improving Object-centric Learning With Query Optimization
Baoxiong Jia*, YuLiu*, Siyuan Huang
We provide all environment configurations in requirements.txt
. To install all packages, you can create a conda environment and install the packages as follows:
conda create -n BO-QSA python=3.8
conda activate BO-QSA
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install -r requirements.txt
In our experiments, we used NVIDIA CUDA 11.3 on Ubuntu 20.04. Similar CUDA version should also be acceptable with corresponding version control for torch
and torchvision
.
Download ShapeStacks, ObjectsRoom, CLEVRTex and Flowers datasets with
chmod +x scripts/downloads_data.sh
./downloads_data.sh
For ObjectsRoom dataset, you need to run objectsroom_process.py
to save the tfrecords dataset as a png format.
Remember to change the DATA_ROOT
in downloads_data.sh
and objectsroom_process.py
to your own paths.
Download PTR dataset following instructions from http://ptr.csail.mit.edu. Download CUB-Birds, Stanford Dogs, and Cars datasets from here, provided by authors from DRC. We use the birds.zip
, cars.tar
and dogs.zip
and then uncompress them.
YCB, ScanNet and COCO datasets are available from here, provided by authors from UnsupObjSeg.
Please organize the data following here before experiments.
To train the model from scratch we provide the following model files:
train_trans_dec.py
: transformer-based modeltrain_mixture_dec.py
: mixture-based modeltrain_base_sa.py
: original slot-attention We provide training scripts underscripts/train
. Please use the following command and change.sh
file to the model you want to experiment with. Take the transformer-based decoder experiment on Birds as an exmaple, you can run the following:
$ cd scripts
$ cd train
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh
Remember to change the paths in path.json
to your own paths.
To reload checkpoints and only run inference, we provide the following model files:
test_trans_dec.py
: transformer-based modeltest_mixture_dec.py
: mixture-based modeltest_base_sa.py
: original slot-attention
Similarly, we provide testing scripts under scripts/test
. We provide transformer-based model for real-world datasets (Birds, Dogs, Cars, Flowers, YCB, ScanNet, COCO)
and mixture-based model for synthetic datasets(ShapeStacks, ObjectsRoom, ClevrTex, PTR). We provide all checkpoints here. Please use the following command and change .sh
file to the model you want to experiment with:
$ cd scripts
$ cd test
$ chmod +x trans_dec_birds.sh
$ ./trans_dec_birds.sh
If you find our paper and/or code helpful, please consider citing:
@inproceedings{jia2023improving,
title={Improving Object-centric Learning with Query Optimization},
author={Jia, Baoxiong and Liu, Yu and Huang, Siyuan},
booktitle={The Eleventh International Conference on Learning Representations},
year={2023}
}
This code heavily used resources from SLATE, SlotAttention, GENESISv2, DRC, Multi-Object Datasets, shapestacks. We thank the authors for open-sourcing their awesome projects.