Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heatmap updates #750

Merged
merged 29 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0a8a7ee
test pytest ini
loftusa Apr 3, 2021
3ae1f60
add binary_heatmap, updates to tutorial notebook
loftusa Apr 4, 2021
4254a05
test plotting
loftusa Apr 4, 2021
88cc245
re-add pytest.ini
loftusa Apr 5, 2021
1cf81b1
Merge branch 'dev' into heatmap-updates
bdpedigo Apr 6, 2021
1d3c93d
Merge branch 'dev' into heatmap-updates
loftusa Apr 7, 2021
116cf17
Merge branch 'dev' into heatmap-updates
loftusa Apr 11, 2021
3489094
Update plot.py
loftusa Apr 12, 2021
7570404
add new tests
loftusa Apr 12, 2021
0c776ae
Merge branch 'dev' into heatmap-updates
bdpedigo May 12, 2021
faba722
Update setup.cfg
loftusa Jul 14, 2021
0e0e47f
Merge remote-tracking branch 'upstream/dev' into dev
loftusa Aug 2, 2021
0957812
Merge branch 'dev' of https://github.com/microsoft/graspologic into dev
loftusa Aug 24, 2021
0f5e123
merge
loftusa Aug 24, 2021
90ad691
remove binary_heatmap
loftusa Aug 24, 2021
50dcb7d
mpl version req back to normal
loftusa Aug 24, 2021
8bb555f
Merge branch 'dev' into heatmap-updates
bdpedigo Sep 15, 2021
74554a5
remove binary_heatmap import
loftusa Sep 15, 2021
7a82f10
Merge branch 'dev' into heatmap-updates
bdpedigo Sep 15, 2021
44f372e
Merge branch 'dev' into heatmap-updates
loftusa Oct 28, 2021
2fec091
black
loftusa Oct 29, 2021
da0232e
Merge branch 'dev' into heatmap-updates
bdpedigo Nov 2, 2021
a0250cf
Merge branch 'dev' into heatmap-updates
loftusa Dec 7, 2021
830b1ed
Merge branch 'dev' into heatmap-updates
bdpedigo Dec 13, 2021
92bb95d
Merge branch 'dev' into heatmap-updates
loftusa Jan 18, 2022
cb00d80
formatting fix
loftusa Jan 18, 2022
0aaadc9
Merge branch 'heatmap-updates' of https://github.com/loftusa/graspolo…
loftusa Jan 18, 2022
69b8745
Merge branch 'dev' into heatmap-updates
bdpedigo Jun 27, 2022
05077e6
Merge branch 'dev' into heatmap-updates
loftusa Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 79 additions & 26 deletions docs/tutorials/plotting/heatmaps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,29 @@
"outputs": [],
"source": [
"import graspologic\n",
"\n",
"import numpy as np\n",
"%matplotlib inline"
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting graphs using heatmap\n",
"\n",
"### Simulate graphs using weighted stochastic block models\n",
"The 2-block model is defined as below:\n",
"## Plotting Simple Graphs using heatmap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A 2-block Stochastic Block Model is defined as below:\n",
"\n",
"\\begin{align*}\n",
"P = \\begin{bmatrix}0.8 & 0.2 \\\\\n",
"0.2 & 0.8 \n",
"\\end{bmatrix}\n",
"\\end{align*}\n",
"\n",
"We generate two weight SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)."
"In simple cases, the model is unweighted. Below, we plot an unweighted SBM."
]
},
{
Expand All @@ -46,25 +48,28 @@
"outputs": [],
"source": [
"from graspologic.simulations import sbm\n",
"from graspologic.plot import heatmap\n",
"\n",
"n_communities = [50, 50]\n",
"p = [[0.8, 0.2], \n",
" [0.2, 0.8]]\n",
"\n",
"wt = np.random.poisson\n",
"wtargs = dict(lam=3)\n",
"A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"wt = np.random.normal\n",
"wtargs = dict(loc=5, scale=1)\n",
"A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)"
"A, labels = sbm(n_communities, p, return_labels=True)\n",
"heatmap(A, title=\"Basic Heatmap function\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting with Hierarchy Labels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the simulated weighted SBMs"
"If we have labels, we can use them to show communities on a Heatmap."
]
},
{
Expand All @@ -73,10 +78,54 @@
"metadata": {},
"outputs": [],
"source": [
"from graspologic.plot import heatmap\n",
"\n",
"title = 'Weighted Stochastic Block Model with Poisson(3)'\n",
"heatmap(A, inner_hier_labels=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can plot outer hierarchy labels in addition to inner hierarchy labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"outer_labels = [\"Outer Labels\"] * 100\n",
"heatmap(A, inner_hier_labels=labels,\n",
" outer_hier_labels=outer_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weighted SBMs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also use heatmap when our graph is weighted. Here, we generate two weighted SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Draw weights from a Poisson(3) distribution\n",
"wt = np.random.poisson\n",
"wtargs = dict(lam=3)\n",
"A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"# Plot\n",
"title = 'Weighted Stochastic Block Model with \\n weights drawn from a Poisson(3) distribution'\n",
"fig= heatmap(A_poisson, title=title)"
]
},
Expand All @@ -86,18 +135,23 @@
"metadata": {},
"outputs": [],
"source": [
"title = 'Weighted Stochastic Block Model with Normal(5, 1)'\n",
"# Draw weights from a Normal(5, 1) distribution\n",
"wt = np.random.normal\n",
"wtargs = dict(loc=5, scale=1)\n",
"A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"fig= heatmap(A_normal, title=title)"
"# Plot\n",
"title = 'Weighted Stochastic Block Model with \\n weights drawn from a Normal(5, 1) distribution'\n",
"fig = heatmap(A_normal, title=title)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### You can also change color maps\n",
"### Colormaps\n",
"\n",
"See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps"
"You can change colormaps. See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps."
]
},
{
Expand All @@ -107,8 +161,7 @@
"outputs": [],
"source": [
"title = 'Weighted Stochastic Block Model with Poisson(3)'\n",
"\n",
"fig= heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)"
"fig = heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)"
]
},
{
Expand Down Expand Up @@ -203,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
22 changes: 17 additions & 5 deletions graspologic/plot/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import warnings
from typing import Any, Collection, Optional, Union

import matplotlib as mpl
Expand Down Expand Up @@ -286,18 +287,22 @@ def heatmap(
if len(xticklabels) != X.shape[1]:
msg = "xticklabels must have same length {}.".format(X.shape[1])
raise ValueError(msg)
elif not isinstance(xticklabels, bool):
msg = "xticklabels must be a bool or a list, not {}".format(type(xticklabels))

elif not isinstance(xticklabels, (bool, int)):
msg = "xticklabels must be a bool, int, or a list, not {}".format(
type(xticklabels)
)
bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(msg)

if isinstance(yticklabels, list):
if len(yticklabels) != X.shape[0]:
msg = "yticklabels must have same length {}.".format(X.shape[0])
raise ValueError(msg)
elif not isinstance(yticklabels, bool):
msg = "yticklabels must be a bool or a list, not {}".format(type(yticklabels))
elif not isinstance(yticklabels, (bool, int)):
msg = "yticklabels must be a bool, int, or a list, not {}".format(
type(yticklabels)
)
raise TypeError(msg)

# Handle cmap
if not isinstance(cmap, (str, list, Colormap)):
msg = "cmap must be a string, list of colors, or matplotlib.colors.Colormap,"
Expand All @@ -315,6 +320,11 @@ def heatmap(
msg = "cbar must be a bool, not {}.".format(type(center))
raise TypeError(msg)

# Warning on labels
if (inner_hier_labels is None) and (outer_hier_labels is not None):
msg = "outer_hier_labels requires inner_hier_labels to be used."
warnings.warn(msg)

bdpedigo marked this conversation as resolved.
Show resolved Hide resolved
arr = import_graph(X)

arr = _process_graphs(
Expand Down Expand Up @@ -1583,6 +1593,8 @@ def _plot_groups(
fontsize: int = 30,
) -> matplotlib.pyplot.Axes:
inner_labels_arr = np.array(inner_labels)
if outer_labels is not None:
outer_labels_arr = np.array(outer_labels)
plot_outer = True
if outer_labels is None:
outer_labels_arr = np.ones_like(inner_labels)
Expand Down