diff --git a/iodata/overlap.py b/iodata/overlap.py index 6e23f90e..a4a40ed2 100644 --- a/iodata/overlap.py +++ b/iodata/overlap.py @@ -28,7 +28,9 @@ from .basis import HORTON2_CONVENTIONS as OVERLAP_CONVENTIONS -__all__ = ['OVERLAP_CONVENTIONS', 'compute_overlap', 'gob_cart_normalization'] +__all__ = [ + 'OVERLAP_CONVENTIONS', 'compute_overlap', 'gob_cart_normalization', "convert_vector_basis" +] def compute_overlap(obasis: MolecularBasis, atcoords: np.ndarray) -> np.ndarray: @@ -159,3 +161,63 @@ def gob_cart_normalization(alpha: np.ndarray, n: np.ndarray) -> np.ndarray: vfac2 = np.vectorize(factorialk) return np.sqrt((4 * alpha)**sum(n) * (2 * alpha / np.pi)**1.5 / np.prod(vfac2(2 * n - 1, 2))) + + +def convert_vector_basis( + coeffs1: np.ndarray, + basis2_overlap: np.ndarray, + basis21_overlap: np.ndarray +) -> np.ndarray: + r""" + Convert a vector from basis 1 of size M to another basis 2 of size N. + + Basis vectors are defined to be linearly independent wrt to one another + and need not be orthogonal. + + Parameters + ---------- + coeffs1: + Coefficients of the vector expanded in basis set 1. Shape is (M,). + basis2_overlap: + Symmetric matrix whose entries are the inner-product between basis vectors + inside basis set 2. Shape is (N, N). + basis21_overlap: + The overlap matrix between basis set 1 and basis set 2. Shape is (N, M). + + Returns + ------- + coeffs2 : + Coefficients of the vector expanded in basis set 2. Shape is (N,). + + Raises + ------ + ValueError : + If shapes of the matrices don't match the requirements in the docs. + LinAlgError : + If least squared solution does not converge. + + Notes + ----- + - `basis2_overlap` is the matrix with (i, j)th entries + :math:`\braket{\psi_i, \psi_j}` where :math:`\psi_i` are in basis set 2. + + - `basis21_overlap` is the matrix whose (i, j)th entries are + :math:`\braket{\psi_i, \phi_j}` where :math:`\phi_j` are in basis set 1 and + :math:`\psi_i` are in basis set 2. + + - If `basis2_overlap` is not full rank, then least-squared solution is solved instead. + + """ + if basis2_overlap.shape[0] != basis2_overlap.shape[1]: + raise ValueError("The `basis2_overlap` should be a square matrix.") + if np.any(np.abs(basis2_overlap.T - basis2_overlap) > 1e-10): + raise ValueError("The `basis2_overlap` should be a symmetric matrix.") + + b = basis21_overlap.dot(coeffs1) + try: + # Try solving exact solution if `basis12_overlap` is full rank. + coeffs2 = np.linalg.solve(basis2_overlap, b) + except np.linalg.LinAlgError: + # Try least-squared solution. + coeffs2, _, _, _ = np.linalg.lstsq(basis2_overlap, b) + return coeffs2