Skip to content

Commit

Permalink
Merge pull request #68 from devbrones/master
Browse files Browse the repository at this point in the history
Support for returning plots instead of showing them directly
  • Loading branch information
flaport authored Dec 17, 2023
2 parents 3f4df26 + 8561dbb commit ff36719
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions fdtd/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def visualize(
srccolor="C0",
detcolor="C2",
norm="linear",
show=False, # default False to allow animate to be true
animate=False, # True to see frame by frame states of grid while running simulation
index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index)
save=False, # True to save frames (requires parameters index, folder)
folder=None, # folder path to save frames
show=False, # default False to allow animate to be true
style=None,
):
"""visualize a projection of the grid and the optical energy inside the grid
Expand All @@ -61,7 +62,10 @@ def visualize(
index: index for each frame of animation (typically a loop variable is passed)
save: save frames in a folder
folder: path to folder to save frames
style: Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if style is not None:
plt.style.use(style)
if norm not in ("linear", "lin", "log"):
raise ValueError("Color map normalization should be 'linear' or 'log'.")
# imports (placed here to circumvent circular imports)
Expand Down Expand Up @@ -116,7 +120,7 @@ def visualize(
plt.plot([], lw=3, color=detcolor, label="Detectors")

# Grid energy
grid_energy = bd.sum(grid.E ** 2 + grid.H ** 2, -1)
grid_energy = bd.sum(grid.E**2 + grid.H**2, -1)
if x is not None:
assert grid.Ny > 1 and grid.Nz > 1
xlabel, ylabel = "y", "z"
Expand Down Expand Up @@ -309,7 +313,9 @@ def visualize(
cmap_norm = None
if norm == "log":
cmap_norm = LogNorm(vmin=1e-4, vmax=grid_energy.max() + 1e-4)
plt.imshow(abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm)
plt.imshow(
abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm
)

# finalize the plot
plt.ylabel(xlabel)
Expand All @@ -327,8 +333,12 @@ def visualize(
if show:
plt.show()

return plt.gcf() # return figure for gradio support


def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
def dB_map_2D(
block_det=None, choose_axis=2, interpolation="spline16", show=True, style=None
):
"""
Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector.
Compatible with continuous sources (not pulse).
Expand All @@ -338,6 +348,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
block_det (numpy array): 5 axes numpy array (timestep, row, column, height, {x, y, z} parameter) created by 'fdtd.BlockDetector'.
(optional) choose_axis (int): Choose between {0, 1, 2} to display {x, y, z} data. Default 2 (-> z).
(optional) interpolation (string): Preferred 'matplotlib.pyplot.imshow' interpolation. Default "spline16".
show (bool): automatically call plt.show at the end of the plotting function
style (string): Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if block_det is None:
raise ValueError(
Expand All @@ -349,7 +361,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
)

# TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure

if style is not None:
plt.style.use(style)
plt.ioff()
plt.close()
a = [] # array to store wave intensities
Expand All @@ -360,24 +373,28 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
a[i].append(max(temp) - min(temp))

peakVal, minVal = max(map(max, a)), min(map(min, a))
#print(
# print(
# "Peak at:",
# [
# [[i, j] for j, y in enumerate(x) if y == peakVal]
# for i, x in enumerate(a)
# if peakVal in x
# ],
#)
# )
a = 10 * log10([[y / minVal for y in x] for x in a])

plt.title("dB map of Electrical waves in detector region")
plt.imshow(a, cmap="inferno", interpolation=interpolation)
cbar = plt.colorbar()
cbar.ax.set_ylabel("dB scale", rotation=270)
plt.show()

if show:
plt.show()

return plt.gcf()

def plot_detection(detector_dict=None, specific_plot=None):

def plot_detection(detector_dict=None, specific_plot=None, show=True, style=None):
"""
1. Plots intensity readings on array of 'fdtd.LineDetector' as a function of timestep.
2. Plots time of arrival of pulse at different LineDetector in array.
Expand All @@ -387,6 +404,8 @@ def plot_detection(detector_dict=None, specific_plot=None):
detector_dict (dictionary): Dictionary of detector readings, as created by 'fdtd.Grid.save_data()'.
(optional) specific_plot (string): Plot for a specific axis data. Choose from {"Ex", "Ey", "Ez", "Hx", "Hy", "Hz"}.
"""
if style is not None:
plt.style.use(style)
if detector_dict is None:
raise Exception(
"Function plotDetection() requires a dictionary of detector readings as 'detector_dict' parameter."
Expand Down Expand Up @@ -464,7 +483,10 @@ def plot_detection(detector_dict=None, specific_plot=None):
plt.xlabel("Time of arrival (time steps)")
plt.legend()
plt.suptitle("Time-of-arrival plot")
plt.show()
if show:
plt.show()

return plt.gcf()


#
Expand Down

0 comments on commit ff36719

Please sign in to comment.