diff --git a/gtda/diagrams/representations.py b/gtda/diagrams/representations.py index 0149220e6..95855dfdc 100644 --- a/gtda/diagrams/representations.py +++ b/gtda/diagrams/representations.py @@ -206,43 +206,41 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): ix = np.flatnonzero(homology_dimensions_arr == dim)[0] _homology_dimensions.append((ix, dim)) + layout_axes_common = { + "type": "linear", + "ticks": "outside", + "showline": True, + "zeroline": True, + "linewidth": 1, + "linecolor": "black", + "mirror": False, + "showexponent": "all", + "exponentformat": "e" + } layout = { "xaxis1": { "title": "Filtration parameter", "side": "bottom", - "type": "linear", - "ticks": "outside", - "anchor": "x1", - "showline": True, - "zeroline": True, - "showexponent": "all", - "exponentformat": "e" + "anchor": "y1", + **layout_axes_common }, "yaxis1": { "title": "Betti number", "side": "left", - "type": "linear", - "ticks": "outside", - "anchor": "y1", - "showline": True, - "zeroline": True, - "showexponent": "all", - "exponentformat": "e" + "anchor": "x1", + **layout_axes_common }, "plot_bgcolor": "white", "title": f"Betti curves from diagram {sample}" } fig = Figure(layout=layout) - fig.update_xaxes(zeroline=True, linewidth=1, linecolor="black", - mirror=False) - fig.update_yaxes(zeroline=True, linewidth=1, linecolor="black", - mirror=False) for ix, dim in _homology_dimensions: fig.add_trace(Scatter(x=self.samplings_[dim], y=Xt[sample][ix], - mode="lines", showlegend=True, + mode="lines", + showlegend=True, name=f"H{int(dim)}")) # Update traces and layout according to user input @@ -452,7 +450,7 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): inv_idx = np.flatnonzero(homology_dimensions_arr == dim)[0] _homology_dimensions.append((inv_idx, dim)) - layout_axes_comm = { + layout_axes_common = { "type": "linear", "ticks": "outside", "showline": True, @@ -467,12 +465,12 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): "xaxis1": { "side": "bottom", "anchor": "y1", - **layout_axes_comm + **layout_axes_common }, "yaxis1": { "side": "left", "anchor": "x1", - **layout_axes_comm + **layout_axes_common }, "plot_bgcolor": "white", } @@ -1172,37 +1170,34 @@ def plot(self, Xt, sample=0, homology_dimensions=None, plotly_params=None): ix = np.flatnonzero(homology_dimensions_arr == dim)[0] _homology_dimensions.append((ix, dim)) + layout_axes_common = { + "type": "linear", + "ticks": "outside", + "showline": True, + "zeroline": True, + "linewidth": 1, + "linecolor": "black", + "mirror": False, + "showexponent": "all", + "exponentformat": "e" + } layout = { "xaxis1": { "title": "Filtration parameter", "side": "bottom", - "type": "linear", - "ticks": "outside", - "anchor": "x1", - "showline": True, - "zeroline": True, - "showexponent": "all", - "exponentformat": "e" + "anchor": "y1", + **layout_axes_common }, "yaxis1": { "side": "left", - "type": "linear", - "ticks": "outside", - "anchor": "y1", - "showline": True, - "zeroline": True, - "showexponent": "all", - "exponentformat": "e" + "anchor": "x1", + **layout_axes_common }, "plot_bgcolor": "white", "title": f"Silhouette representation of diagram {sample}" } fig = Figure(layout=layout) - fig.update_xaxes(zeroline=True, linewidth=1, linecolor="black", - mirror=False) - fig.update_yaxes(zeroline=True, linewidth=1, linecolor="black", - mirror=False) for ix, dim in _homology_dimensions: fig.add_trace(Scatter(x=self.samplings_[dim],