diff --git a/docs/tutorials/plotting/heatmaps.ipynb b/docs/tutorials/plotting/heatmaps.ipynb index be05c85e4..0c93e54de 100644 --- a/docs/tutorials/plotting/heatmaps.ipynb +++ b/docs/tutorials/plotting/heatmaps.ipynb @@ -16,19 +16,21 @@ "outputs": [], "source": [ "import graspologic\n", - "\n", - "import numpy as np\n", - "%matplotlib inline" + "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Plotting graphs using heatmap\n", - "\n", - "### Simulate graphs using weighted stochastic block models\n", - "The 2-block model is defined as below:\n", + "## Plotting Simple Graphs using heatmap" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A 2-block Stochastic Block Model is defined as below:\n", "\n", "\\begin{align*}\n", "P = \\begin{bmatrix}0.8 & 0.2 \\\\\n", @@ -36,7 +38,7 @@ "\\end{bmatrix}\n", "\\end{align*}\n", "\n", - "We generate two weight SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)." + "In simple cases, the model is unweighted. Below, we plot an unweighted SBM." ] }, { @@ -46,25 +48,28 @@ "outputs": [], "source": [ "from graspologic.simulations import sbm\n", + "from graspologic.plot import heatmap\n", "\n", "n_communities = [50, 50]\n", "p = [[0.8, 0.2], \n", " [0.2, 0.8]]\n", "\n", - "wt = np.random.poisson\n", - "wtargs = dict(lam=3)\n", - "A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n", - "\n", - "wt = np.random.normal\n", - "wtargs = dict(loc=5, scale=1)\n", - "A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)" + "A, labels = sbm(n_communities, p, return_labels=True)\n", + "heatmap(A, title=\"Basic Heatmap function\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting with Hierarchy Labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Plot the simulated weighted SBMs" + "If we have labels, we can use them to show communities on a Heatmap." ] }, { @@ -73,10 +78,54 @@ "metadata": {}, "outputs": [], "source": [ - "from graspologic.plot import heatmap\n", - "\n", - "title = 'Weighted Stochastic Block Model with Poisson(3)'\n", + "heatmap(A, inner_hier_labels=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can plot outer hierarchy labels in addition to inner hierarchy labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "outer_labels = [\"Outer Labels\"] * 100\n", + "heatmap(A, inner_hier_labels=labels,\n", + " outer_hier_labels=outer_labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Weighted SBMs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also use heatmap when our graph is weighted. Here, we generate two weighted SBMs where the weights are distributed from a Poisson(3) and Normal(5, 1)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Draw weights from a Poisson(3) distribution\n", + "wt = np.random.poisson\n", + "wtargs = dict(lam=3)\n", + "A_poisson= sbm(n_communities, p, wt=wt, wtargs=wtargs)\n", "\n", + "# Plot\n", + "title = 'Weighted Stochastic Block Model with \\n weights drawn from a Poisson(3) distribution'\n", "fig= heatmap(A_poisson, title=title)" ] }, @@ -86,18 +135,23 @@ "metadata": {}, "outputs": [], "source": [ - "title = 'Weighted Stochastic Block Model with Normal(5, 1)'\n", + "# Draw weights from a Normal(5, 1) distribution\n", + "wt = np.random.normal\n", + "wtargs = dict(loc=5, scale=1)\n", + "A_normal = sbm(n_communities, p, wt=wt, wtargs=wtargs)\n", "\n", - "fig= heatmap(A_normal, title=title)" + "# Plot\n", + "title = 'Weighted Stochastic Block Model with \\n weights drawn from a Normal(5, 1) distribution'\n", + "fig = heatmap(A_normal, title=title)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### You can also change color maps\n", + "### Colormaps\n", "\n", - "See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps" + "You can change colormaps. See [here](https://matplotlib.org/tutorials/colors/colormaps.html) for a list of colormaps." ] }, { @@ -107,8 +161,7 @@ "outputs": [], "source": [ "title = 'Weighted Stochastic Block Model with Poisson(3)'\n", - "\n", - "fig= heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)" + "fig = heatmap(A_poisson, title=title, transform=None, cmap=\"binary\", center=None)" ] }, { @@ -203,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.8.5" } }, "nbformat": 4, diff --git a/graspologic/plot/plot.py b/graspologic/plot/plot.py index 3480419dd..e8fcf4981 100644 --- a/graspologic/plot/plot.py +++ b/graspologic/plot/plot.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation and contributors. # Licensed under the MIT License. +import warnings from typing import Any, Collection, Optional, Union import matplotlib as mpl @@ -286,18 +287,22 @@ def heatmap( if len(xticklabels) != X.shape[1]: msg = "xticklabels must have same length {}.".format(X.shape[1]) raise ValueError(msg) - elif not isinstance(xticklabels, bool): - msg = "xticklabels must be a bool or a list, not {}".format(type(xticklabels)) + + elif not isinstance(xticklabels, (bool, int)): + msg = "xticklabels must be a bool, int, or a list, not {}".format( + type(xticklabels) + ) raise TypeError(msg) if isinstance(yticklabels, list): if len(yticklabels) != X.shape[0]: msg = "yticklabels must have same length {}.".format(X.shape[0]) raise ValueError(msg) - elif not isinstance(yticklabels, bool): - msg = "yticklabels must be a bool or a list, not {}".format(type(yticklabels)) + elif not isinstance(yticklabels, (bool, int)): + msg = "yticklabels must be a bool, int, or a list, not {}".format( + type(yticklabels) + ) raise TypeError(msg) - # Handle cmap if not isinstance(cmap, (str, list, Colormap)): msg = "cmap must be a string, list of colors, or matplotlib.colors.Colormap," @@ -315,6 +320,11 @@ def heatmap( msg = "cbar must be a bool, not {}.".format(type(center)) raise TypeError(msg) + # Warning on labels + if (inner_hier_labels is None) and (outer_hier_labels is not None): + msg = "outer_hier_labels requires inner_hier_labels to be used." + warnings.warn(msg) + arr = import_graph(X) arr = _process_graphs( @@ -1583,6 +1593,8 @@ def _plot_groups( fontsize: int = 30, ) -> matplotlib.pyplot.Axes: inner_labels_arr = np.array(inner_labels) + if outer_labels is not None: + outer_labels_arr = np.array(outer_labels) plot_outer = True if outer_labels is None: outer_labels_arr = np.ones_like(inner_labels)