-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade visualize_graph for explain module #8743
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
(I will add this to another PR, but not here)
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for improving the visualization.
Could you please post some images of how things looked before and after your change.
I've left some comments.
@@ -233,7 +233,7 @@ def visualize_feature_importance( | |||
return _visualize_score(score, feat_labels, path, top_k) | |||
|
|||
def visualize_graph(self, path: Optional[str] = None, | |||
backend: Optional[str] = None): | |||
backend: Optional[str] = None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backend: Optional[str] = None, **kwargs): | |
backend: Optional[str] = None, | |
nodel_label: Optional[Tensor] = None, | |
colors_dict: Optional[Dict[int, str] = None, | |
target_idx: Optional[int]=None): |
Lets add the new arguments as optional arguemnts and add documentation for them. That way the end user is aware of the options available to them.
if target_node != None: | ||
for i, node_id in enumerate(list(g.nodes)): | ||
if node_id == target_node: | ||
print("kylin") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print("kylin") |
@@ -127,10 +134,26 @@ def _visualize_graph_via_networkx( | |||
), | |||
) | |||
|
|||
node_color = ['white'] * len(g.nodes) | |||
if node_label != None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if node_label
is None
won't all nodes be white?
For current
visualize_graph
, the node colors are default 'white', which is not as intuitive as in GNNExplainer: target node is red, different node class with different node color.We add three optional params:
We add lines at two examples
examples/explain/gnn_explainer.py
andexamples/explain/gnn_explainer_ba_shapes.py
I am not sure about how to add optional params. Currently we add them at **kwargs.