-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extending Gnnexplainer for graph classification. #2597
Conversation
This is super awesome. Thanks a lot! I already fixed the failing |
Hello , I have got an erro "type object 'GNNExplainer' has no attribute 'explain_graph'"...Is that a wrong version for me? |
This feature just got merged. You need to install PyG from master to access it:
|
Thanks a lot! |
My code:x1, edge_index1 = testing_dataset[1].x, testing_dataset[1].edge_index
|
@Luoyunsong , could you please share
This should help me debug this better. |
I was doing a brain network classfication. |
@Luoyunsong, thanks for bringing this up. This is a bug, i'll create a PR to fix this asap. In the meanwhile could to try setting |
Thanks for your help and its a perfect work |
@panisson are you referring to this line in the loss function.
|
Yes, this seems to be indeed a crucial bug. We might need to hot-fix the new release. |
Hi all! Is this feature still available in version |
Yes, graph classification within |
Hi, I am a beginner at GNN. Just have a question, Does the node_feat_mask of explain_graph directly give the node feature importance for the entire graph classification? Or, do we need to follow other steps to find the feature importance after obtaining node_feat_mask? |
You should be able to directly use it, and it will contain the feature importance of the entire graph. |
Hi, is there any tutorial or example code available for explain_graph? Thank you. |
No example yet, but hopefully the test case gets you going: https://github.com/pyg-team/pytorch_geometric/blob/master/test/explain/algorithm/test_gnn_explainer.py#L80 |
@rusty1s thanks a lot. I'm new to XAI methods. Can you please tell me what would be the input here? The training set on which the model (e.g. GCNConv) is trained (to see how the model can differentiate between 2 classes for a binary classification problem) or the held-out test set? |
It depends on what explanations you want to receive (explaining training data or explaining hold out data). In general, I think it is more common to run explanations on hold out data. You can then find common substructures by looking at the edge attribution. |
@rusty1s thank you I'm still working on it. I got the error: 'GNNExplainer' object has no attribute 'explain_graph'. Even though I installed PyG from the master link above. Am I missing something? I am using torch 2.2.0 and torch_geometric 2.5.0 |
Are you using |
@rusty1s yes. This is the code I am running: x1, edge_index1 = test_dataset[1].x, test_dataset[1].edge_index explainer = GNNExplainer(model=model, epochs=400, lr=0.0001) graph_feat_mask, edge_mask = explainer.explain_graph( x1, edge_index1)` Following is the error: File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1688, in Module.getattr(self, name) AttributeError: 'GNNExplainer' object has no attribute 'explain_graph' The model I am using is GraphConv from PyG. I couldn't fix this error. |
That would be an incorrect use of |
sorry, I couldn't understand by looking at the example. For other code in https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer.py (I assume for node classification) I didn't understand what parameters are different for graph-level classification. |
Here is how you would use explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='graph',
return_type='log_probs',
),
) |
Thank you so much for the clarification. I just saw that explain_graph was in the deprecated version. However, the batch_index argument of my model's forward function is causing the problem. I think I didn't understand how to load a test graph in the explainer. Below is the error: Cell In[108], line 1 File ~/.local/lib/python3.10/site-packages/torch_geometric/explain/explainer.py:196, in Explainer.call(self, x, edge_index, target, index, **kwargs) File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs) File ~/.local/lib/python3.10/site-packages/torch_geometric/explain/explainer.py:115, in Explainer.get_prediction(self, *args, **kwargs) TypeError: GCN.forward() missing 1 required positional argument: 'batch_index' |
If explanation = explainer(x=x0, edge_index=edge_index0, batch_index=..., target=None) |
@rusty1s thank you so much it worked finally. However, it generates a slightly different graph every time I re-run the explanation.visualize_graph function. Could you let me know what this is? Also, instead of a directed graph is there a way to get an undirected graph so that later I can map the nodes to atoms to visualize molecules with functional groups or substructures important to the prediction? Thanks a ton, once again <3. |
Hi @rusty1s, I have one more doubt, please. For my binary classification problem, I ran GNNExplainer as guided by you. I noticed that most edges have low edge_masks scores for the class predicted '0' graphs with a low probability score (<0.1 obtained by GCN model), there are some edges with high edge_masks scores but those are rare in this class. I hope you get the point of what I'm asking. Sorry for bothering you again but this is part of my thesis. I hope you're doing great. Thanks a lot, once again. |
You would interpret it as case 2. |
@rusty1s thank you, I was in doubt because a lot of my graph '0' has low edge_masks. Case 2 makes sense as the explainer also works on multiclass classification. |
Dear @rusty1s, I have one more doubt, please. I did not provide any threshold to edge_masks or node_masks in the explainer but still, the explainer can provide the subgraph and I get evaluation metric results as well e.g. unfaithfulness, fidelity score, etc. How the subgraph is generated in my case? |
In this case, node and edge importance is just given as a continuous value between 0 and 1. |
#1374
Updated Gnnexplainer to support graph classification. Following are the key updates.
explain_graph
which should be called for graph classification. This has lots of code similar toexplain_node
. Another solution would be to have aexplain
function that handles both graph or node explanation, however that would meanexplain_node
would have to be retired/depreciated, let me know if you feel that's a better solution.visualize_subgraph
settingnode_idx
to -1 implies a graph classification task.Further i'll add an example for this in a day or two. But feel free to review the code.