Skip to content

Commit

Permalink
make GMM happen internally in pairplot_with_gmm
Browse files Browse the repository at this point in the history
  • Loading branch information
loftusa committed Apr 16, 2021
1 parent 7570404 commit 11dfc85
Showing 1 changed file with 82 additions and 66 deletions.
148 changes: 82 additions & 66 deletions graspologic/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from scipy import linalg
from sklearn.preprocessing import Binarizer
from sklearn.utils import check_array, check_consistent_length, check_X_y
from sklearn.mixture import GaussianMixture

from ..embed import selectSVD
from ..embed import selectSVD, select_dimension
from ..utils import import_graph, pass_to_ranks


Expand Down Expand Up @@ -384,7 +385,8 @@ def binary_heatmap(
"cmap is not allowed in a binary heatmap. To change colors, use the `colors` parameter."
)
if not (
isinstance(colorbar_ticklabels, (list, tuple)) and len(colorbar_ticklabels) == 2):
isinstance(colorbar_ticklabels, (list, tuple)) and len(colorbar_ticklabels) == 2
):
raise ValueError("colorbar_ticklabels must be list-like and length 2.")

cmap = mpl.colors.ListedColormap(colors)
Expand Down Expand Up @@ -795,6 +797,8 @@ def _plot_ellipse_and_data(
def pairplot_with_gmm(
X,
gmm,
covariance_type="full",
n_components=None,
labels=None,
cluster_palette="Set1",
label_palette="Set1",
Expand All @@ -807,66 +811,74 @@ def pairplot_with_gmm(
histplot_kws={},
):
r"""
Plot pairwise relationships in a dataset, also showing a clustering predicted by
a Gaussian mixture model.
By default, this function will create a grid of axes such that each dimension
in data will by shared in the y-axis across a single row and in the x-axis
across a single column.
The off-diagonal axes show the pairwise relationships displayed as scatterplot.
The diagonal axes show the univariate distribution of the data for that
dimension displayed as either a histogram or kernel density estimates (KDEs).
Read more in the `Pairplot with GMM: Visualizing High Dimensional Data and
Clustering Tutorial
<https://microsoft.github.io/graspologic/tutorials/plotting/pairplot_with_gmm.html>`_
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
gmm: GaussianMixture object
A fit :class:`sklearn.mixture.GaussianMixture` object.
Gaussian mixture models (GMMs) are probabilistic models for representing data
based on normally distributed subpopulations, GMM clusters each data point into
a corresponding subpopulation.
labels : array-like or list, shape (n_samples), optional
Labels that correspond to each sample in ``X``.
If labels are not passed in then labels are predicted by ``gmm``.
label_palette : str or dict, optional, default: 'Set1'
Palette used to color points if ``labels`` are passed in.
cluster_palette : str or dict, optional, default: 'Set1'
Palette used to color GMM ellipses (and points if no ``labels`` are passed).
title : string, default: ""
Title of the plot.
legend_name : string, default: None
Name to put above the legend.
If ``None``, will be "Cluster" if no custom ``labels`` are passed, and ""
otherwise.
context : None, or one of {talk (default), paper, notebook, poster}
Seaborn plotting context
font_scale : float, optional, default: 1
Separate scaling factor to independently scale the size of the font
elements.
alpha : float, optional, default: 0.7
Opacity value of plotter markers between 0 and 1
figsize : tuple
The size of the 2d subplots configuration
histplot_kws : dict, default: {}
Keyword arguments passed down to :func:`seaborn.histplot`
Returns
-------
fig : matplotlib Figure
axes : np.ndarray
Array of matplotlib Axes
See Also
--------
graspologic.plot.pairplot
graspologic.cluster.AutoGMMCluster
sklearn.mixture.GaussianMixture
Plot pairwise relationships in a dataset, also showing a clustering predicted by
a Gaussian mixture model.
By default, this function will create a grid of axes such that each dimension
in data will by shared in the y-axis across a single row and in the x-axis
across a single column.
The off-diagonal axes show the pairwise relationships displayed as scatterplot.
The diagonal axes show the univariate distribution of the data for that
dimension displayed as either a histogram or kernel density estimates (KDEs).
Read more in the `Pairplot with GMM: Visualizing High Dimensional Data and
Clustering Tutorial
<https://microsoft.github.io/graspologic/tutorials/plotting/pairplot_with_gmm.html>`_
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
covariance_type : str, default: 'full'
{‘full’, ‘tied’, ‘diag’, ‘spherical’}
String describing the type of covariance parameters to use. Must be one of:
‘full’
each component has its own general covariance matrix
‘tied’
all components share the same general covariance matrix
‘diag’
each component has its own diagonal covariance matrix
‘spherical’
each component has its own single variancee
n_components : int or None, default: None
Desired dimensionality of output data. If None, selects an embedding dimension.
labels : array-like or list, shape (n_samples), optional
Labels that correspond to each sample in ``X``.
If labels are not passed in then labels are predicted by ``gmm``.
label_palette : str or dict, optional, default: 'Set1'
Palette used to color points if ``labels`` are passed in.
cluster_palette : str or dict, optional, default: 'Set1'
Palette used to color GMM ellipses (and points if no ``labels`` are passed).
title : string, default: ""
Title of the plot.
legend_name : string, default: None
Name to put above the legend.
If ``None``, will be "Cluster" if no custom ``labels`` are passed, and ""
otherwise.
context : None, or one of {talk (default), paper, notebook, poster}
Seaborn plotting context
font_scale : float, optional, default: 1
Separate scaling factor to independently scale the size of the font
elements.
alpha : float, optional, default: 0.7
Opacity value of plotter markers between 0 and 1
figsize : tuple
The size of the 2d subplots configuration
histplot_kws : dict, default: {}
Keyword arguments passed down to :func:`seaborn.histplot`
Returns
-------
fig : matplotlib Figure
axes : np.ndarray
Array of matplotlib Axes
See Also
--------
graspologic.plot.pairplot
graspologic.cluster.AutoGMMCluster
sklearn.mixture.GaussianMixture
"""
# Handle X and labels
if labels is not None:
Expand All @@ -877,10 +889,14 @@ def pairplot_with_gmm(
else:
# sets default if no custom labels passed
legend_name = "Cluster"
# Handle gmm
if gmm is None:
msg = "You must input a sklearn.mixture.GaussianMixture"
raise NameError(msg)

if n_components is None:
elbows, _ = select_dimension(X, n_elbows=2, threshold=None)
n_components = elbows[-1]

gmm = GaussianMixture(
n_components=n_components, covariance_type=covariance_type
).fit(X)
Y_, means, covariances = gmm.predict(X), gmm.means_, gmm.covariances_
data = pd.DataFrame(data=X)
n_components = gmm.n_components
Expand Down

0 comments on commit 11dfc85

Please sign in to comment.