This is a re-implementation of VIBI (arxiv, github) including experiments for MNIST and CIFAR10 written in Python and PyTorch.
To run the experiments, first clone this repository and install requirements.
git clone https://github.com/willisk/VIBI
cd VIBI
pip install -r requirements.txt
Run all experiments shown in results:
chmod +x run_experiments.sh
./run_experiments.sh
Otherwise run the script with passed arguments.
python train.py
optional arguments
--dataset {MNIST,CIFAR10}
--cuda Enable cuda.
--num_epochs NUM_EPOCHS
Number of training epochs for VIBI.
--explainer_type {Unet,ResNet_2x,ResNet_4x,ResNet_8x}
--xpl_channels {1,3}
--k K Number of chunks.
--beta BETA beta in objective J = I(y,t) - beta * I(x,t).
--num_samples NUM_SAMPLES
Number of samples used for estimating expectation over p(t|x).
--resume_training Recommence training vibi from last saved checkpoint.
--save_best Save only the best models (measured in valid accuracy).
--save_images_every_epoch
Save explanation images every epoch.
--jump_start Use pretrained model with beta=0 as starting point.
The goal is to create interpretable explanations for black-box models.
This is achieved by two neural network, the explainer and the approximator.
The explainer network produces a probability distribution over the input chunks,
given an input image. A relaxed k-hot vector is sampled from this distribution.
This k-hot vector is used to create a masked input, which is then
fed into the approximator network.
The approximator network aims to match the probability distribution of the
black-box model output.
The whole idea builds heavily on L2X (Learning to explain).
The only difference is that VIBI's additional term effectively increases the entropy of the distribution p(z)
,
whereas L2X only optimizes for minimizing the cross-entropy H(p,q)
between the black-box model's predictions and the approximator.
Test Batch | Explanation Distribution | Top-k Explanation |
---|---|---|
Using explainer_model=Resnet4x
, k=4
, beta=0.01
.
Test Batch | Explanation Distribution | Top-k Explanation |
---|---|---|
Using explainer_model=Unet
, k=64
, beta=0.001
.
Green boxes indicate that the black-box model's prediction is correct, red boxes indicate incorrect predictions.
The strength (calculated using 1 - JS(p,q)
) of the outlining color gives feedback on how well the approximator's prediction (using top-k) fits the black-box model's output.