From 9d29fe23f4342275840430b55627d96aa47e51aa Mon Sep 17 00:00:00 2001 From: Nabil Freij Date: Tue, 9 Jul 2024 09:17:55 -0700 Subject: [PATCH] Inherited from plotter instead of wrapping it (#138) * Inherited from plotter instead of wrapping it * changelog * precommit * revert * add ignore * nitpick * remove init and see if its ok * bump min vista --- .github/workflows/ci.yml | 6 ++-- .pre-commit-config.yaml | 3 ++ changelog/138.breaking.rst | 2 ++ docs/nitpick-exceptions.txt | 3 ++ examples/floating_sphere.py | 4 +-- pyproject.toml | 2 +- pytest.ini | 7 +++++ sunkit_pyvista/plotter.py | 44 +++++++++----------------- sunkit_pyvista/tests/test_pyvista.py | 46 ++++++++++++++-------------- 9 files changed, 58 insertions(+), 59 deletions(-) create mode 100644 changelog/138.breaking.rst diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2445334..c28ae59 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,9 +22,9 @@ jobs: core: uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main with: + display: true submodules: false coverage: codecov - display: true envs: | - linux: py312 @@ -32,9 +32,9 @@ jobs: needs: [core] uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main with: + display: true submodules: false coverage: codecov - display: true envs: | - macos: py311 - windows: py310 @@ -44,9 +44,9 @@ jobs: needs: [core] uses: OpenAstronomy/github-actions-workflows/.github/workflows/tox.yml@main with: + display: true default_python: "3.12" pytest: false - display: true libraries: | apt: - graphviz diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48d64e5..c57612b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,3 +44,6 @@ repos: rev: v4.0.0-alpha.8 hooks: - id: prettier +ci: + autofix_prs: false + autoupdate_schedule: "quarterly" diff --git a/changelog/138.breaking.rst b/changelog/138.breaking.rst new file mode 100644 index 0000000..a294106 --- /dev/null +++ b/changelog/138.breaking.rst @@ -0,0 +1,2 @@ +Instead of wrapping `pyvsita.Plotter`, we now inherit from it. +This allows us to drop the ``plotter.plotter`` lines in examples and user facing API. diff --git a/docs/nitpick-exceptions.txt b/docs/nitpick-exceptions.txt index d5de9d0..56aee0a 100644 --- a/docs/nitpick-exceptions.txt +++ b/docs/nitpick-exceptions.txt @@ -16,3 +16,6 @@ py:class array-like py:obj parfive py:class string py:class floats +py:class pyvista.plotting.plotter.BasePlotter +py:class pyvista.plotting.picking.PickingMethods +py:class pyvista.plotting.picking.PickingInterface diff --git a/examples/floating_sphere.py b/examples/floating_sphere.py index 1c903a4..8b85b25 100644 --- a/examples/floating_sphere.py +++ b/examples/floating_sphere.py @@ -65,7 +65,7 @@ # sphinx_gallery_defer_figures cloud = plotter.coordinates_to_polydata(point_cloud) -_ = plotter.plotter.add_points(cloud, point_size=0.7, color="cyan", style="points_gaussian") +_ = plotter.add_points(cloud, point_size=0.7, color="cyan", style="points_gaussian") ################################################################################ # Next we want to build a surface from these points. @@ -75,7 +75,7 @@ # sphinx_gallery_defer_figures surf = cloud.delaunay_3d() -_ = plotter.plotter.add_mesh(surf) +_ = plotter.add_mesh(surf) ################################################################################ # Finally set up the camera position and focus. diff --git a/pyproject.toml b/pyproject.toml index 847ff59..11b7892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Visualization", ] dependencies = [ - 'pyvista[all]>= 0.43.9', + 'pyvista[all]>= 0.44.0', 'sunpy[map]>=5.0.0', ] diff --git a/pytest.ini b/pytest.ini index a6006b2..7eb8ba8 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,6 +3,9 @@ minversion = 7.0 testpaths = sunkit_pyvista docs +norecursedirs = + docs/_build + docs/generated doctest_plus = enabled doctest_optionflags = NORMALIZE_WHITESPACE FLOAT_CMP ELLIPSIS text_file_format = rst @@ -20,3 +23,7 @@ filterwarnings = ignore:unclosed transport ignore:The loop argument is deprecated ignore:Event loop is closed + # Comes from sunkit-magex + ignore:`row_stack` alias is deprecated. Use `np.vstack` directly. + # Need to fix when sunpy 6.0 is out + ignore:The assume_spherical_screen function is deprecated and may be removed in a future version diff --git a/sunkit_pyvista/plotter.py b/sunkit_pyvista/plotter.py index 3416028..8f9508d 100644 --- a/sunkit_pyvista/plotter.py +++ b/sunkit_pyvista/plotter.py @@ -18,11 +18,11 @@ __all__ = ["SunpyPlotter"] -class SunpyPlotter: +class SunpyPlotter(pv.Plotter): """ A plotter for 3D data. - This class wraps `pyvsita.Plotter` to provide coordinate-aware plotting. + This class inherits `pyvsita.Plotter` so we can provide coordinate-aware plotting. For now, all coordinates are converted to a specific frame (`~sunpy.coordinates.HeliocentricInertial` by default), and distance units are such that :math:`R_{sun} = 1`. @@ -45,15 +45,14 @@ class SunpyPlotter: Stores a reference to all the plotted meshes in a dictionary. """ - def __init__(self, *, coordinate_frame=None, obstime=None, **kwargs): + def __init__(self, *args, coordinate_frame=None, obstime=None, **kwargs): + super().__init__(*args, **kwargs) if coordinate_frame is not None and obstime is not None: msg = "Only coordinate_frame or obstime can be specified, not both." raise ValueError(msg) if coordinate_frame is None: coordinate_frame = HeliocentricInertial(obstime=obstime) self._coordinate_frame = coordinate_frame - self._plotter = pv.Plotter(**kwargs) - self.camera = self._plotter.camera self.all_meshes = {} @property @@ -63,21 +62,6 @@ def coordinate_frame(self): """ return self._coordinate_frame - @property - def plotter(self): - """ - `pyvista.Plotter`. - """ - return self._plotter - - def show(self, *args, **kwargs): - """ - Show the plot. - - See `pyvista.Plotter.show` for accepted arguments. - """ - self.plotter.show(*args, **kwargs) - def _extract_color(self, mesh_kwargs): """ Converts a given color string to it's equivalent rgb tuple. @@ -169,7 +153,7 @@ def set_camera_coordinate(self, coord): """ camera_position = self._coords_to_xyz(coord) pos = tuple(camera_position) - self.plotter.camera.position = pos + self.camera.position = pos def set_camera_focus(self, coord): """ @@ -182,7 +166,7 @@ def set_camera_focus(self, coord): """ camera_position = self._coords_to_xyz(coord) pos = tuple(camera_position) - self.plotter.set_focus(pos) + self.set_focus(pos) @u.quantity_input def set_view_angle(self, angle: u.deg): @@ -199,7 +183,7 @@ def set_view_angle(self, angle: u.deg): msg = "specified view angle must be " "0 deg < angle <= 180 deg" raise ValueError(msg) zoom_value = self.camera.view_angle / view_angle - self.plotter.camera.zoom(zoom_value) + self.camera.zoom(zoom_value) def _map_to_mesh(self, m, *, assume_spherical=True): """ @@ -312,7 +296,7 @@ def plot_map( clim = [0, 1] cmap = self._get_cmap(kwargs, m) kwargs.setdefault("show_scalar_bar", False) - self.plotter.add_mesh(map_mesh, cmap=cmap, clim=clim, **kwargs) + self.add_mesh(map_mesh, cmap=cmap, clim=clim, **kwargs) map_mesh.add_field_data([cmap], "cmap") self._add_mesh_to_dict(block_name="maps", mesh=map_mesh) @@ -360,7 +344,7 @@ def plot_coordinates(self, coords, *, radius=0.05, **kwargs): point_mesh.add_field_data(color, "color") kwargs["render_lines_as_tubes"] = kwargs.pop("render_lines_as_tubes", True) - self.plotter.add_mesh(point_mesh, color=color, smooth_shading=True, **kwargs) + self.add_mesh(point_mesh, color=color, smooth_shading=True, **kwargs) self._add_mesh_to_dict(block_name="coordinates", mesh=point_mesh) def plot_solar_axis(self, *, length=2.5, arrow_kwargs=None, **kwargs): @@ -390,7 +374,7 @@ def plot_solar_axis(self, *, length=2.5, arrow_kwargs=None, **kwargs): ) color = self._extract_color(kwargs) arrow_mesh.add_field_data(color, "color") - self.plotter.add_mesh(arrow_mesh, color=color, **kwargs) + self.add_mesh(arrow_mesh, color=color, **kwargs) self._add_mesh_to_dict(block_name="solar_axis", mesh=arrow_mesh) def plot_quadrangle( @@ -455,7 +439,7 @@ def plot_quadrangle( quad_block = quad_block.tube(radius=radius) color = self._extract_color(kwargs) quad_block.add_field_data(color, "color") - self.plotter.add_mesh(quad_block, color=color, **kwargs) + self.add_mesh(quad_block, color=color, **kwargs) self._add_mesh_to_dict(block_name="quadrangles", mesh=quad_block) def plot_field_lines(self, field_lines, *, color_func=None, **kwargs): @@ -504,7 +488,7 @@ def color_func(field_line): kwargs["render_lines_as_tubes"] = kwargs.pop("render_lines_as_tubes", True) kwargs["line_width"] = kwargs.pop("line_width", 2) - self.plotter.add_mesh(spline, color=color, **kwargs) + self.add_mesh(spline, color=color, **kwargs) field_line_meshes.append(spline) self._add_mesh_to_dict(block_name="field_lines", mesh=spline) @@ -557,7 +541,7 @@ def _loop_through_meshes(self, mesh_block): else: color = dict(block.field_data).get("color", None) cmap = dict(block.field_data).get("cmap", [None])[0] - self.plotter.add_mesh(block, color=color, cmap=cmap) + self.add_mesh(block, color=color, cmap=cmap) def load(self, filepath): """ @@ -596,5 +580,5 @@ def plot_limb(self, m, *, radius=0.02, **kwargs): color = self._extract_color(mesh_kwargs=kwargs) limb_block = limb_block.tube(radius=radius) limb_block.add_field_data(color, "color") - self.plotter.add_mesh(limb_block, color=color, **kwargs) + self.add_mesh(limb_block, color=color, **kwargs) self._add_mesh_to_dict(block_name="limbs", mesh=limb_block) diff --git a/sunkit_pyvista/tests/test_pyvista.py b/sunkit_pyvista/tests/test_pyvista.py index 6446790..78e3163 100644 --- a/sunkit_pyvista/tests/test_pyvista.py +++ b/sunkit_pyvista/tests/test_pyvista.py @@ -11,7 +11,7 @@ @pytest.mark.display_server() def test_basic(plotter): - assert isinstance(plotter.plotter, pv.Plotter) + assert isinstance(plotter, pv.Plotter) plotter.show() @@ -44,14 +44,14 @@ def test_set_view_angle(plotter): def test_plot_map(aia171_test_map, plotter): plotter.plot_map(aia171_test_map) - assert plotter.plotter.mesh.n_cells == 16384 - assert plotter.plotter.mesh.n_points == 16641 + assert plotter.mesh.n_cells == 16384 + assert plotter.mesh.n_points == 16641 def test_plot_solar_axis(plotter): plotter.plot_solar_axis() - assert plotter.plotter.mesh.n_cells == 43 - assert plotter.plotter.mesh.n_points == 101 + assert plotter.mesh.n_cells == 43 + assert plotter.mesh.n_points == 101 def test_plot_quadrangle(aia171_test_map, plotter): @@ -67,8 +67,8 @@ def test_plot_quadrangle(aia171_test_map, plotter): height=60 * u.deg, color="blue", ) - assert plotter.plotter.mesh.n_cells == 22 - assert plotter.plotter.mesh.n_points == 80060 + assert plotter.mesh.n_cells == 22 + assert plotter.mesh.n_points == 80060 def test_plot_coordinates(aia171_test_map, plotter): @@ -80,8 +80,8 @@ def test_plot_coordinates(aia171_test_map, plotter): frame="heliocentricinertial", ) plotter.plot_coordinates(line) - assert plotter.plotter.mesh.n_cells == 1 - assert plotter.plotter.mesh.n_points == 3 + assert plotter.mesh.n_cells == 1 + assert plotter.mesh.n_points == 3 # Tests plotting of a small sphere sphere = SkyCoord( @@ -91,22 +91,22 @@ def test_plot_coordinates(aia171_test_map, plotter): frame="heliocentricinertial", ) plotter.plot_coordinates(sphere) - assert plotter.plotter.mesh.n_cells == 1680 - assert plotter.plotter.mesh.n_points == 842 + assert plotter.mesh.n_cells == 1680 + assert plotter.mesh.n_points == 842 expected_center = [-0.5000000149011612, -0.5, 0.7071067690849304] - assert np.allclose(plotter.plotter.mesh.center, expected_center) + assert np.allclose(plotter.mesh.center, expected_center) pixel_pos = np.argwhere(aia171_test_map.data == aia171_test_map.data.max()) * u.pixel hpc_max = aia171_test_map.pixel_to_world(pixel_pos[:, 1], pixel_pos[:, 0]) plotter.plot_coordinates(hpc_max, color="blue") - assert plotter.plotter.mesh.n_cells == 1680 - assert plotter.plotter.mesh.n_points == 842 + assert plotter.mesh.n_cells == 1680 + assert plotter.mesh.n_points == 842 def test_clip_interval(aia171_test_map, plotter): plotter.plot_map(aia171_test_map, clip_interval=(1, 99) * u.percent) clim = plotter._get_clim( # NOQA: SLF001 - data=plotter.plotter.mesh["data"], + data=plotter.mesh["data"], clip_interval=(1, 99) * u.percent, ) expected_clim = [0.006716044038535769, 0.8024368512284383] @@ -114,7 +114,7 @@ def test_clip_interval(aia171_test_map, plotter): expected_clim = [0, 1] clim = plotter._get_clim( # NOQA: SLF001 - data=plotter.plotter.mesh["data"], + data=plotter.mesh["data"], clip_interval=(0, 100) * u.percent, ) assert np.allclose(clim, expected_clim) @@ -138,12 +138,12 @@ def test_save_and_load(aia171_test_map, plotter, tmp_path): filepath = tmp_path / "save_data.vtm" plotter.save(filepath=filepath) - plotter.plotter.clear() + plotter.clear() plotter.load(filepath) - assert plotter.plotter.mesh.n_cells == 16384 - assert plotter.plotter.mesh.n_points == 16641 - assert dict(plotter.plotter.mesh.field_data)["cmap"][0] == "sdoaia171" + assert plotter.mesh.n_cells == 16384 + assert plotter.mesh.n_points == 16641 + assert dict(plotter.mesh.field_data)["cmap"][0] == "sdoaia171" with pytest.raises(ValueError, match="VTM file"): plotter.save(filepath=filepath) @@ -160,10 +160,10 @@ def test_loop_through_meshes(plotter): outer_block = pv.MultiBlock([inner_block, sphere2]) plotter._loop_through_meshes(outer_block) # NOQA: SLF001 - assert plotter.plotter.mesh.center == [0, 1, 1] + assert plotter.mesh.center == [0, 1, 1] def test_plot_limb(aia171_test_map, plotter): plotter.plot_limb(aia171_test_map) - assert plotter.plotter.mesh.n_cells == 22 - assert plotter.plotter.mesh.n_points == 20040 + assert plotter.mesh.n_cells == 22 + assert plotter.mesh.n_points == 20040