Skip to content

Commit

Permalink
add: Solve min-cost-problem for better ES overlaps
Browse files Browse the repository at this point in the history
as originally described in #214.
fix: pysisplot -o was broken
  • Loading branch information
Johannes Steinmetzer committed Dec 20, 2023
1 parent 2417247 commit 8056442
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
37 changes: 29 additions & 8 deletions pysisyphus/calculators/OverlapCalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Plasser, 2016
# [2] https://doi.org/10.1002/jcc.25800
# Garcia, Campetella, 2019
# [3] https://doi.org/10.1021/acs.jctc.0c00295
# First Principles Nonadiabatic Excited-State Molecular Dynamics in NWChem
# Song, Fischer, Zhang, Cramer, Mukamel, Govind, Tretiak, 2020

from collections import namedtuple
from pathlib import Path, PosixPath
Expand All @@ -10,6 +13,7 @@

import h5py
import numpy as np
from scipy.optimize import linear_sum_assignment

from pysisyphus import logger
from pysisyphus.calculators.Calculator import Calculator
Expand Down Expand Up @@ -123,7 +127,8 @@ def __init__(
conf_thresh=1e-3,
# dyn_roots=0,
mos_ref="cur",
mos_renorm=True,
mos_renorm: bool = True,
min_cost: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -187,6 +192,7 @@ def __init__(
self.mos_ref = mos_ref
assert self.mos_ref in ("cur", "ref")
self.mos_renorm = bool(mos_renorm)
self.min_cost = bool(min_cost)

# assert self.ncore >= 0, "ncore must be a >= 0!"

Expand Down Expand Up @@ -764,7 +770,7 @@ def track_root(self, ovlp_type=None):
overlaps = self.get_nto_overlaps(S_AO=S_AO, org=True)
elif ovlp_type == "top":
top_rs = self.get_top_differences(S_AO=S_AO)
overlaps = 1 - top_rs
overlaps = 1.0 - top_rs
else:
raise Exception(
"Invalid overlap type key! Use one of " + ", ".join(self.VALID_KEYS)
Expand All @@ -790,13 +796,28 @@ def track_root(self, ovlp_type=None):
row_ind += 1
self.row_inds.append(row_ind)
self.ref_cycles.append(self.ref_cycle)
self.log(
f"Reference is cycle {self.ref_cycle}, root {ref_root}. "
f"Analyzing row {row_ind} of the overlap matrix."
)
self.log(f"Reference is cycle {self.ref_cycle}, root {ref_root}.")

# As described in [3].
# Code contributed by PT.
if self.min_cost:
# Match all excited state of the current and the reference step to make the
# assignment more reasonable and avoid any double assignments.
self.log("Assigned roots using Kuhn-Munkres algorithm.")
_, col_inds = linear_sum_assignment(-overlaps)
ref_root_row = overlaps[row_ind]
new_root = col_inds[row_ind]
if ref_root_row.argmax() != new_root:
self.log(
"The newly assigned root is not root of highest overlap because "
"another row mapping to the same state had a higher overlap!"
)
else:
# Match the best root row wise/just pick the root w/ the highest overlap.
self.log(f"Analyzing row {row_ind} of the overlap matrix.")
ref_root_row = overlaps[row_ind]
new_root = ref_root_row.argmax()

ref_root_row = overlaps[row_ind]
new_root = ref_root_row.argmax()
max_overlap = ref_root_row[new_root]
if self.ovlp_type == "wf":
new_root -= 1
Expand Down
23 changes: 13 additions & 10 deletions pysisyphus/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15):
results = load_h5(
h5_fn,
h5_group,
datasets=("energies", "forces",),
datasets=(
"energies",
"forces",
),
attrs=("is_cos", "coord_type", "max_force_thresh", "rms_force_thresh"),
)
cycles = len(results["energies"])
Expand All @@ -267,7 +270,7 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15):

last_axis = forces.ndim - 1
max_ = np.nanmax(np.abs(forces), axis=last_axis)
rms = np.sqrt(np.mean(forces ** 2, axis=last_axis))
rms = np.sqrt(np.mean(forces**2, axis=last_axis))
hei_indices = energies.argmax(axis=1)
force_unit = get_force_unit(coord_type)

Expand All @@ -277,15 +280,17 @@ def plot_cos_forces(h5_fn="optimization.h5", h5_group="opt", last=15):
cycle = last_cycles[i]
hei_max = max_[i, hei_index]
hei_rms = rms[i, hei_index]
print(f"\tCycle {cycle:03d}: max(forces)={hei_max:{fmt}}, rms(forces)={hei_rms:{fmt}}")
print(
f"\tCycle {cycle:03d}: max(forces)={hei_max:{fmt}}, rms(forces)={hei_rms:{fmt}}"
)

fig, (ax0, ax1) = plt.subplots(sharex=True, nrows=2)

def plot(ax, data, title):
num = data.shape[0]
alphas = np.linspace(0.125, 1, num=num)
colors = matplotlib.cm.Greys(np.linspace(0, 1, num=num))
colors[-1] = (1., 0., 0., 1.) # use red for latest cycle
colors[-1] = (1.0, 0.0, 0.0, 1.0) # use red for latest cycle
for row, color, alpha in zip(data, colors, alphas):
ax.plot(row, "o-", color=color, alpha=alpha)
ax.set_ylabel(force_unit)
Expand Down Expand Up @@ -525,8 +530,8 @@ def draw(i):
o = np.abs(overlaps[i])
ax.imshow(o, vmin=0, vmax=1)
ax.grid(color="#CCCCCC", linestyle="--", linewidth=1)
ax.set_xticks(np.arange(n_states, dtype=np.int))
ax.set_yticks(np.arange(n_states, dtype=np.int))
ax.set_xticks(np.arange(n_states, dtype=int))
ax.set_yticks(np.arange(n_states, dtype=int))
# set_ylim is needed, otherwise set_yticks drastically shrinks the plot
ax.set_ylim(n_states - 0.5, -0.5)
ax.set_xlabel("new roots")
Expand All @@ -546,7 +551,7 @@ def draw(i):
ref_overlaps = o[ref_ind]
argmax = np.nanargmax(ref_overlaps)
xy = (argmax - 0.5, ref_ind - 0.5)
highlight = Rectangle(xy, 1, 1, fill=False, color="red", lw="4")
highlight = Rectangle(xy, 1, 1, fill=False, color="red", lw=4)
ax.add_artist(highlight)
if ax1:
ax1.imshow(cdd_imgs[i])
Expand Down Expand Up @@ -638,7 +643,6 @@ def render_cdds(h5):


def plot_afir(h5_fn="afir.h5", h5_group="afir"):

h5_fns = (h5_fn, Path(OUT_DIR_DEFAULT) / h5_fn)
for h5_fn in h5_fns:
print(f"Trying to open '{h5_fn}' ... ", end="")
Expand All @@ -656,7 +660,6 @@ def plot_afir(h5_fn="afir.h5", h5_group="afir"):
print("file not found.")
continue


en_conv, en_unit = get_en_conv()
afir_ens *= en_conv
afir_ens -= afir_ens.min()
Expand Down Expand Up @@ -876,7 +879,7 @@ def plot_irc_h5(h5, title=None):
energies *= en_conv

cds = np.linalg.norm(mw_coords - mw_coords[0], axis=1)
rms_grads = np.sqrt(np.mean(gradients ** 2, axis=1))
rms_grads = np.sqrt(np.mean(gradients**2, axis=1))
max_grads = np.abs(gradients).max(axis=1)

fig, (ax0, ax1, ax2) = plt.subplots(nrows=3, sharex=True)
Expand Down

0 comments on commit 8056442

Please sign in to comment.