Skip to content

Commit

Permalink
Asteroid plot edges (#170)
Browse files Browse the repository at this point in the history
* fixes show_edges for the astroid graph

* fixes show_edges for the asteroid plot

* removed an unnecessary else

* minor linting

* adds an interactive legend for the edges

* Update changelog

Co-authored-by: Arian Jamasb <arjamasb@gmail.com>
  • Loading branch information
avivko and a-r-j authored May 13, 2022
1 parent 91c5f59 commit 1c87afd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
### 1.5.0 - UNRELEASED

* [Feature] - [#170](https://github.com/a-r-j/graphein/pull/170) Adds support for viewing edges in `graphein.protein.visualisation.asteroid_plot`. Contribution by @avivko.
* [Feature] - #163 Adds support for conformer generation for SMILE inputs to molecule graph construction.
* [Feature] - #163 Adds support for molecule graph generation from an RDKit.Chem.Mol input.
* [Feature] - #163 Adds support for multiprocess molecule graph construction.
Expand Down
81 changes: 51 additions & 30 deletions graphein/protein/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import networkx as nx
import numpy as np
import plotly.express as px
import plotly.colors as co
import plotly.graph_objects as go
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
Expand Down Expand Up @@ -118,7 +119,9 @@ def colour_edges(
G: nx.Graph,
colour_map: matplotlib.colors.ListedColormap,
colour_by: str = "kind",
) -> List[Tuple[float, float, float, float]]:
set_alpha: float = 1.0,
return_as_rgba: bool = False,
) -> List[Tuple[float, float, float, float]] or List[str]:
"""
Computes edge colours based on the kind of bond/interaction.
Expand All @@ -128,8 +131,12 @@ def colour_edges(
:type colour_map: matplotlib.colors.ListedColormap
:param colour_by: Edge attribute to colour by. Currently only ``"kind"`` is supported.
:type colour_by: str
:param set_alpha: Sets a given alpha value between 0.0 and 1.0 for all the edge colours.
:type set_alpha: float
:param return_as_rgba: Returns a list of rgba strings instead of tuples.
:type return_as_rgba: bool
:return: List of edge colours.
:rtype: List[Tuple[float, float, float, float]]
:rtype: List[Tuple[float, float, float, float]] or List[str]
"""
if colour_by == "kind":
edge_types = set(
Expand All @@ -147,6 +154,13 @@ def colour_edges(
raise NotImplementedError(
"Other edge colouring methods not implemented."
)

assert (0.0 <= set_alpha <= 1.0), f"Alpha value {set_alpha} must be between 0.0 and 1.0"
colors = [c[:3] + (set_alpha,) for c in colors]
if return_as_rgba:
return [
f"rgba{tuple(list(co.convert_to_RGB_255(c[:3])) + [c[3]])}" for c in colors
]
return colors


Expand Down Expand Up @@ -654,12 +668,14 @@ def asteroid_plot(
colour_nodes_by: str = "shell", # residue_name
colour_edges_by: str = "kind",
edge_colour_map: plt.cm.Colormap = plt.cm.plasma,
edge_alpha: float = 1.0,
show_labels: bool = True,
title: Optional[str] = None,
width: int = 600,
height: int = 500,
use_plotly: bool = True,
show_edges: bool = False,
show_legend: bool = True,
node_size_multiplier: float = 10,
) -> Union[plotly.graph_objects.Figure, matplotlib.figure.Figure]:
"""Plots a k-hop subgraph around a node as concentric shells.
Expand All @@ -678,6 +694,8 @@ def asteroid_plot(
:type colour_edges_by: str
:param edge_colour_map: Colour map for edges. Defaults to ``plt.cm.plasma``.
:type edge_colour_map: plt.cm.Colormap
:param edge_alpha: Sets a given alpha value between 0.0 and 1.0 for all the edge colours.
:type edge_alpha: float
:param title: Title of the plot. Defaults to ``None``.
:type title: str
:param width: Width of the plot. Defaults to ``600``.
Expand All @@ -686,8 +704,10 @@ def asteroid_plot(
:type height: int
:param use_plotly: Use plotly to render the graph. Defaults to ``True``.
:type use_plotly: bool
:param show_edges: Whether or not to show edges in the plot. Defaults to ``False``.
:param show_edges: Whether to show edges in the plot. Defaults to ``False``.
:type show_edges: bool
:param show_legend: Whether to show the legend of the edges. Fefaults to `True``.
:type show_legend: bool
:param node_size_multiplier: Multiplier for the size of the nodes. Defaults to ``10``.
:type node_size_multiplier: float.
:returns: Plotly figure or matplotlib figure.
Expand Down Expand Up @@ -715,34 +735,28 @@ def asteroid_plot(

if show_edges:
edge_colors = colour_edges(
subgraph, colour_map=edge_colour_map, colour_by=colour_edges_by
subgraph, colour_map=edge_colour_map, colour_by=colour_edges_by,
set_alpha=edge_alpha, return_as_rgba=True
)

edge_x: List[str] = []
edge_y: List[str] = []
edge_type: List[str] = []
for u, v in subgraph.edges():
show_legend_bools = [(True if x not in edge_colors[:i] else False)
for i, x in enumerate(edge_colors)]
edge_trace = []
for i, (u, v) in enumerate(subgraph.edges()):
x0, y0 = subgraph.nodes[u]["pos"]
x1, y1 = subgraph.nodes[v]["pos"]
edge_x.append(x0)
edge_x.append(x1)
edge_x.append(None)
edge_y.append(y0)
edge_y.append(y1)
edge_y.append(None)
edge_trace = go.Scatter(
x=edge_x,
y=edge_y,
line=dict(width=1, color=edge_colors),
hoverinfo="text",
mode="lines",
text=[
" / ".join(list(edge_type))
for edge_type in nx.get_edge_attributes(
subgraph, "kind"
).values()
],
)
bond_kind = " / ".join(list(subgraph[u][v]["kind"]))
tr = go.Scatter(
x=(x0, x1),
y=(y0, y1),
mode="lines",
line=dict(width=1, color=edge_colors[i]),
hoverinfo="text",
text=[bond_kind],
name=bond_kind,
legendgroup=bond_kind,
showlegend=show_legend_bools[i],
)
edge_trace.append(tr)

node_x: List[str] = []
node_y: List[str] = []
Expand Down Expand Up @@ -773,6 +787,7 @@ def asteroid_plot(
mode="markers+text" if show_labels else "markers",
hoverinfo="text",
textposition="bottom center",
showlegend=False,
marker=dict(
colorscale="YlGnBu",
reversescale=True,
Expand All @@ -789,15 +804,21 @@ def asteroid_plot(
),
)

data = [edge_trace, node_trace] if show_edges else [node_trace]
data = edge_trace + [node_trace] if show_edges else [node_trace]
fig = go.Figure(
data=data,
layout=go.Layout(
title=title if title else f'Asteroid Plot - {g.graph["name"]}',
width=width,
height=height,
titlefont_size=16,
showlegend=False,
legend=dict(
yanchor="top",
y=1,
xanchor="left",
x=1.10
),
showlegend=True if show_legend else False,
hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(
Expand Down

0 comments on commit 1c87afd

Please sign in to comment.