Skip to content
/ DIR-GNN Public

Official code of "Discovering Invariant Rationales for Graph Neural Networks" (ICLR 2022)

License

Notifications You must be signed in to change notification settings

Wuyxin/DIR-GNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Discovering Invariant Rationales for Graph Neural Networks 🔥

Overview

DIR (ICLR 2022) aims to train intrinsic interpretable Graph Neural Networks that are robust and generalizable to out-of-distribution datasets. The core of this work lies in the construction of interventional distributions, from which causal features are identified. See the quick lead-in below.

  • Q: What are interventional distributions?

    They are basically the distributions when we intervene on one variable or a set of variables in the data generation process. For example, we could intervene on the base graph (highlighted in green or blue), which gives us multiple distributions:

  • Q: How to construct the interventional distributions?
    We design the following model structure to do the intervention in the representation space, where the distribution intervener is in charge of sampling one subgraph from the non-causal pool and fixing it at one end of the rationale generator.

  • Q: How can these interventional distributions help us approach the causal features for rationalization?

    Here is the simple philosophy: No matter what values we assign to the non-causal part, the class label is invariant as long as we observe the causal part. Intuitively, interventional distributions offer us "multiple eyes" to discover the features that make the label invariant upon interventions. And we propose the DIR objective to achieve this goal

    See our paper for the formal description and the principle behind it.

Installation

Note that we require 1.7.0 <= torch_geometric <= 2.0.2. Simple run the cmd to install the python environment (you may want to change cudatoolkit accordingly based on your cuda version) or see requirements.txt for the packages.

sh setup_env.sh
conda activate dir

Data download

  • Spurious-Motif: this dataset can be generated via spmotif_gen/spmotif.ipynb.
  • Graph-SST2: this dataset can be downloaded here.
  • MNIST-75sp: this dataset can be downloaded here. Download mnist_75sp_train.pkl, mnist_75sp_test.pkl, and mnist_75sp_color_noise.pt to the directory data/MNISTSP/raw/.

Run DIR

The hyper-parameters used to train the intrinsic interpretable models are set as default in the argparse.ArgumentParser in the training files. Feel free to change them if needed. We use separate files to train each dataset.

Simply run python -m train.{dataset}_dir to reproduce the results in the paper.

Common Questions:

How does the Rationale Generator update its parameters?: #7

Reference

@inproceedings{
    wu2022dir,
    title={Discovering Invariant Rationales for Graph Neural Networks},
    author={Ying-Xin Wu and Xiang Wang and An Zhang and Xiangnan He and Tat-seng Chua},
    booktitle={ICLR},
    year={2022},
}