Skip to content
/ TRM Public

Learning Representations that Support Robust Transfer of Predictors

Notifications You must be signed in to change notification settings

Newbeeer/TRM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transfer Risk Minimization (TRM)

Code for Learning Representations that Support Robust Transfer of Predictors

Yilun Xu, Tommi Jaakkola

TL,DR: We introduce a simple robust estimation criterion -- transfer risk -- that is specifically geared towards optimizing transfer to new environments. Effectively, the criterion amounts to finding a representation that minimizes the risk of applying any optimal predictor trained on one environment to another. The transfer risk essentially decomposes into two terms, a direct transfer term and a weighted gradient-matching term arising from the optimality of per-environment predictors.

Prepare the Datasets

Download PACS/Office-Home/MNIST dataset:

python scripts/download.py --data_dir {data_dir}

Places dataset can be downloaded at:

http://data.csail.mit.edu/places/places365/train_256_places365standard.tar ;

COCO dataset can be downloaded at:

http://images.cocodataset.org/annotations/annotations_trainval2017.zip

Preprocess the SceneCOCO dataset :

# preprocess COCO
python coco.py
# preprocess Places
python places.py

# generate SceceCOCO dataset
python cocoplaces.py

Running the Experiments

  • Datasets:

    • Synthetic datasets for controlled experiments: ColorMNIST / SceneCOCO
    • Real-world datasets: PACS / Office-Home
python -m domainbed.scripts.train  --data_dir {root} --algorithm {alg} \
	--dataset {dataset} --trial_seed {t_seed} --epochs {epochs}  (--shift {shift}) (--resnet50) (--test_eval)

root: root directory for the data
alg: ERM, VREx, IRM, GroupDRO, Fish, MLDG, TRM
t_seed: seed for data splitting
dataset: PACS or OfficeHome or ColoredMNIST or SceneCOCO
epochs: training epochs
resnet50: set ResNet50 as the backbone (default: ResNet18)
shift: for ColoredMNIST and SceneCOCO only, 0:label-correlated; 1: label-uncorrelated; 2: combine shift.
test_eval: test-domain validation set (default: train-domain validation set)

This implementation is based on / inspired by:

About

Learning Representations that Support Robust Transfer of Predictors

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published