Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return figures instead of showing them #441

Merged
merged 5 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gtda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def transform_plot(self, X, sample=0, **plot_params):
sample : int
Sample to be plotted.

plot_params : dict
**plot_params
Optional plotting parameters.

Returns
Expand All @@ -140,6 +140,6 @@ def transform_plot(self, X, sample=0, **plot_params):

"""
Xt = self.transform(X[sample:sample+1])
self.plot(Xt, sample=0, **plot_params)
self.plot(Xt, sample=0, **plot_params).show()

return Xt
54 changes: 48 additions & 6 deletions gtda/diagrams/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transform(self, X, y=None):
return Xt

@staticmethod
def plot(Xt, sample=0):
def plot(Xt, sample=0, plotly_params=None):
"""Plot a sample from a collection of persistence diagrams.

Parameters
Expand All @@ -101,9 +101,23 @@ def plot(Xt, sample=0):
sample : int, optional, default: ``0``
Index of the sample in `Xt` to be plotted.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
return plot_diagram(
Xt[sample], homology_dimensions=[np.inf])
Xt[sample], homology_dimensions=[np.inf],
plotly_params=plotly_params
)


@adapt_fit_transform_docs
Expand Down Expand Up @@ -282,7 +296,7 @@ def inverse_transform(self, X):
Xs[:, :, :2] *= self.scale_
return Xs

def plot(self, Xt, sample=0, homology_dimensions=None):
def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None):
"""Plot a sample from a collection of persistence diagrams, with
homology in multiple dimensions.

Expand All @@ -299,14 +313,28 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
Which homology dimensions to include in the plot. ``None`` is
equivalent to passing :attr:`homology_dimensions_`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
if homology_dimensions is None:
_homology_dimensions = self.homology_dimensions_
else:
_homology_dimensions = homology_dimensions

return plot_diagram(
Xt[sample], homology_dimensions=_homology_dimensions)
Xt[sample], homology_dimensions=_homology_dimensions,
plotly_params=plotly_params
)


@adapt_fit_transform_docs
Expand Down Expand Up @@ -419,7 +447,7 @@ def transform(self, X, y=None):
Xt = _filter(X, self.homology_dimensions_, self.epsilon)
return Xt

def plot(self, Xt, sample=0, homology_dimensions=None):
def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None):
"""Plot a sample from a collection of persistence diagrams, with
homology in multiple dimensions.

Expand All @@ -436,11 +464,25 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
Which homology dimensions to include in the plot. ``None`` is
equivalent to passing :attr:`homology_dimensions_`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
if homology_dimensions is None:
_homology_dimensions = self.homology_dimensions_
else:
_homology_dimensions = homology_dimensions

return plot_diagram(
Xt[sample], homology_dimensions=_homology_dimensions)
Xt[sample], homology_dimensions=_homology_dimensions,
plotly_params=plotly_params
)
83 changes: 69 additions & 14 deletions gtda/diagrams/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
Which homology dimensions to include in the plot. ``None`` means
plotting all dimensions present in :attr:`homology_dimensions_`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)

Expand Down Expand Up @@ -224,7 +229,7 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
mode='lines', showlegend=True,
name=f"H{int(dim)}"))

fig.show()
return fig


@adapt_fit_transform_docs
Expand Down Expand Up @@ -383,6 +388,11 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
``None`` means plotting all dimensions present in
:attr:`homology_dimensions_`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)

Expand Down Expand Up @@ -445,7 +455,7 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
hoverinfo='none',
name=f"Layer {layer + 1}"))

fig.show()
return fig


@adapt_fit_transform_docs
Expand Down Expand Up @@ -595,7 +605,8 @@ def transform(self, X, y=None):
transpose((1, 0, 2, 3))
return Xt

def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'):
def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues',
plotly_params=None):
"""Plot a single channel – corresponding to a given homology
dimension – in a sample from a collection of heat kernel images.

Expand All @@ -619,12 +630,26 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'):
Color scale to be used in the heat map. Can be anything allowed by
:class:`plotly.graph_objects.Heatmap`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"trace"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)
return plot_heatmap(Xt[sample][homology_dimension_ix],
x=self.samplings_[homology_dimension_ix],
y=self.samplings_[homology_dimension_ix],
colorscale=colorscale)
return plot_heatmap(
Xt[sample][homology_dimension_ix],
x=self.samplings_[homology_dimension_ix],
y=self.samplings_[homology_dimension_ix],
colorscale=colorscale, plotly_params=plotly_params
)


@adapt_fit_transform_docs
Expand Down Expand Up @@ -805,7 +830,8 @@ def transform(self, X, y=None):
transpose((1, 0, 2, 3))
return Xt

def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'):
def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues',
plotly_params=None):
"""Plot a single channel – corresponding to a given homology
dimension – in a sample from a collection of persistence images.

Expand All @@ -829,13 +855,25 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'):
Color scale to be used in the heat map. Can be anything allowed by
:class:`plotly.graph_objects.Heatmap`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"trace"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)
samplings_x, samplings_y = self.samplings_[homology_dimension_ix]
return plot_heatmap(Xt[sample][homology_dimension_ix],
x=samplings_x,
y=samplings_y,
colorscale=colorscale)
return plot_heatmap(
Xt[sample][homology_dimension_ix], x=samplings_x, y=samplings_y,
colorscale=colorscale, plotly_params=plotly_params
)


@adapt_fit_transform_docs
Expand Down Expand Up @@ -982,7 +1020,7 @@ def transform(self, X, y=None):
transpose((1, 0, 2))
return Xt

def plot(self, Xt, sample=0, homology_dimensions=None):
def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None):
"""Plot a sample from a collection of silhouettes arranged as in
the output of :meth:`transform`. Include homology in multiple
dimensions.
Expand All @@ -999,6 +1037,18 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
Which homology dimensions to include in the plot. ``None`` means
plotting all dimensions present in :attr:`homology_dimensions_`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
check_is_fitted(self)

Expand Down Expand Up @@ -1055,4 +1105,9 @@ def plot(self, Xt, sample=0, homology_dimensions=None):
hoverinfo="none",
name=f"H{int(dim)}"))

fig.show()
# Update trace and layout according to user input
if plotly_params:
fig.update_traces(plotly_params.get("traces", None))
fig.update_layout(plotly_params.get("layout", None))

return fig
18 changes: 16 additions & 2 deletions gtda/graphs/geodesic_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def transform(self, X, y=None):
return Xt

@staticmethod
def plot(Xt, sample=0, colorscale='blues'):
def plot(Xt, sample=0, colorscale='blues', plotly_params=None):
"""Plot a sample from a collection of distance matrices.

Parameters
Expand All @@ -197,5 +197,19 @@ def plot(Xt, sample=0, colorscale='blues'):
Color scale to be used in the heat map. Can be anything allowed by
:class:`plotly.graph_objects.Heatmap`.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"trace"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
return plot_heatmap(Xt[sample], colorscale=colorscale)
return plot_heatmap(
Xt[sample], colorscale=colorscale, plotly_params=plotly_params
)
18 changes: 14 additions & 4 deletions gtda/graphs/tests/test_geodesic_distance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Testing for GraphGeodesicDistance."""

import warnings

import numpy as np
import plotly.io as pio
import pytest
Expand Down Expand Up @@ -85,20 +87,28 @@ def test_ggd_not_fitted():

def test_ggd_fit_transform_plot():
X = X_ggd[0][0]
GraphGeodesicDistance().fit_transform_plot(X, sample=0)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Methods .*")
GraphGeodesicDistance().fit_transform_plot(X, sample=0)


@pytest.mark.parametrize("X, X_res", X_ggd)
@pytest.mark.parametrize("method", ["auto", "FW", "D", "J", "BF"])
def test_ggd_transform(X, X_res, method):
ggd = GraphGeodesicDistance(directed=False, method=method)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Methods .*")
ggd = GraphGeodesicDistance(directed=False, method=method)
X_ft = ggd.fit_transform(X)

assert_almost_equal(ggd.fit_transform(X), X_res)
assert_almost_equal(X_ft, X_res)


def test_parallel_ggd_transform():
X = X_ggd[0][0]
ggd = GraphGeodesicDistance(n_jobs=1)
ggd_parallel = GraphGeodesicDistance(n_jobs=2)

assert_almost_equal(ggd.fit_transform(X), ggd_parallel.fit_transform(X))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Methods .*")
assert_almost_equal(ggd.fit_transform(X),
ggd_parallel.fit_transform(X))
19 changes: 17 additions & 2 deletions gtda/homology/cubical.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def transform(self, X, y=None):
:math:`\\sum_q n_q`, where :math:`n_q` is the maximum number of
topological features in dimension :math:`q` across all samples in
`X`.

"""
check_is_fitted(self)
Xt = check_array(X, allow_nd=True)
Expand All @@ -221,7 +222,7 @@ def transform(self, X, y=None):
return Xt

@staticmethod
def plot(Xt, sample=0, homology_dimensions=None):
def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None):
"""Plot a sample from a collection of persistence diagrams, with
homology in multiple dimensions.

Expand All @@ -238,6 +239,20 @@ def plot(Xt, sample=0, homology_dimensions=None):
Which homology dimensions to include in the plot. ``None`` means
plotting all dimensions present in ``Xt[sample]``.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
return plot_diagram(
Xt[sample], homology_dimensions=homology_dimensions)
Xt[sample], homology_dimensions=homology_dimensions,
plotly_params=plotly_params
)
Loading