Skip to content

Commit

Permalink
Fix bug in use of homology_dimension_ix with samplings_, rename homol…
Browse files Browse the repository at this point in the history
…ogy_dimension_ix (#452)

* Rename homology_dimension_ix to homology_dimension_idx, fix bug
  • Loading branch information
ulupo authored Aug 10, 2020
1 parent f65f899 commit e0ceda5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
34 changes: 17 additions & 17 deletions gtda/diagrams/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def transform(self, X, y=None):
transpose((1, 0, 2, 3))
return Xt

def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
def plot(self, Xt, sample=0, homology_dimension_idx=0, colorscale="blues",
plotly_params=None):
"""Plot a single channel – corresponding to a given homology
dimension – in a sample from a collection of heat kernel images.
Expand All @@ -672,11 +672,11 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
sample : int, optional, default: ``0``
Index of the sample in `Xt` to be selected.
homology_dimension_ix : int, optional, default: ``0``
Index of the channel in the selected sample to be plotted. If
`Xt` is the result of a call to :meth:`transform` and this
index is i, the plot corresponds to the homology dimension given by
the i-th entry in :attr:`homology_dimensions_`.
homology_dimension_idx : int, optional, default: ``0``
Index of the channel in the selected sample to be plotted. If `Xt`
is the result of a call to :meth:`transform` and this index is i,
the plot corresponds to the homology dimension given by the i-th
entry in :attr:`homology_dimensions_`.
colorscale : str, optional, default: ``"blues"``
Color scale to be used in the heat map. Can be anything allowed by
Expand All @@ -696,10 +696,9 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
"""
check_is_fitted(self)
x = self.samplings_[self.homology_dimensions_[homology_dimension_idx]]
return plot_heatmap(
Xt[sample][homology_dimension_ix],
x=self.samplings_[homology_dimension_ix],
y=self.samplings_[homology_dimension_ix],
Xt[sample][homology_dimension_idx], x=x, y=x,
colorscale=colorscale, plotly_params=plotly_params
)

Expand Down Expand Up @@ -891,7 +890,7 @@ def transform(self, X, y=None):
transpose((1, 0, 2, 3))
return Xt

def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
def plot(self, Xt, sample=0, homology_dimension_idx=0, colorscale="blues",
plotly_params=None):
"""Plot a single channel – corresponding to a given homology
dimension – in a sample from a collection of persistence images.
Expand All @@ -906,11 +905,11 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
sample : int, optional, default: ``0``
Index of the sample in `Xt` to be selected.
homology_dimension_ix : int, optional, default: ``0``
Index of the channel in the selected sample to be plotted. If
`Xt` is the result of a call to :meth:`transform` and this
index is i, the plot corresponds to the homology dimension given by
the i-th entry in :attr:`homology_dimensions_`.
homology_dimension_idx : int, optional, default: ``0``
Index of the channel in the selected sample to be plotted. If `Xt`
is the result of a call to :meth:`transform` and this index is i,
the plot corresponds to the homology dimension given by the i-th
entry in :attr:`homology_dimensions_`.
colorscale : str, optional, default: ``"blues"``
Color scale to be used in the heat map. Can be anything allowed by
Expand All @@ -930,9 +929,10 @@ def plot(self, Xt, sample=0, homology_dimension_ix=0, colorscale="blues",
"""
check_is_fitted(self)
samplings_x, samplings_y = self.samplings_[homology_dimension_ix]
samplings_x, samplings_y = \
self.samplings_[self.homology_dimensions_[homology_dimension_idx]]
return plot_heatmap(
Xt[sample][homology_dimension_ix], x=samplings_x, y=samplings_y,
Xt[sample][homology_dimension_idx], x=samplings_x, y=samplings_y,
colorscale=colorscale, plotly_params=plotly_params
)

Expand Down
8 changes: 4 additions & 4 deletions gtda/diagrams/tests/test_features_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def test_not_fitted():
Silhouette().transform(X)


@pytest.mark.parametrize('hom_dim_ix', [0, 1])
def test_fit_transform_plot_one_hom_dim(hom_dim_ix):
@pytest.mark.parametrize('hom_dim_idx', [0, 1])
def test_fit_transform_plot_one_hom_dim(hom_dim_idx):
HeatKernel().fit_transform_plot(
X, sample=0, homology_dimension_ix=hom_dim_ix)
X, sample=0, homology_dimension_idx=hom_dim_idx)
PersistenceImage().fit_transform_plot(
X, sample=0, homology_dimension_ix=hom_dim_ix)
X, sample=0, homology_dimension_idx=hom_dim_idx)


@pytest.mark.parametrize('hom_dims', [None, (0,), (1,), (0, 1)])
Expand Down

0 comments on commit e0ceda5

Please sign in to comment.