Skip to content

Commit

Permalink
Inherited from plotter instead of wrapping it (#138)
Browse files Browse the repository at this point in the history
* Inherited from plotter instead of wrapping it

* changelog

* precommit

* revert

* add ignore

* nitpick

* remove init and see if its ok

* bump min vista
  • Loading branch information
nabobalis authored Jul 9, 2024
1 parent 2f7ad58 commit 9d29fe2
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 59 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ jobs:
core:
uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main
with:
display: true
submodules: false
coverage: codecov
display: true
envs: |
- linux: py312
test:
needs: [core]
uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main
with:
display: true
submodules: false
coverage: codecov
display: true
envs: |
- macos: py311
- windows: py310
Expand All @@ -44,9 +44,9 @@ jobs:
needs: [core]
uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main
with:
display: true
default_python: "3.12"
pytest: false
display: true
libraries: |
apt:
- graphviz
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ repos:
rev: v4.0.0-alpha.8
hooks:
- id: prettier
ci:
autofix_prs: false
autoupdate_schedule: "quarterly"
2 changes: 2 additions & 0 deletions changelog/138.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Instead of wrapping `pyvsita.Plotter`, we now inherit from it.
This allows us to drop the ``plotter.plotter`` lines in examples and user facing API.
3 changes: 3 additions & 0 deletions docs/nitpick-exceptions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ py:class array-like
py:obj parfive
py:class string
py:class floats
py:class pyvista.plotting.plotter.BasePlotter
py:class pyvista.plotting.picking.PickingMethods
py:class pyvista.plotting.picking.PickingInterface
4 changes: 2 additions & 2 deletions examples/floating_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
# sphinx_gallery_defer_figures

cloud = plotter.coordinates_to_polydata(point_cloud)
_ = plotter.plotter.add_points(cloud, point_size=0.7, color="cyan", style="points_gaussian")
_ = plotter.add_points(cloud, point_size=0.7, color="cyan", style="points_gaussian")

################################################################################
# Next we want to build a surface from these points.
Expand All @@ -75,7 +75,7 @@
# sphinx_gallery_defer_figures

surf = cloud.delaunay_3d()
_ = plotter.plotter.add_mesh(surf)
_ = plotter.add_mesh(surf)

################################################################################
# Finally set up the camera position and focus.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Visualization",
]
dependencies = [
'pyvista[all]>= 0.43.9',
'pyvista[all]>= 0.44.0',
'sunpy[map]>=5.0.0',
]

Expand Down
7 changes: 7 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ minversion = 7.0
testpaths =
sunkit_pyvista
docs
norecursedirs =
docs/_build
docs/generated
doctest_plus = enabled
doctest_optionflags = NORMALIZE_WHITESPACE FLOAT_CMP ELLIPSIS
text_file_format = rst
Expand All @@ -20,3 +23,7 @@ filterwarnings =
ignore:unclosed transport
ignore:The loop argument is deprecated
ignore:Event loop is closed
# Comes from sunkit-magex
ignore:`row_stack` alias is deprecated. Use `np.vstack` directly.
# Need to fix when sunpy 6.0 is out
ignore:The assume_spherical_screen function is deprecated and may be removed in a future version
44 changes: 14 additions & 30 deletions sunkit_pyvista/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
__all__ = ["SunpyPlotter"]


class SunpyPlotter:
class SunpyPlotter(pv.Plotter):
"""
A plotter for 3D data.
This class wraps `pyvsita.Plotter` to provide coordinate-aware plotting.
This class inherits `pyvsita.Plotter` so we can provide coordinate-aware plotting.
For now, all coordinates are converted to
a specific frame (`~sunpy.coordinates.HeliocentricInertial` by default),
and distance units are such that :math:`R_{sun} = 1`.
Expand All @@ -45,15 +45,14 @@ class SunpyPlotter:
Stores a reference to all the plotted meshes in a dictionary.
"""

def __init__(self, *, coordinate_frame=None, obstime=None, **kwargs):
def __init__(self, *args, coordinate_frame=None, obstime=None, **kwargs):
super().__init__(*args, **kwargs)
if coordinate_frame is not None and obstime is not None:
msg = "Only coordinate_frame or obstime can be specified, not both."
raise ValueError(msg)
if coordinate_frame is None:
coordinate_frame = HeliocentricInertial(obstime=obstime)
self._coordinate_frame = coordinate_frame
self._plotter = pv.Plotter(**kwargs)
self.camera = self._plotter.camera
self.all_meshes = {}

@property
Expand All @@ -63,21 +62,6 @@ def coordinate_frame(self):
"""
return self._coordinate_frame

@property
def plotter(self):
"""
`pyvista.Plotter`.
"""
return self._plotter

def show(self, *args, **kwargs):
"""
Show the plot.
See `pyvista.Plotter.show` for accepted arguments.
"""
self.plotter.show(*args, **kwargs)

def _extract_color(self, mesh_kwargs):
"""
Converts a given color string to it's equivalent rgb tuple.
Expand Down Expand Up @@ -169,7 +153,7 @@ def set_camera_coordinate(self, coord):
"""
camera_position = self._coords_to_xyz(coord)
pos = tuple(camera_position)
self.plotter.camera.position = pos
self.camera.position = pos

def set_camera_focus(self, coord):
"""
Expand All @@ -182,7 +166,7 @@ def set_camera_focus(self, coord):
"""
camera_position = self._coords_to_xyz(coord)
pos = tuple(camera_position)
self.plotter.set_focus(pos)
self.set_focus(pos)

@u.quantity_input
def set_view_angle(self, angle: u.deg):
Expand All @@ -199,7 +183,7 @@ def set_view_angle(self, angle: u.deg):
msg = "specified view angle must be " "0 deg < angle <= 180 deg"
raise ValueError(msg)
zoom_value = self.camera.view_angle / view_angle
self.plotter.camera.zoom(zoom_value)
self.camera.zoom(zoom_value)

def _map_to_mesh(self, m, *, assume_spherical=True):
"""
Expand Down Expand Up @@ -312,7 +296,7 @@ def plot_map(
clim = [0, 1]
cmap = self._get_cmap(kwargs, m)
kwargs.setdefault("show_scalar_bar", False)
self.plotter.add_mesh(map_mesh, cmap=cmap, clim=clim, **kwargs)
self.add_mesh(map_mesh, cmap=cmap, clim=clim, **kwargs)
map_mesh.add_field_data([cmap], "cmap")
self._add_mesh_to_dict(block_name="maps", mesh=map_mesh)

Expand Down Expand Up @@ -360,7 +344,7 @@ def plot_coordinates(self, coords, *, radius=0.05, **kwargs):
point_mesh.add_field_data(color, "color")

kwargs["render_lines_as_tubes"] = kwargs.pop("render_lines_as_tubes", True)
self.plotter.add_mesh(point_mesh, color=color, smooth_shading=True, **kwargs)
self.add_mesh(point_mesh, color=color, smooth_shading=True, **kwargs)
self._add_mesh_to_dict(block_name="coordinates", mesh=point_mesh)

def plot_solar_axis(self, *, length=2.5, arrow_kwargs=None, **kwargs):
Expand Down Expand Up @@ -390,7 +374,7 @@ def plot_solar_axis(self, *, length=2.5, arrow_kwargs=None, **kwargs):
)
color = self._extract_color(kwargs)
arrow_mesh.add_field_data(color, "color")
self.plotter.add_mesh(arrow_mesh, color=color, **kwargs)
self.add_mesh(arrow_mesh, color=color, **kwargs)
self._add_mesh_to_dict(block_name="solar_axis", mesh=arrow_mesh)

def plot_quadrangle(
Expand Down Expand Up @@ -455,7 +439,7 @@ def plot_quadrangle(
quad_block = quad_block.tube(radius=radius)
color = self._extract_color(kwargs)
quad_block.add_field_data(color, "color")
self.plotter.add_mesh(quad_block, color=color, **kwargs)
self.add_mesh(quad_block, color=color, **kwargs)
self._add_mesh_to_dict(block_name="quadrangles", mesh=quad_block)

def plot_field_lines(self, field_lines, *, color_func=None, **kwargs):
Expand Down Expand Up @@ -504,7 +488,7 @@ def color_func(field_line):

kwargs["render_lines_as_tubes"] = kwargs.pop("render_lines_as_tubes", True)
kwargs["line_width"] = kwargs.pop("line_width", 2)
self.plotter.add_mesh(spline, color=color, **kwargs)
self.add_mesh(spline, color=color, **kwargs)
field_line_meshes.append(spline)

self._add_mesh_to_dict(block_name="field_lines", mesh=spline)
Expand Down Expand Up @@ -557,7 +541,7 @@ def _loop_through_meshes(self, mesh_block):
else:
color = dict(block.field_data).get("color", None)
cmap = dict(block.field_data).get("cmap", [None])[0]
self.plotter.add_mesh(block, color=color, cmap=cmap)
self.add_mesh(block, color=color, cmap=cmap)

def load(self, filepath):
"""
Expand Down Expand Up @@ -596,5 +580,5 @@ def plot_limb(self, m, *, radius=0.02, **kwargs):
color = self._extract_color(mesh_kwargs=kwargs)
limb_block = limb_block.tube(radius=radius)
limb_block.add_field_data(color, "color")
self.plotter.add_mesh(limb_block, color=color, **kwargs)
self.add_mesh(limb_block, color=color, **kwargs)
self._add_mesh_to_dict(block_name="limbs", mesh=limb_block)
46 changes: 23 additions & 23 deletions sunkit_pyvista/tests/test_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.display_server()
def test_basic(plotter):
assert isinstance(plotter.plotter, pv.Plotter)
assert isinstance(plotter, pv.Plotter)
plotter.show()


Expand Down Expand Up @@ -44,14 +44,14 @@ def test_set_view_angle(plotter):

def test_plot_map(aia171_test_map, plotter):
plotter.plot_map(aia171_test_map)
assert plotter.plotter.mesh.n_cells == 16384
assert plotter.plotter.mesh.n_points == 16641
assert plotter.mesh.n_cells == 16384
assert plotter.mesh.n_points == 16641


def test_plot_solar_axis(plotter):
plotter.plot_solar_axis()
assert plotter.plotter.mesh.n_cells == 43
assert plotter.plotter.mesh.n_points == 101
assert plotter.mesh.n_cells == 43
assert plotter.mesh.n_points == 101


def test_plot_quadrangle(aia171_test_map, plotter):
Expand All @@ -67,8 +67,8 @@ def test_plot_quadrangle(aia171_test_map, plotter):
height=60 * u.deg,
color="blue",
)
assert plotter.plotter.mesh.n_cells == 22
assert plotter.plotter.mesh.n_points == 80060
assert plotter.mesh.n_cells == 22
assert plotter.mesh.n_points == 80060


def test_plot_coordinates(aia171_test_map, plotter):
Expand All @@ -80,8 +80,8 @@ def test_plot_coordinates(aia171_test_map, plotter):
frame="heliocentricinertial",
)
plotter.plot_coordinates(line)
assert plotter.plotter.mesh.n_cells == 1
assert plotter.plotter.mesh.n_points == 3
assert plotter.mesh.n_cells == 1
assert plotter.mesh.n_points == 3

# Tests plotting of a small sphere
sphere = SkyCoord(
Expand All @@ -91,30 +91,30 @@ def test_plot_coordinates(aia171_test_map, plotter):
frame="heliocentricinertial",
)
plotter.plot_coordinates(sphere)
assert plotter.plotter.mesh.n_cells == 1680
assert plotter.plotter.mesh.n_points == 842
assert plotter.mesh.n_cells == 1680
assert plotter.mesh.n_points == 842
expected_center = [-0.5000000149011612, -0.5, 0.7071067690849304]
assert np.allclose(plotter.plotter.mesh.center, expected_center)
assert np.allclose(plotter.mesh.center, expected_center)

pixel_pos = np.argwhere(aia171_test_map.data == aia171_test_map.data.max()) * u.pixel
hpc_max = aia171_test_map.pixel_to_world(pixel_pos[:, 1], pixel_pos[:, 0])
plotter.plot_coordinates(hpc_max, color="blue")
assert plotter.plotter.mesh.n_cells == 1680
assert plotter.plotter.mesh.n_points == 842
assert plotter.mesh.n_cells == 1680
assert plotter.mesh.n_points == 842


def test_clip_interval(aia171_test_map, plotter):
plotter.plot_map(aia171_test_map, clip_interval=(1, 99) * u.percent)
clim = plotter._get_clim( # NOQA: SLF001
data=plotter.plotter.mesh["data"],
data=plotter.mesh["data"],
clip_interval=(1, 99) * u.percent,
)
expected_clim = [0.006716044038535769, 0.8024368512284383]
assert np.allclose(clim, expected_clim)

expected_clim = [0, 1]
clim = plotter._get_clim( # NOQA: SLF001
data=plotter.plotter.mesh["data"],
data=plotter.mesh["data"],
clip_interval=(0, 100) * u.percent,
)
assert np.allclose(clim, expected_clim)
Expand All @@ -138,12 +138,12 @@ def test_save_and_load(aia171_test_map, plotter, tmp_path):
filepath = tmp_path / "save_data.vtm"
plotter.save(filepath=filepath)

plotter.plotter.clear()
plotter.clear()
plotter.load(filepath)

assert plotter.plotter.mesh.n_cells == 16384
assert plotter.plotter.mesh.n_points == 16641
assert dict(plotter.plotter.mesh.field_data)["cmap"][0] == "sdoaia171"
assert plotter.mesh.n_cells == 16384
assert plotter.mesh.n_points == 16641
assert dict(plotter.mesh.field_data)["cmap"][0] == "sdoaia171"

with pytest.raises(ValueError, match="VTM file"):
plotter.save(filepath=filepath)
Expand All @@ -160,10 +160,10 @@ def test_loop_through_meshes(plotter):
outer_block = pv.MultiBlock([inner_block, sphere2])
plotter._loop_through_meshes(outer_block) # NOQA: SLF001

assert plotter.plotter.mesh.center == [0, 1, 1]
assert plotter.mesh.center == [0, 1, 1]


def test_plot_limb(aia171_test_map, plotter):
plotter.plot_limb(aia171_test_map)
assert plotter.plotter.mesh.n_cells == 22
assert plotter.plotter.mesh.n_points == 20040
assert plotter.mesh.n_cells == 22
assert plotter.mesh.n_points == 20040

0 comments on commit 9d29fe2

Please sign in to comment.