-
Notifications
You must be signed in to change notification settings - Fork 177
/
Copy pathpersistence_diagrams.py
86 lines (72 loc) · 2.96 KB
/
persistence_diagrams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Persistent-homology–related plotting functions and classes."""
# License: GNU AGPLv3
import numpy as np
import plotly.graph_objs as gobj
def plot_diagram(diagram, homology_dimensions=None, **input_layout):
"""Plot a single persistence diagram.
Parameters
----------
diagram : ndarray of shape (n_points, 3)
The persistence diagram to plot, where the third dimension along axis 1
contains homology dimensions, and the first two contain (birth, death)
pairs to be used as coordinates in the two-dimensional plot.
homology_dimensions : list of int or None, optional, default: ``None``
Homology dimensions which will appear on the plot. If ``None``, all
homology dimensions which appear in `diagram` will be plotted.
"""
from ..diagrams._utils import _subdiagrams # To avoid circular imports
# TODO: increase the marker size
if homology_dimensions is None:
homology_dimensions = np.unique(diagram[:, 2])
max_filt_param = np.where(np.isinf(diagram), -np.inf, diagram).max()
layout = dict(
width=500,
height=500,
xaxis1=dict(
title='Birth',
side='bottom',
type='linear',
range=[0, 1.1 * max_filt_param],
ticks='outside',
anchor='y1',
showline=True,
zeroline=True,
showexponent='all',
exponentformat='e'
),
yaxis1=dict(
title='Death',
side='left',
type='linear',
range=[0, 1.1 * max_filt_param],
ticks='outside',
anchor='x1',
showline=True,
zeroline=True,
showexponent='all',
exponentformat='e'
),
plot_bgcolor='white'
)
layout.update(input_layout)
fig = gobj.Figure(layout=layout)
fig.update_xaxes(zeroline=True, linewidth=1, linecolor='black',
mirror=False)
fig.update_yaxes(zeroline=True, linewidth=1, linecolor='black',
mirror=False)
fig.add_trace(gobj.Scatter(x=np.array([-100 * max_filt_param,
100 * max_filt_param]),
y=np.array([-100 * max_filt_param,
100 * max_filt_param]),
mode='lines',
line=dict(dash='dash', width=1, color='black'),
showlegend=False, hoverinfo='none'))
for dim in homology_dimensions:
name = f'H{int(dim)}' if dim != np.inf else 'Any homology dimension'
subdiagram = _subdiagrams(np.asarray([diagram]), [dim],
remove_dim=True)[0]
diff = (subdiagram[:, 1] != subdiagram[:, 0])
subdiagram = subdiagram[diff]
fig.add_trace(gobj.Scatter(x=subdiagram[:, 0], y=subdiagram[:, 1],
mode='markers', name=name))
fig.show()