Skip to content

MLD3/ConceptCredibleModel

Repository files navigation

This repository contains code to replicate experiments in the 2022 NeurIPS paper, “Learning Concept Credible Models for Mitigating Shortcuts”.

Given access to a representation based on domain knowledge (i.e. known concepts), we want to learn a model that is accurate regardless of whether the training data is biased (i.e., containing shortcuts that do not hold in practice) and whether the known concepts alone are sufficient for accurate predictions. We call such a model a concept credible model (CCM). To achieve that end, we proposed 2 methods, CCM EYE and CCM RES, that is provably concept credible in some linear settings and can empirically mitigate learning shortcuts even when assumptions are broken.

The code directories are organized as the following

mimic_scripts/ contains the training code for reproducing experiments on MIMIC-CXR dataset. scripts/ contains the training code for reproducing experiments on CUB birds dataset. notebooks/ contains ipython notebook for visualization of the results.

Dependencies are listed in Pipfile and can be installed with pipenv.

To run baseline models for the CUB dataset:

Getting concept C:

python scripts/concept_model.py --transform flip --lr_step 1000 -t 0 -s noise --n_shortcuts 10

Oracle CBM used to generate shortcut

python scripts/cbm.py --lr_step 15 -s noise -t 1 --n_shortcuts 10 --c_model_path <path to C>/concept

For CBM

python scripts/cbm.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts

For STD(X)

python scripts/standard_model.py -s <path to oracle CBM>/cbm.pt --n_shortcuts 10 -t 1

For STD(C, X)

python scripts/ccm.py --lr_step 15 --alpha 0 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)> --c_model_path outputs/9843d41ae4c711ebb773ac1f6b24a434/concepts

For CCM RES

python scripts/ccm_r.py --lr_step 15 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to CBM>/cbm

For CCM EYE

python scripts/ccm.py --lr_step 15 --alpha 0.001 -s <path to oracle CBM>/cbm.pt -t 1 --n_shortcuts 10 --u_model_path <path to STD(X)>/standard --c_model_path <path to C>/concepts

logging

I log all the commands ran using

track log

see how to use my command tracking

track -h

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published