From 967415d2db8e2db47d038194983c47d8c6dc305b Mon Sep 17 00:00:00 2001 From: Thomas Dorch Date: Thu, 13 Oct 2022 18:13:12 -0700 Subject: [PATCH] add type hints --- python/visualization.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/python/visualization.py b/python/visualization.py index 9f12073f6..2dd46529d 100644 --- a/python/visualization.py +++ b/python/visualization.py @@ -14,7 +14,7 @@ ## Typing imports from matplotlib.axes import Axes from matplotlib.figure import Figure -from typing import Callable, Union, Any, Literal, Tuple +from typing import Callable, Union, Any, Literal, Tuple, List # ------------------------------------------------------- # # Visualization @@ -108,7 +108,13 @@ def filter_dict(dict_to_filter: dict, func_with_kwargs: Callable) -> dict: def place_label( - ax: Axes, label_text: str, x, y, centerx, centery, label_parameters: dict = None + ax: Axes, + label_text: str, + x: float, + y: float, + centerx: float, + centery: float, + label_parameters: dict = None, ) -> Axes: if label_parameters is None: @@ -148,7 +154,7 @@ def place_label( # Returns the intersection points of two Volumes. # Volumes must be a line, plane, or rectangular prism # (since they are volume objects) -def intersect_volume_volume(volume1: Volume, volume2: Volume) -> list: +def intersect_volume_volume(volume1: Volume, volume2: Volume) -> List[Vector3]: # volume1 ............... [volume] # volume2 ............... [volume] @@ -214,7 +220,7 @@ def intersect_volume_volume(volume1: Volume, volume2: Volume) -> list: # Not only do we need to check for all of these possibilities, but we also need # to check if the user accidentally specifies a plane that stretches beyond the # simulation domain. -def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> (Vector3, Vector3): +def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> Tuple[Vector3, Vector3]: # Pull correct plane from user if output_plane: plane_center, plane_size = (output_plane.center, output_plane.size) @@ -259,7 +265,7 @@ def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> (Vector3, Vector def box_vertices( box_center: Vector3, box_size: Vector3, is_cylindrical: bool = False -) -> (float, float, float, float, float, float): +) -> Tuple[float, float, float, float, float, float]: # in cylindrical coordinates, radial (R) axis # is in the range (0,R) rather than (-R/2,+R/2) # as in Cartesian coordinates. @@ -671,8 +677,7 @@ def plot_sources( output_plane: Volume = None, labels: bool = False, source_parameters: dict = None, -): - +) -> Axes: # consolidate plotting parameters if source_parameters is None: source_parameters = default_source_parameters @@ -923,7 +928,7 @@ def visualize_chunks(sim: Simulation): vols = sim.structure.get_chunk_volumes() owners = sim.structure.get_chunk_owners() - def plot_box(box, proc, fig, ax): + def plot_box(box, proc, fig, ax: Axes): if sim.structure.gv.dim == 2: low = Vector3(box.low.x, box.low.y, box.low.z) high = Vector3(box.high.x, box.high.y, box.high.z) @@ -1028,16 +1033,14 @@ def display_figure_immediately(fig: Figure) -> None: # ------------------------------------------------------- # # A helper class used to make jshtml animations embed # seamlessly within Jupyter notebooks. - - class JS_Animation: - def __init__(self, jshtml): + def __init__(self, jshtml: str): self.jshtml = jshtml - def _repr_html_(self): + def _repr_html_(self) -> str: return self.jshtml - def get_jshtml(self): + def get_jshtml(self) -> str: return self.jshtml @@ -1251,7 +1254,6 @@ def __call__(self, sim: Simulation, todo: Literal["step", "finish"]) -> None: self.grab_frame() return elif todo == "finish": - # Normalize the frames, if requested, and export if self.normalize and mp.am_master(): if mp.verbosity.meep > 0: @@ -1263,7 +1265,6 @@ def __call__(self, sim: Simulation, todo: Literal["step", "finish"]) -> None: self.ax.images[-1].set_data(fields[k, :, :]) self.ax.images[-1].set_clim(vmin=-0.8, vmax=0.8) self.grab_frame() - return @property