From f5c19f2f3023b33361c72703754d4d109038f7ca Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Fri, 31 Jul 2020 20:31:03 +0200 Subject: [PATCH 1/5] Refactor of plotting API - Make all functions in gtda/plotting return figures (or tuples of figures for betti_surfaces) instead of showing them - `plot` class methods *show* figures as before so the transformer-level plotting API is largely unchanged - Add `plot_params` kwarg throughout to allow user customisability of output figures (subtlety: one of the key can be either "trace" when the output figure only has one trace, or "traces" when it has several) --- gtda/base.py | 2 +- gtda/diagrams/preprocessing.py | 45 +++++++++++--- gtda/diagrams/representations.py | 52 ++++++++++++---- gtda/graphs/geodesic_distance.py | 13 +++- gtda/homology/cubical.py | 16 ++++- gtda/homology/simplicial.py | 60 +++++++++++++++---- gtda/images/filtrations.py | 75 ++++++++++++++++++++---- gtda/images/preprocessing.py | 59 +++++++++++++++---- gtda/plotting/diagram_representations.py | 54 +++++++++++++++-- gtda/plotting/images.py | 21 ++++++- gtda/plotting/persistence_diagrams.py | 22 ++++++- gtda/plotting/point_clouds.py | 30 +++++++--- gtda/point_clouds/rescaling.py | 26 ++++++-- gtda/time_series/embedding.py | 11 +++- 14 files changed, 404 insertions(+), 82 deletions(-) diff --git a/gtda/base.py b/gtda/base.py index 25a981e63..edfd480d0 100644 --- a/gtda/base.py +++ b/gtda/base.py @@ -130,7 +130,7 @@ def transform_plot(self, X, sample=0, **plot_params): sample : int Sample to be plotted. - plot_params : dict + **plot_params Optional plotting parameters. Returns diff --git a/gtda/diagrams/preprocessing.py b/gtda/diagrams/preprocessing.py index 56276ac89..196dc1609 100644 --- a/gtda/diagrams/preprocessing.py +++ b/gtda/diagrams/preprocessing.py @@ -89,7 +89,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0): + def plot(Xt, sample=0, plotly_params=None): """Plot a sample from a collection of persistence diagrams. Parameters @@ -101,9 +101,18 @@ def plot(Xt, sample=0): sample : int, optional, default: ``0`` Index of the sample in `Xt` to be plotted. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=[np.inf]) + plot_diagram( + Xt[sample], homology_dimensions=[np.inf], + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -282,7 +291,7 @@ def inverse_transform(self, X): Xs[:, :, :2] *= self.scale_ return Xs - def plot(self, Xt, sample=0, homology_dimensions=None): + def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -299,14 +308,23 @@ def plot(self, Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` is equivalent to passing :attr:`homology_dimensions_`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ if homology_dimensions is None: _homology_dimensions = self.homology_dimensions_ else: _homology_dimensions = homology_dimensions - return plot_diagram( - Xt[sample], homology_dimensions=_homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=_homology_dimensions, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -419,7 +437,7 @@ def transform(self, X, y=None): Xt = _filter(X, self.homology_dimensions_, self.epsilon) return Xt - def plot(self, Xt, sample=0, homology_dimensions=None): + def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -436,11 +454,20 @@ def plot(self, Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` is equivalent to passing :attr:`homology_dimensions_`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ if homology_dimensions is None: _homology_dimensions = self.homology_dimensions_ else: _homology_dimensions = homology_dimensions - return plot_diagram( - Xt[sample], homology_dimensions=_homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=_homology_dimensions, + plotly_params=plotly_params + ).show() diff --git a/gtda/diagrams/representations.py b/gtda/diagrams/representations.py index ead22a33b..fbc2ced05 100644 --- a/gtda/diagrams/representations.py +++ b/gtda/diagrams/representations.py @@ -595,7 +595,8 @@ def transform(self, X, y=None): transpose((1, 0, 2, 3)) return Xt - def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'): + def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues', + plotly_params=None): """Plot a single channel – corresponding to a given homology dimension – in a sample from a collection of heat kernel images. @@ -619,12 +620,21 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'): Color scale to be used in the heat map. Can be anything allowed by :class:`plotly.graph_objects.Heatmap`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ check_is_fitted(self) - return plot_heatmap(Xt[sample][homology_dimension_ix], - x=self.samplings_[homology_dimension_ix], - y=self.samplings_[homology_dimension_ix], - colorscale=colorscale) + plot_heatmap( + Xt[sample][homology_dimension_ix], + x=self.samplings_[homology_dimension_ix], + y=self.samplings_[homology_dimension_ix], + colorscale=colorscale, plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -805,7 +815,8 @@ def transform(self, X, y=None): transpose((1, 0, 2, 3)) return Xt - def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'): + def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues', + plotly_params=None): """Plot a single channel – corresponding to a given homology dimension – in a sample from a collection of persistence images. @@ -829,13 +840,20 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues'): Color scale to be used in the heat map. Can be anything allowed by :class:`plotly.graph_objects.Heatmap`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ check_is_fitted(self) samplings_x, samplings_y = self.samplings_[homology_dimension_ix] - return plot_heatmap(Xt[sample][homology_dimension_ix], - x=samplings_x, - y=samplings_y, - colorscale=colorscale) + plot_heatmap( + Xt[sample][homology_dimension_ix], x=samplings_x, y=samplings_y, + colorscale=colorscale, plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -982,7 +1000,7 @@ def transform(self, X, y=None): transpose((1, 0, 2)) return Xt - def plot(self, Xt, sample=0, homology_dimensions=None): + def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of silhouettes arranged as in the output of :meth:`transform`. Include homology in multiple dimensions. @@ -999,6 +1017,13 @@ def plot(self, Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in :attr:`homology_dimensions_`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ check_is_fitted(self) @@ -1055,4 +1080,9 @@ def plot(self, Xt, sample=0, homology_dimensions=None): hoverinfo="none", name=f"H{int(dim)}")) + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("traces", None)) + fig.update_layout(plotly_params.get("layout", None)) + fig.show() diff --git a/gtda/graphs/geodesic_distance.py b/gtda/graphs/geodesic_distance.py index f64a883d6..77f1f5394 100644 --- a/gtda/graphs/geodesic_distance.py +++ b/gtda/graphs/geodesic_distance.py @@ -180,7 +180,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='blues'): + def plot(Xt, sample=0, colorscale='blues', plotly_params=None): """Plot a sample from a collection of distance matrices. Parameters @@ -197,5 +197,14 @@ def plot(Xt, sample=0, colorscale='blues'): Color scale to be used in the heat map. Can be anything allowed by :class:`plotly.graph_objects.Heatmap`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale) + plot_heatmap( + Xt[sample], colorscale=colorscale, plotly_params=plotly_params + ).show() diff --git a/gtda/homology/cubical.py b/gtda/homology/cubical.py index e85140625..c2b06bc39 100644 --- a/gtda/homology/cubical.py +++ b/gtda/homology/cubical.py @@ -197,6 +197,7 @@ def transform(self, X, y=None): :math:`\\sum_q n_q`, where :math:`n_q` is the maximum number of topological features in dimension :math:`q` across all samples in `X`. + """ check_is_fitted(self) Xt = check_array(X, allow_nd=True) @@ -221,7 +222,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, homology_dimensions=None): + def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -238,6 +239,15 @@ def plot(Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in ``Xt[sample]``. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=homology_dimensions, + plotly_params=plotly_params + ).show() diff --git a/gtda/homology/simplicial.py b/gtda/homology/simplicial.py index 7d6cdee30..ee5deaf08 100644 --- a/gtda/homology/simplicial.py +++ b/gtda/homology/simplicial.py @@ -233,7 +233,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, homology_dimensions=None): + def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -250,9 +250,18 @@ def plot(Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in ``Xt[sample]``. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=homology_dimensions, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -480,7 +489,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, homology_dimensions=None): + def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -497,9 +506,18 @@ def plot(Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in ``Xt[sample]``. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=homology_dimensions, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -683,7 +701,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, homology_dimensions=None): + def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -700,9 +718,18 @@ def plot(Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in ``Xt[sample]``. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=homology_dimensions, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -947,7 +974,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, homology_dimensions=None): + def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): """Plot a sample from a collection of persistence diagrams, with homology in multiple dimensions. @@ -964,6 +991,15 @@ def plot(Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in ``Xt[sample]``. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_diagram( - Xt[sample], homology_dimensions=homology_dimensions) + plot_diagram( + Xt[sample], homology_dimensions=homology_dimensions, + plotly_params=plotly_params + ).show() diff --git a/gtda/images/filtrations.py b/gtda/images/filtrations.py index 624da3f4e..663cb14a9 100644 --- a/gtda/images/filtrations.py +++ b/gtda/images/filtrations.py @@ -180,7 +180,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D greyscale images. Parameters @@ -201,8 +202,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample], colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -406,7 +417,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D greyscale images. Parameters @@ -427,8 +439,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample], colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -579,7 +601,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D greyscale images. Parameters @@ -600,8 +623,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample], colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -751,7 +784,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D greyscale images. Parameters @@ -772,8 +806,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample], colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -933,7 +977,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D greyscale images. Parameters @@ -954,5 +999,15 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample], colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() diff --git a/gtda/images/preprocessing.py b/gtda/images/preprocessing.py index dc996f8d5..c8b53a6c4 100644 --- a/gtda/images/preprocessing.py +++ b/gtda/images/preprocessing.py @@ -141,7 +141,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D binary images. Parameters @@ -162,9 +163,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap( - Xt[sample] * 1, colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample] * 1, colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ) @adapt_fit_transform_docs @@ -249,7 +259,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D binary images. Parameters @@ -270,9 +281,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap( - Xt[sample] * 1, colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample] * 1, colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -401,7 +421,8 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='greys', origin='upper'): + def plot(Xt, sample=0, colorscale='greys', origin='upper', + plotly_params=None): """Plot a sample from a collection of 2D binary images. Parameters @@ -422,9 +443,18 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper'): left corner. The convention ``'upper'`` is typically used for matrices and images. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap( - Xt[sample] * 1, colorscale=colorscale, origin=origin) + plot_heatmap( + Xt[sample] * 1, colorscale=colorscale, origin=origin, + plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -531,7 +561,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0): + def plot(Xt, sample=0, plotly_params=None): """Plot a sample from a collection of point clouds. If the point cloud is in more than three dimensions, only the first three are plotted. @@ -544,5 +574,12 @@ def plot(Xt, sample=0): sample : int, optional, default: ``0`` Index of the sample in `Xt` to be plotted. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_point_cloud(Xt[sample]) + plot_point_cloud(Xt[sample], plotly_params=plotly_params).show() diff --git a/gtda/plotting/diagram_representations.py b/gtda/plotting/diagram_representations.py index fe25570b0..8d0db7f70 100644 --- a/gtda/plotting/diagram_representations.py +++ b/gtda/plotting/diagram_representations.py @@ -5,7 +5,8 @@ import plotly.graph_objs as gobj -def plot_betti_curves(betti_numbers, samplings, homology_dimensions=None): +def plot_betti_curves(betti_numbers, samplings, homology_dimensions=None, + plotly_params=None): """Plot Betti curves by homology dimension. Parameters @@ -23,6 +24,18 @@ def plot_betti_curves(betti_numbers, samplings, homology_dimensions=None): Which homology dimensions to include in the plot. If ``None``, all available homology dimensions will be used. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should be + dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Figure representing the Betti curves. + """ if homology_dimensions is None: _homology_dimensions = list(range(betti_numbers.shape[0])) @@ -67,11 +80,16 @@ def plot_betti_curves(betti_numbers, samplings, homology_dimensions=None): hoverinfo='none', name=f'H{int(dim)}')) - fig.show() + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("traces", None)) + fig.update_layout(plotly_params.get("layout", None)) + + return fig def plot_betti_surfaces(betti_curves, samplings=None, - homology_dimensions=None): + homology_dimensions=None, plotly_params=None): """Plot Betti surfaces (Betti numbers against "time" and filtration parameter) by homology dimension. @@ -97,6 +115,22 @@ def plot_betti_surfaces(betti_curves, samplings=None, on the x-axis against the corresponding values in `betti_curves` on the y-axis. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should be + dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + + Returns + ------- + figs/fig : tuple of :class:`plotly.graph_objects.Figure`/\ + :class:`plotly.graph_objects.Figure` object + If ``n_samples > 1``, a tuple of figures representing the Betti + surfaces, with one figure per dimension in `homology_dimensions`. + Otherwise, a single figure representing the Betti curve of the + single sample present. + """ if homology_dimensions is None: _homology_dimensions = list(range(betti_curves.shape[1])) @@ -124,8 +158,11 @@ def plot_betti_surfaces(betti_curves, samplings=None, } } if betti_curves.shape[0] == 1: - plot_betti_curves(betti_curves[0], samplings, homology_dimensions) + return plot_betti_curves( + betti_curves[0], samplings, homology_dimensions, plotly_params + ) else: + figs = [] for dim in _homology_dimensions: fig = gobj.Figure() fig.update_layout(scene=scene, @@ -136,4 +173,11 @@ def plot_betti_surfaces(betti_curves, samplings=None, z=betti_curves[:, dim], connectgaps=True, hoverinfo='none')) - fig.show() + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("traces", None)) + fig.update_layout(plotly_params.get("layout", None)) + + figs.append(fig) + + return tuple(figs) diff --git a/gtda/plotting/images.py b/gtda/plotting/images.py index be8d7e25c..eeef06b47 100644 --- a/gtda/plotting/images.py +++ b/gtda/plotting/images.py @@ -5,7 +5,7 @@ def plot_heatmap(data, x=None, y=None, colorscale='greys', origin='upper', - title=None): + title=None, plotly_params=None): """Plot a 2D single-channel image, as a heat map from 2D array data. Parameters @@ -31,6 +31,18 @@ def plot_heatmap(data, x=None, y=None, colorscale='greys', origin='upper', title : str or None, optional, default: ``None`` Title of the resulting figure. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should be + dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Figure representing the 2D single-channel image. + """ autorange = True if origin == 'lower' else 'reversed' layout = dict( @@ -44,4 +56,9 @@ def plot_heatmap(data, x=None, y=None, colorscale='greys', origin='upper', z=data, x=x, y=y, colorscale=colorscale )) - fig.show() + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("trace", None)) + fig.update_layout(plotly_params.get("layout", None)) + + return fig diff --git a/gtda/plotting/persistence_diagrams.py b/gtda/plotting/persistence_diagrams.py index cc37fa34a..cf611be79 100644 --- a/gtda/plotting/persistence_diagrams.py +++ b/gtda/plotting/persistence_diagrams.py @@ -5,7 +5,7 @@ import plotly.graph_objs as gobj -def plot_diagram(diagram, homology_dimensions=None): +def plot_diagram(diagram, homology_dimensions=None, plotly_params=None): """Plot a single persistence diagram. Parameters @@ -19,8 +19,19 @@ def plot_diagram(diagram, homology_dimensions=None): Homology dimensions which will appear on the plot. If ``None``, all homology dimensions which appear in `diagram` will be plotted. - """ + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"traces"`` and ``"layout"``, and the corresponding values should be + dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Figure representing the persistence diagram. + """ # TODO: increase the marker size if homology_dimensions is None: homology_dimensions = np.unique(diagram[:, 2]) @@ -91,4 +102,9 @@ def plot_diagram(diagram, homology_dimensions=None): plot_bgcolor='white' ) - fig.show() + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("traces", None)) + fig.update_layout(plotly_params.get("layout", None)) + + return fig diff --git a/gtda/plotting/point_clouds.py b/gtda/plotting/point_clouds.py index fef2010a2..374a7f38b 100644 --- a/gtda/plotting/point_clouds.py +++ b/gtda/plotting/point_clouds.py @@ -5,10 +5,10 @@ import plotly.graph_objs as gobj -def plot_point_cloud(point_cloud, dimension=None): +def plot_point_cloud(point_cloud, dimension=None, plotly_params=None): """Plot the first 2 or 3 coordinates of a point cloud. - This function will not work on 1D arrays. + Note: this function does not work on 1D arrays. Parameters ---------- @@ -20,6 +20,18 @@ def plot_point_cloud(point_cloud, dimension=None): Sets the dimension of the resulting plot. If ``None``, the dimension will be chosen between 2 and 3 depending on the shape of `point_cloud`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should be + dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Figure representing a point cloud in 2D or 3D. + """ # TODO: increase the marker size if dimension is None: @@ -30,7 +42,9 @@ def plot_point_cloud(point_cloud, dimension=None): raise ValueError("Not enough dimensions available in the input point " "cloud.") - if dimension == 2: + if dimension not in [2, 3]: + raise ValueError("The value of the dimension is different from 2 or 3") + elif dimension == 2: layout = { "width": 800, "height": 800, @@ -73,7 +87,6 @@ def plot_point_cloud(point_cloud, dimension=None): point_cloud.shape[0])), colorscale='Viridis', opacity=0.8))) - fig.show() elif dimension == 3: scene = { "xaxis": { @@ -109,6 +122,9 @@ def plot_point_cloud(point_cloud, dimension=None): colorscale='Viridis', opacity=0.8))) - fig.show() - else: - raise ValueError("The value of the dimension is different from 2 or 3") + # Update trace and layout according to user input + if plotly_params: + fig.update_traces(plotly_params.get("trace", None)) + fig.update_layout(plotly_params.get("layout", None)) + + return fig diff --git a/gtda/point_clouds/rescaling.py b/gtda/point_clouds/rescaling.py index 7b6284f70..6e67789ed 100644 --- a/gtda/point_clouds/rescaling.py +++ b/gtda/point_clouds/rescaling.py @@ -192,7 +192,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='blues'): + def plot(Xt, sample=0, colorscale='blues', plotly_params=None): """Plot a sample from a collection of distance matrices. Parameters @@ -208,8 +208,17 @@ def plot(Xt, sample=0, colorscale='blues'): Color scale to be used in the heat map. Can be anything allowed by :class:`plotly.graph_objects.Heatmap`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale) + plot_heatmap( + Xt[sample], colorscale=colorscale, plotly_params=plotly_params + ).show() @adapt_fit_transform_docs @@ -371,7 +380,7 @@ def transform(self, X, y=None): return Xt @staticmethod - def plot(Xt, sample=0, colorscale='blues'): + def plot(Xt, sample=0, colorscale='blues', plotly_params=None): """Plot a sample from a collection of distance matrices. Parameters @@ -387,5 +396,14 @@ def plot(Xt, sample=0, colorscale='blues'): Color scale to be used in the heat map. Can be anything allowed by :class:`plotly.graph_objects.Heatmap`. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_heatmap(Xt[sample], colorscale=colorscale) + plot_heatmap( + Xt[sample], colorscale=colorscale, plotly_params=plotly_params + ).show() diff --git a/gtda/time_series/embedding.py b/gtda/time_series/embedding.py index a909c7e71..ad7bc1304 100644 --- a/gtda/time_series/embedding.py +++ b/gtda/time_series/embedding.py @@ -173,7 +173,7 @@ def resample(self, y, X=None): return yr @staticmethod - def plot(Xt, sample=0): + def plot(Xt, sample=0, plotly_params=None): """Plot a sample from a collection of sliding windows, as a point cloud in 2D or 3D. If points in the window have more than three dimensions, only the first three are plotted. @@ -192,8 +192,15 @@ def plot(Xt, sample=0): sample : int, optional, default: ``0`` Index of the sample in `Xt` to be plotted. + plotly_params : dict or None, optional, default: ``None`` + Custom parameters to configure the plotly figure. Allowed keys are + ``"trace"`` and ``"layout"``, and the corresponding values should + be dictionaries containing keyword arguments as would be fed to the + :meth:`update_traces` and :meth:`update_layout` methods of + :class:`plotly.graph_objects.Figure`. + """ - return plot_point_cloud(Xt[sample]) + plot_point_cloud(Xt[sample], plotly_params=plotly_params).show() @adapt_fit_transform_docs From bc5fee8335e0f4e81766360bedc5cfebc1a7206a Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Fri, 31 Jul 2020 20:46:35 +0200 Subject: [PATCH 2/5] Suppress user warnings on graph geodesic distance algorithms in tests --- gtda/graphs/tests/test_geodesic_distance.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/gtda/graphs/tests/test_geodesic_distance.py b/gtda/graphs/tests/test_geodesic_distance.py index 1530edb3e..9988267d0 100644 --- a/gtda/graphs/tests/test_geodesic_distance.py +++ b/gtda/graphs/tests/test_geodesic_distance.py @@ -1,5 +1,7 @@ """Testing for GraphGeodesicDistance.""" +import warnings + import numpy as np import plotly.io as pio import pytest @@ -85,15 +87,20 @@ def test_ggd_not_fitted(): def test_ggd_fit_transform_plot(): X = X_ggd[0][0] - GraphGeodesicDistance().fit_transform_plot(X, sample=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Methods .*") + GraphGeodesicDistance().fit_transform_plot(X, sample=0) @pytest.mark.parametrize("X, X_res", X_ggd) @pytest.mark.parametrize("method", ["auto", "FW", "D", "J", "BF"]) def test_ggd_transform(X, X_res, method): - ggd = GraphGeodesicDistance(directed=False, method=method) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Methods .*") + ggd = GraphGeodesicDistance(directed=False, method=method) + X_ft = ggd.fit_transform(X) - assert_almost_equal(ggd.fit_transform(X), X_res) + assert_almost_equal(X_ft, X_res) def test_parallel_ggd_transform(): @@ -101,4 +108,7 @@ def test_parallel_ggd_transform(): ggd = GraphGeodesicDistance(n_jobs=1) ggd_parallel = GraphGeodesicDistance(n_jobs=2) - assert_almost_equal(ggd.fit_transform(X), ggd_parallel.fit_transform(X)) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Methods .*") + assert_almost_equal(ggd.fit_transform(X), + ggd_parallel.fit_transform(X)) From 3008f92191c89306ac4a0569dd2ccbec68a6282e Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Fri, 31 Jul 2020 23:25:32 +0200 Subject: [PATCH 3/5] Resolve overflow warnings in mapper filter tests --- gtda/mapper/tests/test_filter.py | 119 +++++++++++++++---------------- 1 file changed, 58 insertions(+), 61 deletions(-) diff --git a/gtda/mapper/tests/test_filter.py b/gtda/mapper/tests/test_filter.py index 6c7815b90..8e2c20e00 100644 --- a/gtda/mapper/tests/test_filter.py +++ b/gtda/mapper/tests/test_filter.py @@ -1,24 +1,25 @@ +import warnings + import numpy as np from hypothesis import given from hypothesis.extra.numpy import array_shapes, arrays from hypothesis.strategies import integers, floats from numpy.testing import assert_almost_equal from scipy.spatial.distance import pdist, squareform +from sklearn.neighbors import KernelDensity from gtda.mapper import Eccentricity, Entropy, Projection from gtda.mapper.utils._list_feature_union import ListFeatureUnion from gtda.mapper.utils.decorators import method_to_transform -from sklearn.neighbors import KernelDensity - -@given( - X=arrays(dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False), - shape=array_shapes(min_dims=2, max_dims=2)), - exponent=integers(min_value=1, max_value=100) -) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=-1e3, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2)), + exponent=integers(min_value=1, max_value=10)) def test_eccentricity_shape_equals_number_of_samples(X, exponent): """Verify that eccentricity preserves the nb of samples in the input.""" eccentricity = Eccentricity(exponent=exponent) @@ -28,7 +29,9 @@ def test_eccentricity_shape_equals_number_of_samples(X, exponent): @given(X=arrays(dtype=np.float, elements=floats(allow_nan=False, - allow_infinity=False), + allow_infinity=False, + min_value=-1e3, + max_value=1e3), shape=array_shapes(min_dims=2, max_dims=2))) def test_eccentricity_values_with_infinity_norm_equals_max_row_values(X): eccentricity = Eccentricity(exponent=np.inf) @@ -37,35 +40,33 @@ def test_eccentricity_values_with_infinity_norm_equals_max_row_values(X): assert_almost_equal(Xt, np.max(distance_matrix, axis=1).reshape(-1, 1)) -@given(X=arrays( - dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False, - min_value=-1e3, - max_value=-1), - shape=array_shapes(min_dims=2, max_dims=2, min_side=2) -)) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=-1e3, + max_value=-1), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2))) def test_entropy_values_for_negative_inputs(X): """Verify the numerical results of entropy (does it have the correct logic), on a collection of **negative** inputs.""" entropy = Entropy() - Xt = entropy.fit_transform(X) - probs = X / X.sum(axis=1, keepdims=True) - entropies = - np.einsum('ij,ij->i', probs, - np.where(probs != 0, np.log2(probs), 0)) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + Xt = entropy.fit_transform(X) + probs = X / X.sum(axis=1, keepdims=True) + entropies = - np.einsum('ij,ij->i', probs, + np.where(probs != 0, np.log2(probs), 0)) assert_almost_equal(Xt, entropies[:, None]) -@given(X=arrays( - dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False, - min_value=1, - max_value=1e3), - shape=array_shapes(min_dims=2, max_dims=2, min_side=2) -)) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=1, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2))) def test_entropy_values_for_positive_inputs(X): - """Verify the numerical results of entropy (does it have the correct logic), + """Verify the numerical results of entropy (does it have the correct logic) on a collection of **positive** inputs.""" entropy = Entropy() Xt = entropy.fit_transform(X) @@ -77,25 +78,25 @@ def test_entropy_values_for_positive_inputs(X): @given(X=arrays(dtype=np.float, elements=floats(allow_nan=False, - allow_infinity=False), - shape=array_shapes(min_dims=2, max_dims=2))) + allow_infinity=False, + min_value=-1e3, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2))) def test_projection_values_equal_slice(X): """Test the logic of the ``Projection`` transformer.""" columns = np.random.choice( - X.shape[1], 1 + np.random.randint(X.shape[1])) + X.shape[1], 1 + np.random.randint(X.shape[1] - 1)) Xt = Projection(columns=columns).fit_transform(X) assert_almost_equal(Xt, X[:, columns]) -@given(X=arrays( - dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False, - min_value=1, - max_value=1e3), - shape=array_shapes(min_dims=2, max_dims=2, min_side=2), - unique=True -)) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=1, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2), + unique=True)) def test_gaussian_density_values(X): """Check that ``fit_transform`` and ``fit + score_samples`` of ``KernelDensity`` are the same.""" @@ -107,15 +108,13 @@ def test_gaussian_density_values(X): assert_almost_equal(Xt_actual, Xt_desired) -@given(X=arrays( - dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False, - min_value=1, - max_value=1e3), - shape=array_shapes(min_dims=2, max_dims=2, min_side=2), - unique=True -)) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=1, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2), + unique=True)) def test_list_feature_union_transform(X): """Check that a ``ListFeatureUnion`` of two projections gives the same result as stacking the projections.""" @@ -131,15 +130,13 @@ def test_list_feature_union_transform(X): assert_almost_equal(x_12, x_1_2) -@given(X=arrays( - dtype=np.float, - elements=floats(allow_nan=False, - allow_infinity=False, - min_value=1, - max_value=1e3), - shape=array_shapes(min_dims=2, max_dims=2, min_side=2), - unique=True -)) +@given(X=arrays(dtype=np.float, + elements=floats(allow_nan=False, + allow_infinity=False, + min_value=1, + max_value=1e3), + shape=array_shapes(min_dims=2, max_dims=2, min_side=2), + unique=True)) def test_list_feature_union_drops(X): """Check the the drop of ``ListFeatureUnion`` keeps the correct number of samples""" From 4b795cf86b198d42a1a672683c953fb9100485d8 Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Fri, 31 Jul 2020 23:25:51 +0200 Subject: [PATCH 4/5] Resolve numpy DeprecationWarning in test_validation --- gtda/utils/tests/test_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtda/utils/tests/test_validation.py b/gtda/utils/tests/test_validation.py index fe4968642..67da87321 100644 --- a/gtda/utils/tests/test_validation.py +++ b/gtda/utils/tests/test_validation.py @@ -142,7 +142,7 @@ def test_check_point_clouds_value_err_finite(): # Check that we error on 1d array input with pytest.raises(ValueError): - check_point_clouds(np.asarray(ex.X_list_tot)) + check_point_clouds(np.asarray(ex.X_list_tot, dtype=object)) # Check that we error on 2d array input with pytest.raises(ValueError): From 5be8bf8ce4cb41dc781dd402fcab3e1c9e33979f Mon Sep 17 00:00:00 2001 From: Umberto Lupo Date: Sat, 1 Aug 2020 00:39:30 +0200 Subject: [PATCH 5/5] Make `plot` class methods return figures, call .show() only in PlotterMixin.transform_plot --- gtda/base.py | 2 +- gtda/diagrams/preprocessing.py | 27 ++++++++++++++----- gtda/diagrams/representations.py | 39 ++++++++++++++++++++++----- gtda/graphs/geodesic_distance.py | 9 +++++-- gtda/homology/cubical.py | 9 +++++-- gtda/homology/simplicial.py | 36 +++++++++++++++++++------ gtda/images/filtrations.py | 45 +++++++++++++++++++++++++------- gtda/images/preprocessing.py | 32 ++++++++++++++++++----- gtda/point_clouds/rescaling.py | 18 ++++++++++--- gtda/time_series/embedding.py | 7 ++++- 10 files changed, 177 insertions(+), 47 deletions(-) diff --git a/gtda/base.py b/gtda/base.py index edfd480d0..6e65101ae 100644 --- a/gtda/base.py +++ b/gtda/base.py @@ -140,6 +140,6 @@ def transform_plot(self, X, sample=0, **plot_params): """ Xt = self.transform(X[sample:sample+1]) - self.plot(Xt, sample=0, **plot_params) + self.plot(Xt, sample=0, **plot_params).show() return Xt diff --git a/gtda/diagrams/preprocessing.py b/gtda/diagrams/preprocessing.py index 196dc1609..8f8119ee1 100644 --- a/gtda/diagrams/preprocessing.py +++ b/gtda/diagrams/preprocessing.py @@ -108,11 +108,16 @@ def plot(Xt, sample=0, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=[np.inf], plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -315,16 +320,21 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ if homology_dimensions is None: _homology_dimensions = self.homology_dimensions_ else: _homology_dimensions = homology_dimensions - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=_homology_dimensions, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -461,13 +471,18 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ if homology_dimensions is None: _homology_dimensions = self.homology_dimensions_ else: _homology_dimensions = homology_dimensions - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=_homology_dimensions, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/diagrams/representations.py b/gtda/diagrams/representations.py index fbc2ced05..4bc0d6598 100644 --- a/gtda/diagrams/representations.py +++ b/gtda/diagrams/representations.py @@ -168,6 +168,11 @@ def plot(self, Xt, sample=0, homology_dimensions=None): Which homology dimensions to include in the plot. ``None`` means plotting all dimensions present in :attr:`homology_dimensions_`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ check_is_fitted(self) @@ -224,7 +229,7 @@ def plot(self, Xt, sample=0, homology_dimensions=None): mode='lines', showlegend=True, name=f"H{int(dim)}")) - fig.show() + return fig @adapt_fit_transform_docs @@ -383,6 +388,11 @@ def plot(self, Xt, sample=0, homology_dimensions=None): ``None`` means plotting all dimensions present in :attr:`homology_dimensions_`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ check_is_fitted(self) @@ -445,7 +455,7 @@ def plot(self, Xt, sample=0, homology_dimensions=None): hoverinfo='none', name=f"Layer {layer + 1}")) - fig.show() + return fig @adapt_fit_transform_docs @@ -627,14 +637,19 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ check_is_fitted(self) - plot_heatmap( + return plot_heatmap( Xt[sample][homology_dimension_ix], x=self.samplings_[homology_dimension_ix], y=self.samplings_[homology_dimension_ix], colorscale=colorscale, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -847,13 +862,18 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale='blues', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ check_is_fitted(self) samplings_x, samplings_y = self.samplings_[homology_dimension_ix] - plot_heatmap( + return plot_heatmap( Xt[sample][homology_dimension_ix], x=samplings_x, y=samplings_y, colorscale=colorscale, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -1024,6 +1044,11 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ check_is_fitted(self) @@ -1085,4 +1110,4 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): fig.update_traces(plotly_params.get("traces", None)) fig.update_layout(plotly_params.get("layout", None)) - fig.show() + return fig diff --git a/gtda/graphs/geodesic_distance.py b/gtda/graphs/geodesic_distance.py index 77f1f5394..c006eaa15 100644 --- a/gtda/graphs/geodesic_distance.py +++ b/gtda/graphs/geodesic_distance.py @@ -204,7 +204,12 @@ def plot(Xt, sample=0, colorscale='blues', plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/homology/cubical.py b/gtda/homology/cubical.py index c2b06bc39..baa505430 100644 --- a/gtda/homology/cubical.py +++ b/gtda/homology/cubical.py @@ -246,8 +246,13 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=homology_dimensions, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/homology/simplicial.py b/gtda/homology/simplicial.py index ee5deaf08..472bc228f 100644 --- a/gtda/homology/simplicial.py +++ b/gtda/homology/simplicial.py @@ -257,11 +257,16 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=homology_dimensions, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -513,11 +518,16 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=homology_dimensions, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -725,11 +735,16 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=homology_dimensions, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -998,8 +1013,13 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_diagram( + return plot_diagram( Xt[sample], homology_dimensions=homology_dimensions, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/images/filtrations.py b/gtda/images/filtrations.py index 663cb14a9..ba975e593 100644 --- a/gtda/images/filtrations.py +++ b/gtda/images/filtrations.py @@ -209,11 +209,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -446,11 +451,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -630,11 +640,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -813,11 +828,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -1006,8 +1026,13 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/images/preprocessing.py b/gtda/images/preprocessing.py index c8b53a6c4..a31775a5b 100644 --- a/gtda/images/preprocessing.py +++ b/gtda/images/preprocessing.py @@ -170,8 +170,13 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample] * 1, colorscale=colorscale, origin=origin, plotly_params=plotly_params ) @@ -288,11 +293,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample] * 1, colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -450,11 +460,16 @@ def plot(Xt, sample=0, colorscale='greys', origin='upper', :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample] * 1, colorscale=colorscale, origin=origin, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -581,5 +596,10 @@ def plot(Xt, sample=0, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_point_cloud(Xt[sample], plotly_params=plotly_params).show() + return plot_point_cloud(Xt[sample], plotly_params=plotly_params) diff --git a/gtda/point_clouds/rescaling.py b/gtda/point_clouds/rescaling.py index 6e67789ed..18c418ac9 100644 --- a/gtda/point_clouds/rescaling.py +++ b/gtda/point_clouds/rescaling.py @@ -215,10 +215,15 @@ def plot(Xt, sample=0, colorscale='blues', plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, plotly_params=plotly_params - ).show() + ) @adapt_fit_transform_docs @@ -403,7 +408,12 @@ def plot(Xt, sample=0, colorscale='blues', plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_heatmap( + return plot_heatmap( Xt[sample], colorscale=colorscale, plotly_params=plotly_params - ).show() + ) diff --git a/gtda/time_series/embedding.py b/gtda/time_series/embedding.py index ad7bc1304..978f8fd6c 100644 --- a/gtda/time_series/embedding.py +++ b/gtda/time_series/embedding.py @@ -199,8 +199,13 @@ def plot(Xt, sample=0, plotly_params=None): :meth:`update_traces` and :meth:`update_layout` methods of :class:`plotly.graph_objects.Figure`. + Returns + ------- + fig : :class:`plotly.graph_objects.Figure` object + Plotly figure. + """ - plot_point_cloud(Xt[sample], plotly_params=plotly_params).show() + return plot_point_cloud(Xt[sample], plotly_params=plotly_params) @adapt_fit_transform_docs