diff --git a/pysisyphus/calculators/OverlapCalculator.py b/pysisyphus/calculators/OverlapCalculator.py index 4fbc0db2f..829e0271c 100644 --- a/pysisyphus/calculators/OverlapCalculator.py +++ b/pysisyphus/calculators/OverlapCalculator.py @@ -2,6 +2,9 @@ # Plasser, 2016 # [2] https://doi.org/10.1002/jcc.25800 # Garcia, Campetella, 2019 +# [3] https://doi.org/10.1021/acs.jctc.0c00295 +# First Principles Nonadiabatic Excited-State Molecular Dynamics in NWChem +# Song, Fischer, Zhang, Cramer, Mukamel, Govind, Tretiak, 2020 from collections import namedtuple from pathlib import Path, PosixPath @@ -10,6 +13,7 @@ import h5py import numpy as np +from scipy.optimize import linear_sum_assignment from pysisyphus import logger from pysisyphus.calculators.Calculator import Calculator @@ -123,7 +127,8 @@ def __init__( conf_thresh=1e-3, # dyn_roots=0, mos_ref="cur", - mos_renorm=True, + mos_renorm: bool = True, + min_cost: bool = False, **kwargs, ): super().__init__(*args, **kwargs) @@ -187,6 +192,7 @@ def __init__( self.mos_ref = mos_ref assert self.mos_ref in ("cur", "ref") self.mos_renorm = bool(mos_renorm) + self.min_cost = bool(min_cost) # assert self.ncore >= 0, "ncore must be a >= 0!" @@ -764,7 +770,7 @@ def track_root(self, ovlp_type=None): overlaps = self.get_nto_overlaps(S_AO=S_AO, org=True) elif ovlp_type == "top": top_rs = self.get_top_differences(S_AO=S_AO) - overlaps = 1 - top_rs + overlaps = 1.0 - top_rs else: raise Exception( "Invalid overlap type key! Use one of " + ", ".join(self.VALID_KEYS) @@ -790,13 +796,28 @@ def track_root(self, ovlp_type=None): row_ind += 1 self.row_inds.append(row_ind) self.ref_cycles.append(self.ref_cycle) - self.log( - f"Reference is cycle {self.ref_cycle}, root {ref_root}. " - f"Analyzing row {row_ind} of the overlap matrix." - ) + self.log(f"Reference is cycle {self.ref_cycle}, root {ref_root}.") + + # As described in [3]. + # Code contributed by PT. + if self.min_cost: + # Match all excited state of the current and the reference step to make the + # assignment more reasonable and avoid any double assignments. + self.log("Assigned roots using Kuhn-Munkres algorithm.") + _, col_inds = linear_sum_assignment(-overlaps) + ref_root_row = overlaps[row_ind] + new_root = col_inds[row_ind] + if ref_root_row.argmax() != new_root: + self.log( + "The newly assigned root is not root of highest overlap because " + "another row mapping to the same state had a higher overlap!" + ) + else: + # Match the best root row wise/just pick the root w/ the highest overlap. + self.log(f"Analyzing row {row_ind} of the overlap matrix.") + ref_root_row = overlaps[row_ind] + new_root = ref_root_row.argmax() - ref_root_row = overlaps[row_ind] - new_root = ref_root_row.argmax() max_overlap = ref_root_row[new_root] if self.ovlp_type == "wf": new_root -= 1 diff --git a/pysisyphus/plot.py b/pysisyphus/plot.py index 827884b62..97dd23bf3 100644 --- a/pysisyphus/plot.py +++ b/pysisyphus/plot.py @@ -254,7 +254,10 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15): results = load_h5( h5_fn, h5_group, - datasets=("energies", "forces",), + datasets=( + "energies", + "forces", + ), attrs=("is_cos", "coord_type", "max_force_thresh", "rms_force_thresh"), ) cycles = len(results["energies"]) @@ -267,7 +270,7 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15): last_axis = forces.ndim - 1 max_ = np.nanmax(np.abs(forces), axis=last_axis) - rms = np.sqrt(np.mean(forces ** 2, axis=last_axis)) + rms = np.sqrt(np.mean(forces**2, axis=last_axis)) hei_indices = energies.argmax(axis=1) force_unit = get_force_unit(coord_type) @@ -277,7 +280,9 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15): cycle = last_cycles[i] hei_max = max_[i, hei_index] hei_rms = rms[i, hei_index] - print(f"\tCycle {cycle:03d}: max(forces)={hei_max:{fmt}}, rms(forces)={hei_rms:{fmt}}") + print( + f"\tCycle {cycle:03d}: max(forces)={hei_max:{fmt}}, rms(forces)={hei_rms:{fmt}}" + ) fig, (ax0, ax1) = plt.subplots(sharex=True, nrows=2) @@ -285,7 +290,7 @@ def plot(ax, data, title): num = data.shape[0] alphas = np.linspace(0.125, 1, num=num) colors = matplotlib.cm.Greys(np.linspace(0, 1, num=num)) - colors[-1] = (1., 0., 0., 1.) # use red for latest cycle + colors[-1] = (1.0, 0.0, 0.0, 1.0) # use red for latest cycle for row, color, alpha in zip(data, colors, alphas): ax.plot(row, "o-", color=color, alpha=alpha) ax.set_ylabel(force_unit) @@ -525,8 +530,8 @@ def draw(i): o = np.abs(overlaps[i]) ax.imshow(o, vmin=0, vmax=1) ax.grid(color="#CCCCCC", linestyle="--", linewidth=1) - ax.set_xticks(np.arange(n_states, dtype=np.int)) - ax.set_yticks(np.arange(n_states, dtype=np.int)) + ax.set_xticks(np.arange(n_states, dtype=int)) + ax.set_yticks(np.arange(n_states, dtype=int)) # set_ylim is needed, otherwise set_yticks drastically shrinks the plot ax.set_ylim(n_states - 0.5, -0.5) ax.set_xlabel("new roots") @@ -546,7 +551,7 @@ def draw(i): ref_overlaps = o[ref_ind] argmax = np.nanargmax(ref_overlaps) xy = (argmax - 0.5, ref_ind - 0.5) - highlight = Rectangle(xy, 1, 1, fill=False, color="red", lw="4") + highlight = Rectangle(xy, 1, 1, fill=False, color="red", lw=4) ax.add_artist(highlight) if ax1: ax1.imshow(cdd_imgs[i]) @@ -638,7 +643,6 @@ def render_cdds(h5): def plot_afir(h5_fn="afir.h5", h5_group="afir"): - h5_fns = (h5_fn, Path(OUT_DIR_DEFAULT) / h5_fn) for h5_fn in h5_fns: print(f"Trying to open '{h5_fn}' ... ", end="") @@ -656,7 +660,6 @@ def plot_afir(h5_fn="afir.h5", h5_group="afir"): print("file not found.") continue - en_conv, en_unit = get_en_conv() afir_ens *= en_conv afir_ens -= afir_ens.min() @@ -876,7 +879,7 @@ def plot_irc_h5(h5, title=None): energies *= en_conv cds = np.linalg.norm(mw_coords - mw_coords[0], axis=1) - rms_grads = np.sqrt(np.mean(gradients ** 2, axis=1)) + rms_grads = np.sqrt(np.mean(gradients**2, axis=1)) max_grads = np.abs(gradients).max(axis=1) fig, (ax0, ax1, ax2) = plt.subplots(nrows=3, sharex=True)