Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifies QuickPlot.plot to take a list of times and show superimposed plots in case of 1D variables. #4529

Merged
merged 15 commits into from
Nov 22, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- Added phase-dependent particle options to LAM ([#4369](https://github.com/pybamm-team/PyBaMM/pull/4369))
- Added a lithium ion equivalent circuit model with split open circuit voltages for each electrode (`SplitOCVR`). ([#4330](https://github.com/pybamm-team/PyBaMM/pull/4330))
- Added the `pybamm.DiscreteTimeSum` expression node to sum an expression over a sequence of data times, and accompanying `pybamm.DiscreteTimeData` class to store the data times and values ([#4501](https://github.com/pybamm-team/PyBaMM/pull/4501))
- Modified `quick_plot.plot` to accept a list of times and generate superimposed graphs for specified time points.([#4529](https://github.com/pybamm-team/PyBaMM/pull/4529))
medha-14 marked this conversation as resolved.
Show resolved Hide resolved

## Optimizations

Expand Down
211 changes: 95 additions & 116 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def set_output_variables(self, output_variables, solutions):
if variable.domain != domain:
raise ValueError(
"Mismatching variable domains. "
f"'{variable_tuple[0]}' has domain '{domain}', but '{variable_tuple[idx]}' has domain '{variable.domain}'"
f"'{variable_tuple[0]}' has domain '{domain}', but '{variable_tuple[ idx]}' has domain '{variable.domain}'"
medha-14 marked this conversation as resolved.
Show resolved Hide resolved
)
self.spatial_variable_dict[variable_tuple] = {}

Expand Down Expand Up @@ -479,24 +479,24 @@ def reset_axis(self):
): # pragma: no cover
raise ValueError(f"Axis limits cannot be NaN for variables '{key}'")

def plot(self, t, dynamic=False):
def plot(self, t: float | list[float], dynamic: bool = False):
medha-14 marked this conversation as resolved.
Show resolved Hide resolved
"""Produces a quick plot with the internal states at time t.

Parameters
----------
t : float
Dimensional time (in 'time_units') at which to plot.
t : float or list of float
Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
agriyakhetarpal marked this conversation as resolved.
Show resolved Hide resolved
dynamic : bool, optional
Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
If True, creates a dynamic plot with a slider.
"""

plt = import_optional_dependency("matplotlib.pyplot")
gridspec = import_optional_dependency("matplotlib.gridspec")
cm = import_optional_dependency("matplotlib", "cm")
colors = import_optional_dependency("matplotlib", "colors")

t_in_seconds = t * self.time_scaling_factor
if not isinstance(t, list):
t = [t]

self.fig = plt.figure(figsize=self.figsize)

self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols)
Expand All @@ -508,6 +508,11 @@ def plot(self, t, dynamic=False):
# initialize empty handles, to be created only if the appropriate plots are made
solution_handles = []

# Generate distinct colors for each time point
time_colors = plt.cm.coolwarm(
np.linspace(0, 1, len(t))
) # Use a colormap for distinct colors

for k, (key, variable_lists) in enumerate(self.variables.items()):
ax = self.fig.add_subplot(self.gridspec[k])
self.axes.add(key, ax)
Expand All @@ -518,19 +523,17 @@ def plot(self, t, dynamic=False):
ax.xaxis.set_major_locator(plt.MaxNLocator(3))
self.plots[key] = defaultdict(dict)
variable_handles = []
# Set labels for the first subplot only (avoid repetition)

if variable_lists[0][0].dimensions == 0:
# 0D plot: plot as a function of time, indicating time t with a line
# 0D plot: plot as a function of time, indicating multiple times with lines
ax.set_xlabel(f"Time [{self.time_unit}]")
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
full_t = self.ts_seconds[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
Expand All @@ -542,128 +545,104 @@ def plot(self, t, dynamic=False):
solution_handles.append(self.plots[key][i][0])
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max)
(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"k--",
lw=1.5,
)

# Add vertical lines for each time in the list, using different colors for each time
for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor
(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"--", # Dashed lines
lw=1.5,
color=time_colors[idx], # Different color for each time
label=f"t = {t_single:.2f} {self.time_unit}",
)
ax.legend()

elif variable_lists[0][0].dimensions == 1:
# 1D plot: plot as a function of x at time t
# Read dictionary of spatial variables
# 1D plot: plot as a function of x at different times
spatial_vars = self.spatial_variable_dict[key]
spatial_var_name = next(iter(spatial_vars.keys()))
ax.set_xlabel(
f"{spatial_var_name} [{self.spatial_unit}]",
)
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=self.colors[i],
linestyle=linestyle,
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])
# add lines for boundaries between subdomains
for boundary in variable_lists[0][0].internal_boundaries:
boundary_scaled = boundary * self.spatial_factor
ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0)
ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]")

for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor

for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=time_colors[idx], # Different color for each time
linestyle=linestyle,
label=f"t = {t_single:.2f} {self.time_unit}", # Add time label
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])

# Add a legend to indicate which plot corresponds to which time
ax.legend()

elif variable_lists[0][0].dimensions == 2:
# Read dictionary of spatial variables
# 2D plot: superimpose plots at different times
spatial_vars = self.spatial_variable_dict[key]
# there can only be one entry in the variable list
variable = variable_lists[0][0]
# different order based on whether the domains are x-r, x-z or y-z, etc
if self.x_first_and_y_second[key] is False:
x_name = list(spatial_vars.keys())[1][0]
y_name = next(iter(spatial_vars.keys()))[0]
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars)
else:
x_name = next(iter(spatial_vars.keys()))[0]
y_name = list(spatial_vars.keys())[1][0]

for t_single in t:
t_in_seconds = t_single * self.time_scaling_factor
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars).T
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
vmin, vmax = self.variable_limits[key]
# store the plot and the var data (for testing) as cant access
# z data from QuadMesh or QuadContourSet object
if self.is_y_z[key] is True:
self.plots[key][0][0] = ax.pcolormesh(
x,
y,
var,
vmin=vmin,
vmax=vmax,
shading=self.shading,

ax.set_xlabel(
f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]"
)
else:
self.plots[key][0][0] = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax
ax.set_ylabel(
f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]"
)
self.plots[key][0][1] = var
if vmin is None and vmax is None:
vmin = ax_min(var)
vmax = ax_max(var)
self.colorbars[key] = self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)),
ax=ax,
)
# Set either y label or legend entries
if len(key) == 1:
title = split_long_string(key[0])
ax.set_title(title, fontsize="medium")
else:
ax.legend(
variable_handles,
[split_long_string(s, 6) for s in key],
bbox_to_anchor=(0.5, 1),
loc="lower center",
)
vmin, vmax = self.variable_limits[key]

# Use contourf and colorbars to represent the values
contour_plot = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
)
self.plots[key][0][0] = contour_plot
self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax)

# Set global legend
self.plots[key][0][1] = var

ax.set_title(f"t = {t_single:.2f} {self.time_unit}")

# Set global legend if there are multiple models
if len(self.labels) > 1:
fig_legend = self.fig.legend(
solution_handles, self.labels, loc="lower right"
)
# Get the position of the top of the legend in relative figure units
# There may be a better way ...
try:
legend_top_inches = fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1]
legend_top = legend_top_inches / fig_height_inches
except AttributeError: # pragma: no cover
# When testing the examples we set the matplotlib backend to "Template"
# which means that the above code doesn't work. Since this is just for
# that particular test we can just skip it
legend_top = 0
else:
legend_top = 0
fig_legend = None

# Fix layout
# Fix layout for sliders if dynamic
if dynamic:
slider_top = 0.05
else:
slider_top = 0
bottom = max(legend_top, slider_top)
bottom = max(
fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
if fig_legend
else 0,
slider_top,
)
self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1])

def dynamic_plot(self, show_plot=True, step=None):
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,41 @@ def test_simple_ode_model(self, solver):
solution, ["a", "b broadcasted"], variable_limits="bad variable limits"
)

# Test with a list of times
# Test for a 0D variable
quick_plot = pybamm.QuickPlot(solution, ["a"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("a",)]) == 1

# Test for a 1D variable
quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"])

time_list = [0.5, 1.5, 2.5]
quick_plot.plot(time_list)

ax = quick_plot.fig.axes[0]
lines = ax.get_lines()

variable_key = ("c broadcasted",)
if variable_key in quick_plot.variables:
num_variables = len(quick_plot.variables[variable_key][0])
else:
raise KeyError(
f"'{variable_key}' is not in quick_plot.variables. Available keys: "
+ str(quick_plot.variables.keys())
)

expected_lines = len(time_list) * num_variables

assert (
len(lines) == expected_lines
), f"Expected {expected_lines} superimposed lines, but got {len(lines)}"

# Test for a 2D variable
quick_plot = pybamm.QuickPlot(solution, ["2D variable"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("2D variable",)]) == 1

# Test errors
with pytest.raises(ValueError, match="Mismatching variable domains"):
pybamm.QuickPlot(solution, [["a", "b broadcasted"]])
Expand Down
Loading