This is code for the paper Generalized Variational Continual Learning1. This repository is originally based on the HAT2 repository.
Continual learning deals with training models on new tasks and datasets in an online fashion. One strand of research has used probabilistic regularization for continual learning, with two of the main approaches in this vein being Online Elastic Weight Consolidation (Online EWC) and Variational Continual Learning (VCL). VCL employs variational inference, which in other settings has been improved empirically by applying likelihood-tempering. We show that applying this modification to VCL recovers Online EWC as a limiting case, allowing for interpolation between the two approaches. We term the general algorithm Generalized VCL (GVCL). In order to mitigate the observed overpruning effect of VI, we take inspiration from a common multi-task architecture, neural networks with task-specific FiLM layers, and find that this addition leads to significant performance gains, specifically for variational methods. In the small-data regime, GVCL strongly outperforms existing baselines. In larger datasets, GVCL with FiLM layers outperforms or is competitive with existing baselines in terms of accuracy, whilst also providing significantly better calibration.
Noel Loo, Siddharth Swaroop, Richard E. Turner
-
Create a python 3 conda environment (check the requirements.txt file)
-
To run chasy experiments, run src/dataloaders/hasy_utils.py to download the dataset
-
The following folder structure is expected at runtime. From the git folder:
- src/ : Where all the scripts lie (already produced by the repo)
- dat/ : Place to put/download all data sets
- res/ : Place to save results
- tmp/ : Place to store temporary files
-
The main script is src/run.py. To run multiple experiments we use src/run_multi.py or src/work.py; to run the compression experiment we use src/run_compression.sh.
- The original HAT repository had mnist twice (instead of fashion mnist) for the mixed vision tasks so results on that benchmark may differ
- The the implementation of EWC and IMM-mode is different than the orginal repository, since since the original repository calculated the FIM using batches instead of individual samples
- The two ipython notebooks are for the toy examples in appendix A and B
1 Noel Loo, Siddharth Swaroop, & Richard E Turner (2021). Generalized Variational Continual Learning. In International Conference on Learning Representations.
2 Serrà, J., Surís, D., Miron, M. & Karatzoglou, A.. (2018). Overcoming Catastrophic Forgetting with Hard Attention to the Task. Proceedings of the 35th International Conference on Machine Learning, in PMLR 80:4548-4557