Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards self-explainable graph neural network #15

Open
Tracked by #13
2nazero opened this issue Oct 29, 2024 · 9 comments
Open
Tracked by #13

Towards self-explainable graph neural network #15

2nazero opened this issue Oct 29, 2024 · 9 comments
Assignees
Labels
code Code Analysis paper This issue is about a paper (to read)

Comments

@2nazero
Copy link
Collaborator

2nazero commented Oct 29, 2024

Towards self-explainable graph neural network

@inproceedings{dai2021towards,
  title={Towards self-explainable graph neural network},
  author={Dai, Enyan and Wang, Suhang},
  booktitle={Proceedings of the 30th ACM International Conference on Information \& Knowledge Management},
  pages={302--311},
  year={2021}
}
@2nazero
Copy link
Collaborator Author

2nazero commented Oct 29, 2024

@2nazero 2nazero self-assigned this Oct 31, 2024
@2nazero
Copy link
Collaborator Author

2nazero commented Oct 31, 2024

Overall Summary

The paper addresses the need for self-explainable Graph Neural Network that not only predicts labels for nodes in a graph but also provides clear explanations for why certain nodes are classified as they are.

Traditional post-hoc methods attempt to explain GNN predictions after the model has made them, which may lead to explanations that don’t fully or accurately reflect the true reasoning behind predictions.

Contributions: SE-GNN (Self-Explainable GNN)

  • L-Nearest Labeled Nodes: SE-GNN identifies the closest L labeled nodes for each target node, allowing the model to base predictions on the similarity of features and structures in the local graph.
  • Structural Similarity Matching: SE-GNN performs edge matching to calculate the structural similarity between nodes, providing more robust explanations based on graph topology.
  • Contrastive Learning: By using contrastive learning, SE-GNN distinguishes similar nodes from dissimilar ones, enhancing the model's ability to interpret node relationships.
스크린샷 2024-10-31 오후 9 03 11

@2nazero
Copy link
Collaborator Author

2nazero commented Oct 31, 2024

Similarity Modeling

Node Similarity

Node similarity captures the feature similarity of nodes. SE-GNN first uses an MLP to learn each node’s initial feature embeddings $X$, followed by one GCN layer to aggregate information from neighbors.

$$H_W = \text{MLP}(X), \quad H = \sigma(\tilde{A} H_W W) + H_W$$

$\tilde{A} = D^{-1/2} (A + I) D^{-1/2}$ is the normalized adjacency matrix and $H$ is the final node embedding. Node similarity between nodes $v_i$ and $v_j$ is then calculated using a similarity function, such as cosine similarity:

$$S_N(v_i, v_j) = \text{cosine}(h_i, h_j)$$

Local Structure Similarity

Structural similarity is computed by matching the edges within each node’s 2-hop subgraph.

$$S_E(v_i, v_j) = \frac{1}{K} \sum_{k=1}^{K} \text{cosine}(e_{ik}, e_{jk})$$

$e_{ik}$ and $e_{jk}$ represent the edge embeddings within the local graphs of nodes $v_i$ and $v_j$.

Overall Similarity

$$S(v_i, v_j) = \alpha S_N(v_i, v_j) + (1 - \alpha) S_E(v_i, v_j)$$

Self-Explainable Classification

Prediction with L-Nearest Labeled Nodes

For each target node $v$, the predicted label $\hat{y}_v$ is a weighted average of the labels of its L-nearest labeled neighbors where $w_k$ is the weight assigned based on similarity $S(v, v_k)$, and $y_k$ is the label of neighbor $v_k$.

Self-Supervision

SE-GNN applies contrastive learning to reinforce node and edge similarity for better interpretability, optimizing the separation of similar and dissimilar nodes or edges.

$$L_N = -\log \frac{\exp(\text{cosine}(h_i, h_i^+) / \tau)}{\sum_{j=1}^{N} \exp(\text{cosine}(h_i, h_j) / \tau)}$$

$h_i$ is the representation of node $i$, $h_i^+$ is a positive sample, and $h_j$ are negative samples.

Overall Objective Function

The final loss function for SE-GNN combines classification loss $L_C$, contrastive loss $L_N$, and edge contrastive loss $L_E$

$$L = L_C + \lambda L_N + \eta L_E$$

$\lambda$ and $\eta$ are hyperparameters that control the contributions of the self-supervised loss terms.
스크린샷 2024-10-31 오후 9 42 52

@2nazero
Copy link
Collaborator Author

2nazero commented Oct 31, 2024

Experiments

Datasets

  • Real-World Datasets: Cora, Citeseer, Pubmed
  • Synthetic Datasets: Syn-Cora, BA-Shapes

Classification and Explanation Quality

  • Classification Accuracy
스크린샷 2024-11-01 오후 5 32 43 스크린샷 2024-11-01 오후 5 39 07
  • Explanation Quality (Precision@k)
스크린샷 2024-11-01 오후 5 34 36
  • Syn-Cora
스크린샷 2024-11-01 오후 5 38 16
  • BA-Shapes
스크린샷 2024-11-01 오후 5 38 03

Robustness to Noise

스크린샷 2024-11-01 오후 5 41 01

Ablation Study

  • Local Structure Similarity: The performance of SE-GNN decreases significantly when only 1-hop neighbors are used, highlighting the importance of rich structural information.
  • Self-Supervision Effects: Removing self-supervision on node or structural similarity reduces accuracy and explanation quality, indicating that these components are essential for SE-GNN’s interpretability.

Parameter Sensitivity Analysis

SE-GNN achieves the best performance when hyperparameters are within a moderate range, as too little or too much emphasis on self-supervision can decrease classification and explanation quality.

@2nazero 2nazero added paper This issue is about a paper (to read) code Code Analysis labels Nov 4, 2024
@2nazero
Copy link
Collaborator Author

2nazero commented Nov 7, 2024

Code Comparison 1

스크린샷 2024-11-07 오후 3 17 27
root@e69f0e0aedb4:~/SEGNN# bash train_real.sh
Namespace(K=20, T=1.0, alpha=0.5, attr_mask=0.15, batch_size=32, beta1=0.01, beta2=0.01, cuda=False, dataset='Pubmed', debug=True, epochs=200, hidden=64, hop=2, init=False, lr=0.01, model='DeGNN', nlayer=3, no_cuda=False, seed=12, weight_decay=0.0005)
=== training model ===
Epoch 1, cls loss: 1.1331895589828491 cont_loss: 0.0806 acc_val: 0.6980 best_acc_val: 0.6980
Epoch 2, cls loss: 0.7374804615974426 cont_loss: 0.0806 acc_val: 0.7180 best_acc_val: 0.7180
Epoch 3, cls loss: 0.6413862705230713 cont_loss: 0.0802 acc_val: 0.7400 best_acc_val: 0.7400
Epoch 4, cls loss: 0.58443284034729 cont_loss: 0.0796 acc_val: 0.7380 best_acc_val: 0.7400
Epoch 5, cls loss: 0.5255063772201538 cont_loss: 0.0786 acc_val: 0.7220 best_acc_val: 0.7400
Epoch 6, cls loss: 0.45011571049690247 cont_loss: 0.0769 acc_val: 0.7300 best_acc_val: 0.7400
Epoch 7, cls loss: 0.3780854642391205 cont_loss: 0.0744 acc_val: 0.7160 best_acc_val: 0.7400
Epoch 8, cls loss: 0.2936973571777344 cont_loss: 0.0707 acc_val: 0.7040 best_acc_val: 0.7400
Epoch 9, cls loss: 0.20843859016895294 cont_loss: 0.0676 acc_val: 0.6920 best_acc_val: 0.7400
Epoch 10, cls loss: 0.13450011610984802 cont_loss: 0.0637 acc_val: 0.6920 best_acc_val: 0.7400
Epoch 11, cls loss: 0.0759669616818428 cont_loss: 0.0609 acc_val: 0.6940 best_acc_val: 0.7400
Epoch 12, cls loss: 0.03820379450917244 cont_loss: 0.0589 acc_val: 0.6960 best_acc_val: 0.7400
Epoch 13, cls loss: 0.017903033643960953 cont_loss: 0.0542 acc_val: 0.6940 best_acc_val: 0.7400
...

@2nazero
Copy link
Collaborator Author

2nazero commented Nov 7, 2024

Code Comparison 2

스크린샷 2024-11-07 오후 3 26 32
root@e69f0e0aedb4:~/SEGNN# bash train_syn.sh
/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
Namespace(K=5, T=1, alpha=0.5, attr_mask=0.2, beta1=1, beta2=1, cuda=False, debug=False, dropout=0.5, epochs=50, hidden=64, hop=2, init=False, lr=0.01, model='DeGNN', nlayer=2, no_cuda=False, sample_size=256, seed=11, weight_decay=0.0005)
Accuracy: 1.0000
Node Pair accuracy: 0.9913
Edge Pair accuracy: 0.8334
/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
Namespace(K=5, T=1, alpha=0.5, attr_mask=0.2, beta1=1, beta2=1, cuda=False, debug=False, dropout=0.5, epochs=50, hidden=64, hop=2, init=False, lr=0.01, model='DeGNN', nlayer=2, no_cuda=False, sample_size=256, seed=12, weight_decay=0.0005)
...

@2nazero
Copy link
Collaborator Author

2nazero commented Nov 7, 2024

Code Comparison 3

스크린샷 2024-11-07 오후 3 28 49
root@e69f0e0aedb4:~/SEGNN# bash train_BAshape.sh
/opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0
Namespace(K=10, T=1, alpha=0.5, attr_mask=0.5, beta1=1, beta2=1, cuda=False, debug=True, dropout=0.5, epochs=100, hidden=128, hop=1, init=False, lr=0.01, model='DeGNN', nlayer=2, no_cuda=False, sample_size=128, seed=11, weight_decay=0.0005)
=== training model ===
Epoch 1, cls loss: 0.06207555532455444 cont_loss: 15.9162 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 2, cls loss: 0.07821442931890488 cont_loss: 14.3210 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 3, cls loss: 0.10167255252599716 cont_loss: 13.8908 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 4, cls loss: 0.12009401619434357 cont_loss: 13.4723 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 5, cls loss: 0.13597513735294342 cont_loss: 11.6072 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 6, cls loss: 0.13402707874774933 cont_loss: 11.4642 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 7, cls loss: 0.14052826166152954 cont_loss: 10.2032 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 8, cls loss: 0.1241062581539154 cont_loss: 10.7389 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 9, cls loss: 0.14182981848716736 cont_loss: 11.9345 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 10, cls loss: 0.12085629999637604 cont_loss: 11.7807 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 11, cls loss: 0.15116392076015472 cont_loss: 11.1117 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 12, cls loss: 0.1615379899740219 cont_loss: 10.2897 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 13, cls loss: 0.16297703981399536 cont_loss: 10.4220 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 14, cls loss: 0.14840374886989594 cont_loss: 10.2818 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 15, cls loss: 0.12788088619709015 cont_loss: 11.4748 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 16, cls loss: 0.13422787189483643 cont_loss: 9.9767 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 17, cls loss: 0.13471050560474396 cont_loss: 10.3638 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 18, cls loss: 0.15220044553279877 cont_loss: 10.6026 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 19, cls loss: 0.14824332296848297 cont_loss: 9.7765 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 20, cls loss: 0.15234188735485077 cont_loss: 10.2859 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 21, cls loss: 0.14792807400226593 cont_loss: 10.1910 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 22, cls loss: 0.14518018066883087 cont_loss: 10.4645 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 23, cls loss: 0.14033383131027222 cont_loss: 8.9243 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 24, cls loss: 0.1567702740430832 cont_loss: 8.9019 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 25, cls loss: 0.15553423762321472 cont_loss: 9.9690 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 26, cls loss: 0.1321064978837967 cont_loss: 8.9667 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 27, cls loss: 0.1406252533197403 cont_loss: 9.8881 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 28, cls loss: 0.15064963698387146 cont_loss: 8.9274 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 29, cls loss: 0.14569571614265442 cont_loss: 9.9802 acc_val: 1.0000 best_acc_val: 1.0000
Epoch 30, cls loss: 0.13596659898757935 cont_loss: 10.0723 acc_val: 1.0000 best_acc_val: 1.0000
...

@2nazero
Copy link
Collaborator Author

2nazero commented Nov 7, 2024

Code Comparison 4

스크린샷 2024-11-07 오후 3 33 42
root@e69f0e0aedb4:~/SEGNN# bash test_real.sh
Testing
Accuracy: 0.8000, mAP: 0.7930
cls test results: accuracy= 0.7860
Accuracy of GCN-K: 0.7820, Accuracy of CLS: 0.7860, MAP: 0.7952
cls test results: accuracy= 0.7580
Accuracy of GIN-K: 0.7560, Accuracy of CLS: 0.7580, MAP: 0.7749
cls test results: accuracy= 0.7370
Accuracy of MLP-K: 0.7390, Accuracy of CLS: 0.7370, MAP: 0.7458
Loading citeseer dataset...
Selecting 1 largest connected components
Testing
Accuracy: 0.7494, mAP: 0.7124
cls test results: accuracy= 0.7340
Accuracy of GCN-K: 0.7210, Accuracy of CLS: 0.7340, MAP: 0.7324
cls test results: accuracy= 0.6961
Accuracy of GIN-K: 0.6659, Accuracy of CLS: 0.6961, MAP: 0.6857
cls test results: accuracy= 0.6286
Accuracy of MLP-K: 0.6315, Accuracy of CLS: 0.6286, MAP: 0.6525
Testing
...

@2nazero
Copy link
Collaborator Author

2nazero commented Nov 7, 2024

Code Comparison 5 - Results w/o Codes

These are the results that aren't provided in the code. (Apparently, some results aren't made out of coding such as human evaluation.)

스크린샷 2024-11-07 오후 3 35 13 스크린샷 2024-11-07 오후 3 36 27 스크린샷 2024-11-07 오후 3 36 34

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
code Code Analysis paper This issue is about a paper (to read)
Projects
None yet
Development

No branches or pull requests

1 participant