This repository contains code for the paper An Empirical Study of Example Forgetting during Deep Neural Network Learning.
Bibtex:
@inproceedings{Forgetting,
title={An Empirical Study of Example Forgetting during Deep Neural Network Learning},
author={Toneva, Mariya and Sordoni, Alessandro and Combes, Remi Tachet des and Trischler, Adam and Bengio, Yoshua and Gordon, Geoffrey J},
booktitle={ICLR},
year={2019}
}
The environment requirements for the code in this repository are provided in requirements.txt. They can be installed in bulk with the following command:
pip install -r requirements.txt
Our approach consists of two main steps:
- Train on full training set to collect statistics (loss, accuracy, misclassification margin) after each presentation of every example.
- Given these presentation statistics, compute number of forgetting events per example and sort examples by forgetting counts. A forgetting event is defined as a transition in the training accuracy of an example from 1 to 0 on two consecutive presentations. Note that a missclassification is not necessarily a forgetting event.
We present results on MNIST and CIFAR. Below we provide an overview of the supplied code and commands to reproduce our results for both steps.
python run_mnist.py
--dataset [mnist/permuted_mnist]
--no_dropout
--output_dir [mnist/permuted_mnist]_results
--seed s
, where s ranges from to 1 to 5. The default setting was used for all other flags. Use options --dataset mnist
and --output_dir mnist_results
to run on MNIST, and options --dataset permuted_mnist
and --output_dir permuted_mnist_results
for permuted MNIST. Each training run with a different seed saves a file that contains the presentation statistics (loss, accuracy, misclassification margin) from that run in the specified --output_dir
. The names of the saved files contain the arguments (and argument values) that were used to generate them.
python order_examples_by_forgetting.py
--output_dir [mnist/permuted_mnist]_results
--output_name [mnist/permuted_mnist]_sorted
--input_dir [mnist/permuted_mnist]_results
--input_fname_args
dataset [mnist/permuted_mnist]
no_dropout True
sorting_file none
remove_n 0
keep_lowest_n 0
This finds all output files produced by run_mnist.py
that are in --input_dir
and match the arguments and values specified by --input_fname_args
. Note that --seed
is not specified above, which enables us to match the output files of all 5 training runs. Using all matched files, order_examples_by_forgetting.py
calculates the total number of times an example is forgotten across all epochs of all training runs (i.e. the example's forgetting counts). Then, all examples are sorted by their forgetting counts in ascending order (i.e. the examples that are unforgettable, or forgotten 0 times, come first), and the sorted examples and their respective total forgetting counts are saved in a dictionary with the specified name in --output_fname
. This script also outputs the number of unforgettable examples across all completed training runs.
python run_cifar.py
--dataset [cifar10/cifar100]
--data_augmentation
--output_dir [cifar10/cifar100]_results
--seed s
, where s ranges from to 1 to 5. The default setting was used for all other flags. This script has a similar functionality to run_mnist.py
.
python order_examples_by_forgetting.py
--output_dir [cifar10/cifar100]_results
--output_name [cifar10/cifar100]_sorted
--input_dir [cifar10/cifar100]_results
--input_fname_args
dataset [cifar10/cifar100]
data_augmentation True
cutout False
sorting_file none
remove_n 0
keep_lowest_n 0
remove_subsample 0
noise_percent_labels 0
noise_percent_pixels 0
noise_std_pixels 0
The removal experiments specify a number of examples to be completely removed from the sorted training set. We achieve this by providing three extra flags to run_mnist.py
and run_cifar.py
: --sorting_file
, which is the name of the file output by order_examples_by_forgetting.py
that specifies the sorting of the examples based on forgetting counts, --remove_n
, which specifies the number of examples to remove, and --keep_lowest_n
, which specifies where in the list of sorted training examples the removal should begin. We found that near state-of-the-art generalization performance can be maintained even when all unforgettable examples (i.e. examples with 0 forgetting events) are removed.
For Figure 5 Left results:
python run_cifar.py
--dataset cifar10
--data_augmentation
--cutout
--sorting_file cifar10_sorted
--input_dir cifar10_results
--output_dir cifar10_results
--seed s
--remove_n r
--keep_lowest_n k
, where s is in range(1,6)
, r is in range(0,50000,1000)
, and k is 0 (for selected) and -1 (for random).
For Figure 5 Right results:
python run_cifar.py
--dataset cifar10
--data_augmentation
--cutout
--sorting_file cifar10_sorted
--input_dir cifar10_results
--output_dir cifar10_results
--seed s
--remove_n r
--keep_lowest_n k
, where s is in range(1,6)
, r is 5000, and k is in range(0,50000,1000)
(for selected) and -1 (for random).
Cutout implementation from Cutout repository.
For Figure 6 results:
python run_mnist.py
--dataset [mnist/permuted_mnist]
--sorting_file [mnist/permuted_mnist]_sorted
--input_dir [mnist/permuted_mnist]_results
--output_dir [mnist/permuted_mnist]_results
--seed s
--remove_n r
--keep_lowest_n k
, where s is in range(1,6)
, r is in range(0,60000,1000)
, and k is 0 (for selected) and -1 (for random). The results for CIFAR-10 in this figure are the same results as those from Figure 5 Left.
We also investigate how adding noise to the example labels affects forgetting. We introduce label noise by assigning random labels to a specified percentage of the training set.
For Figure 3 results:
python run_cifar.py
--dataset cifar10
--data_augmentation
--output_dir cifar10_results
--noise_percent_labels 20
In the supplementary, we further investigate how introducing additive Gaussian noise to the example pixels affects forgetting.
For Supplementary Figure 11 results:
python run_cifar.py
--dataset cifar10
--data_augmentation
--output_dir cifar10_results
--noise_percent_pixels 100
--noise_std_pixels n
, where n is in [0.5,1,2,10]
.
This project is licensed under the MIT License - see the LICENSE.md file for details