Skip to content

Commit

Permalink
Updated heatmap (#750)
Browse files Browse the repository at this point in the history
* test pytest ini

* add binary_heatmap, updates to tutorial notebook

* test plotting

* re-add pytest.ini

* Update plot.py

Fix a few bugs

* add new tests

* Update setup.cfg

* remove binary_heatmap

* mpl version req back to normal

* remove binary_heatmap import

* black

* formatting fix

---------

Co-authored-by: Benjamin Pedigo <benjamindpedigo@gmail.com>
  • Loading branch information
loftusa and bdpedigo authored Jun 13, 2023
1 parent 39f83ae commit 4499f7c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 31 deletions.
105 changes: 79 additions & 26 deletions docs/tutorials/plotting/heatmaps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,29 @@
"outputs": [],
"source": [
"import graspologic\n",
"\n",
"import numpy as np\n",
"%matplotlib inline"
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting graphs using heatmap\n",
"\n",
"### Simulate graphs using weighted stochastic block models\n",
"The 2-block model is defined as below:\n",
"## Plotting Simple Graphs using heatmap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A 2-block Stochastic Block Model is defined as below:\n",
"\n",
"\\begin{align*}\n",
"P = \\begin{bmatrix}0.8 & 0.2 \\\\\n",
"0.2 & 0.8 \n",
"\\end{bmatrix}\n",
"\\end{align*}\n",
"\n",
"We generate two weight SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)."
"In simple cases, the model is unweighted. Below, we plot an unweighted SBM."
]
},
{
Expand All @@ -46,25 +48,28 @@
"outputs": [],
"source": [
"from graspologic.simulations import sbm\n",
"from graspologic.plot import heatmap\n",
"\n",
"n_communities = [50, 50]\n",
"p = [[0.8, 0.2], \n",
" [0.2, 0.8]]\n",
"\n",
"wt = np.random.poisson\n",
"wtargs = dict(lam=3)\n",
"A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"wt = np.random.normal\n",
"wtargs = dict(loc=5, scale=1)\n",
"A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)"
"A, labels = sbm(n_communities, p, return_labels=True)\n",
"heatmap(A, title=\"Basic Heatmap function\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting with Hierarchy Labels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the simulated weighted SBMs"
"If we have labels, we can use them to show communities on a Heatmap."
]
},
{
Expand All @@ -73,10 +78,54 @@
"metadata": {},
"outputs": [],
"source": [
"from graspologic.plot import heatmap\n",
"\n",
"title = 'Weighted Stochastic Block Model with Poisson(3)'\n",
"heatmap(A, inner_hier_labels=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can plot outer hierarchy labels in addition to inner hierarchy labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"outer_labels = [\"Outer Labels\"] * 100\n",
"heatmap(A, inner_hier_labels=labels,\n",
" outer_hier_labels=outer_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weighted SBMs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also use heatmap when our graph is weighted. Here, we generate two weighted SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Draw weights from a Poisson(3) distribution\n",
"wt = np.random.poisson\n",
"wtargs = dict(lam=3)\n",
"A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"# Plot\n",
"title = 'Weighted Stochastic Block Model with \\n weights drawn from a Poisson(3) distribution'\n",
"fig= heatmap(A_poisson, title=title)"
]
},
Expand All @@ -86,18 +135,23 @@
"metadata": {},
"outputs": [],
"source": [
"title = 'Weighted Stochastic Block Model with Normal(5, 1)'\n",
"# Draw weights from a Normal(5, 1) distribution\n",
"wt = np.random.normal\n",
"wtargs = dict(loc=5, scale=1)\n",
"A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)\n",
"\n",
"fig= heatmap(A_normal, title=title)"
"# Plot\n",
"title = 'Weighted Stochastic Block Model with \\n weights drawn from a Normal(5, 1) distribution'\n",
"fig = heatmap(A_normal, title=title)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### You can also change color maps\n",
"### Colormaps\n",
"\n",
"See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps"
"You can change colormaps. See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps."
]
},
{
Expand All @@ -107,8 +161,7 @@
"outputs": [],
"source": [
"title = 'Weighted Stochastic Block Model with Poisson(3)'\n",
"\n",
"fig= heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)"
"fig = heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)"
]
},
{
Expand Down Expand Up @@ -203,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.8.5"
}
},
"nbformat": 4,
Expand Down
22 changes: 17 additions & 5 deletions graspologic/plot/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import warnings
from typing import Any, Collection, Optional, Union

import matplotlib as mpl
Expand Down Expand Up @@ -286,18 +287,22 @@ def heatmap(
if len(xticklabels) != X.shape[1]:
msg = "xticklabels must have same length {}.".format(X.shape[1])
raise ValueError(msg)
elif not isinstance(xticklabels, bool):
msg = "xticklabels must be a bool or a list, not {}".format(type(xticklabels))

elif not isinstance(xticklabels, (bool, int)):
msg = "xticklabels must be a bool, int, or a list, not {}".format(
type(xticklabels)
)
raise TypeError(msg)

if isinstance(yticklabels, list):
if len(yticklabels) != X.shape[0]:
msg = "yticklabels must have same length {}.".format(X.shape[0])
raise ValueError(msg)
elif not isinstance(yticklabels, bool):
msg = "yticklabels must be a bool or a list, not {}".format(type(yticklabels))
elif not isinstance(yticklabels, (bool, int)):
msg = "yticklabels must be a bool, int, or a list, not {}".format(
type(yticklabels)
)
raise TypeError(msg)

# Handle cmap
if not isinstance(cmap, (str, list, Colormap)):
msg = "cmap must be a string, list of colors, or matplotlib.colors.Colormap,"
Expand All @@ -315,6 +320,11 @@ def heatmap(
msg = "cbar must be a bool, not {}.".format(type(center))
raise TypeError(msg)

# Warning on labels
if (inner_hier_labels is None) and (outer_hier_labels is not None):
msg = "outer_hier_labels requires inner_hier_labels to be used."
warnings.warn(msg)

arr = import_graph(X)

arr = _process_graphs(
Expand Down Expand Up @@ -1583,6 +1593,8 @@ def _plot_groups(
fontsize: int = 30,
) -> matplotlib.pyplot.Axes:
inner_labels_arr = np.array(inner_labels)
if outer_labels is not None:
outer_labels_arr = np.array(outer_labels)
plot_outer = True
if outer_labels is None:
outer_labels_arr = np.ones_like(inner_labels)
Expand Down

0 comments on commit 4499f7c

Please sign in to comment.