This is a re-implementation of VIBI (arxiv, github) including experiments for MNIST and CIFAR10 written in Python and PyTorch.
To run the experiments, first clone this repository and install requirements.
git clone
pip install -r requirements.txt
Run all experiments shown in results:
chmod +x
Otherwise run the script with passed arguments.
optional arguments
--dataset {MNIST,CIFAR10}
--cuda Enable cuda.
--num_epochs NUM_EPOCHS
Number of training epochs for VIBI.
--explainer_type {Unet,ResNet_2x,ResNet_4x,ResNet_8x}
--xpl_channels {1,3}
--k K Number of chunks.
--beta BETA beta in objective J = I(y,t) - beta * I(x,t).
--num_samples NUM_SAMPLES
Number of samples used for estimating expectation over p(t|x).
--resume_training Recommence training vibi from last saved checkpoint.
--save_best Save only the best models (measured in valid accuracy).
Save explanation images every epoch.
--jump_start Use pretrained model with beta=0 as starting point.
The goal is to create interpretable explanations for black-box models.
This is achieved by two neural network, the explainer and the approximator.
The explainer network produces a probability distribution over the input chunks,
given an input image. A relaxed k-hot vector is sampled from this distribution.
This k-hot vector is used to create a masked input, which is then
fed into the approximator network.
The approximator network aims to match the probability distribution of the
black-box model output.
The whole idea builds heavily on L2X (Learning to explain).
The only difference is that VIBI's additional term effectively increases the entropy of the distribution p(z)
whereas L2X only optimizes for minimizing the cross-entropy H(p,q)
between the black-box model's predictions and the approximator.
Test Batch | Explanation Distribution | Top-k Explanation |
Using explainer_model=Resnet4x
, k=4
, beta=0.01
Test Batch | Explanation Distribution | Top-k Explanation |
Using explainer_model=Unet
, k=64
, beta=0.001
Green boxes indicate that the black-box model's prediction is correct, red boxes indicate incorrect predictions.
The strength (calculated using 1 - JS(p,q)
) of the outlining color gives feedback on how well the approximator's prediction (using top-k) fits the black-box model's output.