Main Takeaways | Results | How to use this repository
This repository contains the implementation for my thesis
Abstract:
Object-centered perception lays the foundation of human intelligence and may be a stepping stone towards artificial general intelligence. Despite recent advances, artificial vision systems that decompose images into objects and their relations without any supervision lack real-world applicability and only work on simplistic datasets. While there are multiple reasons associated to this obstacle, we hypothesize that integrating a multi-class paradigm into object-structured representations is an essential prerequisite to model the variability of real-world images. In this thesis, we present and test novel generative models that aim to retrieve a multi-class representation from raw images without any supervision by interpreting objects as prototype modifications. For single-object datasets such as MNIST, these approaches can equally be used for image clustering. Despite its simplicity, our best model discovers clusters that match the original MNIST classes with an accuracy of 93.52%. This accuracy can be obtained by simply labeling the prototype images which highlights the interpretability of our models.
-
One class to rule them all - Why we need a multi-class paradigm:
To the best of our knowledge, all previous works assumes in some way that objects within an image originate from the same class. While a class can be arbitrary complex (i.e., the number of attributes in the latent space could be arbitrary large), it seems rather unnatural and inefficient to force an one class to rule them all paradigm. E.g., how should the attributes to generate different images of a hand relate to the attributes to generate images of cars? -
A multi-class paradigm introduces discreteness - Implementation of gradient estimators:
The reparameterization trick (Kingma & Welling, 2014, Rezende et al., 2014) is not applicable to discrete random nodes due to the associated discontinuous operations. Therefore, we implemented the following discrete gradient estimators (for a categorical distribution):- REINFORCE (Williams, 1992), also known as score-function estimator, is an unbiased, high-variance estimator that is based on the log-derivative trick.
- NVIL (A. Minh & Gregor, 2014) is basically a variant of REINFORCE with control variates, in which a neural baseline is used.
- CONCRETE (Maddison et al., 2017 and Jang et al., 2017) is based on the idea to approximate sampling from a discrete distribution via sampling from a CONtinuous relaxation of disCRETE random variables.
- REBAR (Tucker et al., 2017) combines the unbiased, high-variance REINFORCE estimator with the low-variance, but biased CONCRETE estimator as a control variate.
- RELAX (Grathwohl et al, 2018) combines the ideas of REBAR and NVIL by using a surrogate neural network to construct a control variate that aims at minimizing the variance.
-
Variational Autoencoders (VAEs) with a discrete latent space learn prototypes:
The concept of prototypes naturally arises in D-VAEs (discrete VAEs) due to the finite number of reconstructions. Intuitively, the D-VAE aims to reconstruct all images within the dataset X using only K (latent space dimension) realizations for the decoder distribution. Thus, it has to find class-associated prototypes in order to minimize the reconstruction error. More accurately, the generative process of a D-VAE assumes that each image is generated by sampling from one of those K distributions.Due to their modeling assumptions, D-VAEs are highly limited in practice: Each image must be highly resembled by some prototype image. I.e., positional and/or shape variations are only included through the variance of the scene model.
-
D-VAE with a Spatial Transformer (D-VAE-ST) - Objects are prototype modifications:
To overcome the limitations of the D-VAE, we can include a spatial transformer that generates the reconstructed image by disturbing the prototype image using a latent (infered) transformation vector. Thereby, we can model positional and shape variations within the data.Although the D-VAE-ST is based on a more realistic generative model, it is still only applicable to simple 2D data in which the same texture is assumed for all objects within a class. For real-world applicability, it would be interesting to extend the prototype definition, e.g., we could define prototypes as three dimensional objects and infer the view to generate a reconstructed image.
-
Discrete Attend-Infer-Repeat (DAIR) - A multi-object-multi-class model:
We showed how the D-VAE-ST can be included into AIR (Eslami et al., 2016), an existing model that explicetly reasons about the number of objects. We term this adaptation DAIR and compared it against AIR on the MultiMNIST dataset. Ideally, after training DAIR should be able to locate and classify objects within an image in a purely unsupervised fashion. While the preliminary results for DAIR were rather poor, we expect that some modifications and a different hyperparameter setup could lead to significant improvements. -
Robustness is a problem:
We observed that our approaches often get trapped into local minima. In future works, it would be interesting to study how we could make the training procedure more robust, e.g., using drop-out, batch-normalization and convolutional layers.
This repository contains the whole implementation and experimental setups such that all results can easily be reproduced.
Simply clone the repository, create a virtual enviroment (not necessary, but advised), install the necessary packages and install the strimadec
package.
git clone --depth 1 https://github.com/borea17/strimadec/
cd strimadec
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
python setup.py install
After you successfully installed the repository, you can run the following commands to reproduce the thesis results:
# Toy Experiments
python strimadec/experiments/toy_experiment/01_replication_experiment.py \
--num_epochs 10000 --num_repetitions 50
python strimadec/experiments/toy_experiment/02_toy_experiment.py \
--num_epochs 5000 --num_repetitions 50
python strimadec/experiments/toy_experiment/03_hyperparameter_experiment.py \
--num_epochs 5000 --num_repetitions 25
# Single-Object-Multi-Class Experiments
python strimadec/experiments/single_object_multi_class/01_DVAE_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'SimplifiedMNIST'
python strimadec/experiments/single_object_multi_class/01_DVAE_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'FullMNIST'
python strimadec/experiments/single_object_multi_class/01_DVAE_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'FashionMNIST'
python strimadec/experiments/single_object_multi_class/01_DVAE_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'Letters'
python strimadec/experiments/single_object_multi_class/02_DVAEST_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'SimplifiedMNIST'
python strimadec/experiments/single_object_multi_class/02_DVAEST_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'FullMNIST'
python strimadec/experiments/single_object_multi_class/02_DVAEST_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'FashionMNIST'
python strimadec/experiments/single_object_multi_class/02_DVAEST_experiments.py \
--num_epochs 150 --num_repetitions 10 --dataset_name 'Letters'
python strimadec/experiments/single_object_multi_class/03_DVAEST_overclustering_experiments.py \
--num_epochs 150 --num_repetitions 5 --dataset_name 'SimplifiedMNIST'
python strimadec/experiments/single_object_multi_class/03_DVAEST_overclustering_experiments.py \
--num_epochs 150 --num_repetitions 5 --dataset_name 'FullMNIST'
python strimadec/experiments/single_object_multi_class/03_DVAEST_overclustering_experiments.py \
--num_epochs 150 --num_repetitions 5 --dataset_name 'FashionMNIST'
# Multi-Object-Multi-Class Experiments
python strimadec/experiments/multi_object_multi_class/01_comparison_experiment.py \
--num_epochs 200 --num_repetitions 10 --dataset_name 'MultiMNIST'
WARNING: Running all of the commands may take days and is not advised.
All results (trained networks, plots and tensorboard directories) can be downloaded from this google drive (about 15GB). Copy the downloaded folders into the respective folder under strimadec/experiments/
, then you can run the scripts with the addendum --load
in a few minutes.
Contributions are welcome and if needed, I will provide a more detailed documentation.