Skip to content

Commit

Permalink
add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasdorch committed Oct 14, 2022
1 parent 511a380 commit 967415d
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions python/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
## Typing imports
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from typing import Callable, Union, Any, Literal, Tuple
from typing import Callable, Union, Any, Literal, Tuple, List

# ------------------------------------------------------- #
# Visualization
Expand Down Expand Up @@ -108,7 +108,13 @@ def filter_dict(dict_to_filter: dict, func_with_kwargs: Callable) -> dict:


def place_label(
ax: Axes, label_text: str, x, y, centerx, centery, label_parameters: dict = None
ax: Axes,
label_text: str,
x: float,
y: float,
centerx: float,
centery: float,
label_parameters: dict = None,
) -> Axes:

if label_parameters is None:
Expand Down Expand Up @@ -148,7 +154,7 @@ def place_label(
# Returns the intersection points of two Volumes.
# Volumes must be a line, plane, or rectangular prism
# (since they are volume objects)
def intersect_volume_volume(volume1: Volume, volume2: Volume) -> list:
def intersect_volume_volume(volume1: Volume, volume2: Volume) -> List[Vector3]:
# volume1 ............... [volume]
# volume2 ............... [volume]

Expand Down Expand Up @@ -214,7 +220,7 @@ def intersect_volume_volume(volume1: Volume, volume2: Volume) -> list:
# Not only do we need to check for all of these possibilities, but we also need
# to check if the user accidentally specifies a plane that stretches beyond the
# simulation domain.
def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> (Vector3, Vector3):
def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> Tuple[Vector3, Vector3]:
# Pull correct plane from user
if output_plane:
plane_center, plane_size = (output_plane.center, output_plane.size)
Expand Down Expand Up @@ -259,7 +265,7 @@ def get_2D_dimensions(sim: Simulation, output_plane: Volume) -> (Vector3, Vector

def box_vertices(
box_center: Vector3, box_size: Vector3, is_cylindrical: bool = False
) -> (float, float, float, float, float, float):
) -> Tuple[float, float, float, float, float, float]:
# in cylindrical coordinates, radial (R) axis
# is in the range (0,R) rather than (-R/2,+R/2)
# as in Cartesian coordinates.
Expand Down Expand Up @@ -671,8 +677,7 @@ def plot_sources(
output_plane: Volume = None,
labels: bool = False,
source_parameters: dict = None,
):

) -> Axes:
# consolidate plotting parameters
if source_parameters is None:
source_parameters = default_source_parameters
Expand Down Expand Up @@ -923,7 +928,7 @@ def visualize_chunks(sim: Simulation):
vols = sim.structure.get_chunk_volumes()
owners = sim.structure.get_chunk_owners()

def plot_box(box, proc, fig, ax):
def plot_box(box, proc, fig, ax: Axes):
if sim.structure.gv.dim == 2:
low = Vector3(box.low.x, box.low.y, box.low.z)
high = Vector3(box.high.x, box.high.y, box.high.z)
Expand Down Expand Up @@ -1028,16 +1033,14 @@ def display_figure_immediately(fig: Figure) -> None:
# ------------------------------------------------------- #
# A helper class used to make jshtml animations embed
# seamlessly within Jupyter notebooks.


class JS_Animation:
def __init__(self, jshtml):
def __init__(self, jshtml: str):
self.jshtml = jshtml

def _repr_html_(self):
def _repr_html_(self) -> str:
return self.jshtml

def get_jshtml(self):
def get_jshtml(self) -> str:
return self.jshtml


Expand Down Expand Up @@ -1251,7 +1254,6 @@ def __call__(self, sim: Simulation, todo: Literal["step", "finish"]) -> None:
self.grab_frame()
return
elif todo == "finish":

# Normalize the frames, if requested, and export
if self.normalize and mp.am_master():
if mp.verbosity.meep > 0:
Expand All @@ -1263,7 +1265,6 @@ def __call__(self, sim: Simulation, todo: Literal["step", "finish"]) -> None:
self.ax.images[-1].set_data(fields[k, :, :])
self.ax.images[-1].set_clim(vmin=-0.8, vmax=0.8)
self.grab_frame()

return

@property
Expand Down

0 comments on commit 967415d

Please sign in to comment.