When the test distribution differs from the training distribution, machine learning models can perform poorly and wrongly overestimate their performance. In this work, we aim to better estimate the model's performance under distribution shift, without supervision. To do so, we use a set of domain-invariant predictors as a proxy for the unknown, true target labels. The error of this performance estimation is bounded by the target risk of the proxy model.
Estimating Generalization under Distribution Shifts via Domain-Invariant Representations
Ching-Yao Chuang,
Antonio Torralba, and
Stefanie Jegelka
In International Conference on Machine Learning (ICML), 2020.
- Python 3.7
- PyTorch 1.3.1
- PIL
We will examine our method on two datasets: MNIST (source) and MNIST-M (target) where we assume that the labels of MNIST-M are not acceesible while estimating. The goal is to estimate the generalization of models trained on MNIST on MNIST-M.
Download the MNIST-M dataset from Google Drive and unzip it.
mkdir dataset
cd dataset
tar -zvxf mnist_m.tar.gz
The main idea of this work is to use domain adaptation models as a proxy to unknown labels. In particular, we first train a domain adversarial neural network (DANN) with the following command:
python pretrain.py
After training, the check model will be saved as checkpoints/model_check.pth
. Equipped with the pretrained check model, we can estimate the proxy risk of itself or other hypotheses by maximizing the disagreement (Algorithm 1 in the paper).
Flags:
--model_path
: specify the path to candidate model.--check_model_path
: specify the path to pretrained check model.--eps
: constraint for the domain-invariant loss of check models.--lam
: Tradeoff parameter for maximizing disgreement.
For instance, to estimate the proxy risk of the check model itself (DANN) with default setting, run
python proxy_risk.py --model_path checkpoints/model_check.pth --check_model_path checkpoints/model_check.pth
Next, we examine our method by estimating the proxy risk for non-adaptive models that are trained only on the source, i.e., standard supervised learning. Pretrain the supervised model on MNIST:
python suptrain.py
Estimate proxy risk:
python proxy_risk.py --model_path checkpoints/model_source.pth --check_model_path checkpoints/model_check.pth
If you find this repo useful for your research, please consider citing the paper
@article{chuang2020estimating,
title={Estimating Generalization under Distribution Shifts via Domain-Invariant Representations},
author={Chuang, Ching-Yao and Torralba, Antonio and Jegelka, Stefanie},
journal={International conference on machine learning},
year={2020}
}
For any questions, please contact Ching-Yao Chuang (cychuang@mit.edu).
Part of this code is inspired by fungtion/DANN.