Skip to content

Latest commit

 

History

History
60 lines (45 loc) · 2.51 KB

README.org

File metadata and controls

60 lines (45 loc) · 2.51 KB

This repository contains code to replicate experiments in the 2022 NeurIPS paper, “Learning Concept Credible Models for Mitigating Shortcuts”.

Given access to a representation based on domain knowledge (i.e. known concepts), we want to learn a model that is accurate regardless of whether the training data is biased (i.e., containing shortcuts that do not hold in practice) and whether the known concepts alone are sufficient for accurate predictions. We call such a model a concept credible model (CCM). To achieve that end, we proposed 2 methods, CCM EYE and CCM RES, that is provably concept credible in some linear settings and can empirically mitigate learning shortcuts even when assumptions are broken.

The code directories are organized as the following

mimic_scripts/ contains the training code for reproducing experiments on MIMIC-CXR dataset. scripts/ contains the training code for reproducing experiments on CUB birds dataset. notebooks/ contains ipython notebook for visualization of the results.

Dependencies are listed in Pipfile and can be installed with pipenv.

To run baseline models for the CUB dataset:

Getting concept C:

python scripts/concept_model.py --transform flip --lr_step 1000 -t 0 -s noise --n_shortcuts 10

Oracle CBM used to generate shortcut

python scripts/cbm.py --lr_step 15 -s noise -t 1 --n_shortcuts 10 --c_model_path <path to C>/concept

For CBM

python scripts/cbm.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts

For STD(X)

python scripts/standard_model.py -s <path to oracle CBM>/cbm.pt --n_shortcuts 10 -t 1

For STD(C, X)

python scripts/ccm.py --lr_step 15 --alpha 0 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)> --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts

For CCM RES

python scripts/ccm_r.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to CBM>/cbm

For CCM EYE

python scripts/ccm.py --lr_step 15 --alpha 0.001 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to C>/concepts

logging

I log all the commands ran using

track log

see how to use my command tracking

track -h