Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PhononDosPlotter.plot_dos() add support for existing plt.Axes #3487

Merged
merged 4 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,7 @@ def make_supergraph(graph, multiplicity, periodicity_vectors):
connecting_edges.append((n1, n2, key, new_data))
else:
if not np.all(np.array(data["delta"]) == 0):
print(
"delta not equal to periodicity nor 0 ... : ",
n1,
n2,
key,
data["delta"],
data,
)
print("delta not equal to periodicity nor 0 ... : ", n1, n2, key, data["delta"], data)
input("Are we ok with this ?")
other_edges.append((n1, n2, key, data))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1129,22 +1129,22 @@ def _get_map(self, isite):
target_cns = [cg.coordination_number for cg in target_cgs]
for ii in range(min([len(maps_and_surfaces), self.max_nabundant])):
my_map_and_surface = maps_and_surfaces[order[ii]]
mymap = my_map_and_surface["map"]
cn = mymap[0]
my_map = my_map_and_surface["map"]
cn = my_map[0]
if cn not in target_cns or cn > 12 or cn == 0:
continue
all_conditions = [params[2] for params in my_map_and_surface["parameters_indices"]]
if self._additional_condition not in all_conditions:
continue
cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][mymap[0]][
mymap[1]
cg, cgdict = self.structure_environments.ce_list[self.structure_environments.sites_map[isite]][my_map[0]][
my_map[1]
].minimum_geometry(symmetry_measure_type=self._symmetry_measure_type)
if (
cg in self.target_environments
and cgdict["symmetry_measure"] <= self.max_csm
and cgdict["symmetry_measure"] < current_target_env_csm
):
current_map = mymap
current_map = my_map
current_target_env_csm = cgdict["symmetry_measure"]
if current_map is not None:
return current_map
Expand Down
4 changes: 2 additions & 2 deletions pymatgen/analysis/pourbaix_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,15 +473,15 @@ def __init__(
for entry in ion_entries:
ion_elts = list(set(entry.elements) - ELEMENTS_HO)
# TODO: the logic here for ion concentration setting is in two
# places, in PourbaixEntry and here, should be consolidated
# places, in PourbaixEntry and here, should be consolidated
if len(ion_elts) == 1:
entry.concentration = conc_dict[ion_elts[0].symbol] * entry.normalization_factor
elif len(ion_elts) > 1 and not entry.concentration:
raise ValueError("Elemental concentration not compatible with multi-element ions")

self._unprocessed_entries = solid_entries + ion_entries

if not len(solid_entries + ion_entries) == len(entries):
if len(solid_entries + ion_entries) != len(entries):
raise ValueError('All supplied entries must have a phase type of either "Solid" or "Ion"')

if self.filter_solids:
Expand Down
27 changes: 13 additions & 14 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
)
self.stack = stack
self.sigma = sigma
self._doses: dict = {}
self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {}

def add_dos(self, label: str, dos: PhononDos) -> None:
"""Adds a dos for plotting.
Expand Down Expand Up @@ -138,6 +138,7 @@ def get_plot(
ylim: float | None = None,
units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz",
legend: dict | None = None,
ax: Axes | None = None,
) -> Axes:
"""Get a matplotlib plot showing the DOS.

Expand All @@ -149,6 +150,8 @@ def get_plot(
legend: dict with legend options. For example, {"loc": "upper right"}
will place the legend in the upper right corner. Defaults to
{"fontsize": 30}.
ax (Axes): An existing axes object onto which the plot will be
added. If None, a new figure will be created.
"""
legend = legend or {"fontsize": 30}
unit = freq_units(units)
Expand All @@ -161,7 +164,7 @@ def get_plot(
y = None
all_densities = []
all_frequencies = []
ax = pretty_plot(12, 8)
ax = pretty_plot(12, 8, ax=ax)

# Note that this complicated processing of frequencies is to allow for
# stacked plots in matplotlib.
Expand Down Expand Up @@ -516,30 +519,27 @@ def show(
"""Show the plot using matplotlib.

Args:
ylim: Specify the y-axis (frequency) limits; by default None let
the code choose.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
ylim (float): Specifies the y-axis limits.
units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies.
"""
self.get_plot(ylim, units=units)
plt.show()

def save_plot(
self,
filename: str | PathLike,
img_format: str = "eps",
ylim: float | None = None,
units: Literal["thz", "ev", "mev", "ha", "cm-1", "cm^-1"] = "thz",
) -> None:
"""Save matplotlib plot to a file.

Args:
filename: Filename to write to.
img_format: Image format to use. Defaults to EPS.
ylim: Specifies the y-axis limits.
units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
filename (str | Path): Filename to write to.
ylim (float): Specifies the y-axis limits.
units ("thz" | "ev" | "mev" | "ha" | "cm-1" | "cm^-1"): units for the frequencies.
"""
self.get_plot(ylim=ylim, units=units)
plt.savefig(filename, format=img_format)
plt.savefig(filename)
plt.close()

def show_proj(
Expand Down Expand Up @@ -598,9 +598,8 @@ def get_ticks(self) -> dict[str, list]:
elif point.label.startswith("\\") or point.label.find("_") != -1:
tick_labels.append(f"${point.label}$")
else:
label = point.label
if label == "GAMMA":
label = r"$\Gamma$"
# map atomate2 all-upper-case point.labels to pretty LaTeX
label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label)
tick_labels.append(label)
previous_label = point.label
previous_branch = this_branch
Expand Down
Loading