From 5d6d4238b2fffc37842f1ba50a28bb94b96b1f60 Mon Sep 17 00:00:00 2001
From: marvinfriede <51965259+marvinfriede@users.noreply.github.com>
Date: Wed, 17 Jan 2024 11:26:10 +0100
Subject: [PATCH 1/7] Refactor
---
examples/eeq-batch.py | 2 +-
examples/eeq-single.py | 4 +-
src/tad_multicharge/__init__.py | 14 +-
src/tad_multicharge/model/__init__.py | 26 ++
.../{model.py => model/base.py} | 42 ++-
src/tad_multicharge/{ => model}/eeq.py | 255 +++++++++---------
test/test_charge/test_charges.py | 10 +-
test/test_charge/test_general.py | 36 ++-
test/test_grad/test_charge.py | 6 +-
test/test_grad/test_dedr.py | 2 +-
test/test_grad/test_dqdr.py | 2 +-
test/test_grad/test_param.py | 6 +-
12 files changed, 245 insertions(+), 160 deletions(-)
create mode 100644 src/tad_multicharge/model/__init__.py
rename src/tad_multicharge/{model.py => model/base.py} (65%)
rename src/tad_multicharge/{ => model}/eeq.py (57%)
diff --git a/examples/eeq-batch.py b/examples/eeq-batch.py
index a341b87..e1ac1ef 100644
--- a/examples/eeq-batch.py
+++ b/examples/eeq-batch.py
@@ -3,7 +3,7 @@
from tad_mctc.batch import pack
from tad_mctc.convert import symbol_to_number
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
# S22 system 4: formamide dimer
numbers = pack(
diff --git a/examples/eeq-single.py b/examples/eeq-single.py
index 494e49f..eac14ac 100644
--- a/examples/eeq-single.py
+++ b/examples/eeq-single.py
@@ -1,7 +1,7 @@
# SPDX-Identifier: CC0-1.0
import torch
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
numbers = torch.tensor([7, 7, 1, 1, 1, 1, 1, 1])
@@ -23,7 +23,7 @@
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
eeq_model = eeq.EEQModel.param2019()
-energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
print(torch.sum(energy, -1))
# tensor(-0.1750)
diff --git a/src/tad_multicharge/__init__.py b/src/tad_multicharge/__init__.py
index 8e5a8dc..590acc6 100644
--- a/src/tad_multicharge/__init__.py
+++ b/src/tad_multicharge/__init__.py
@@ -70,18 +70,16 @@
>>> # total charge of both systems
>>> charge = torch.tensor([0.0, 0.0])
>>>
->>> # calculate dispersion energy in Hartree
->>> energy = torch.sum(d4.dftd4(numbers, positions, charge, param), -1)
+>>> # calculate electrostatic energy in Hartree
+>>> energy = torch.sum(eeq.get_energy(numbers, positions, charge), -1)
>>>
>>> torch.set_printoptions(precision=10)
>>> print(energy)
-tensor([-0.0088341432, -0.0027013607])
->>> print(energy[0] - 2*energy[1])
-tensor(-0.0034314217)
+>>> # tensor([-0.2086755037, -0.0972094536])
+>>> print(energy[0] - 2 * energy[1])
+>>> # tensor(-0.0142565966)
"""
import torch
-from . import eeq, model
from .__version__ import __version__
-from .eeq import get_charges as get_eeq_charges
-from .eeq import get_eeq
+from .model.eeq import get_charges as get_eeq_charges
diff --git a/src/tad_multicharge/model/__init__.py b/src/tad_multicharge/model/__init__.py
new file mode 100644
index 0000000..5ed0e88
--- /dev/null
+++ b/src/tad_multicharge/model/__init__.py
@@ -0,0 +1,26 @@
+# This file is part of tad-multicharge.
+#
+# SPDX-Identifier: LGPL-3.0
+# Copyright (C) 2023 Marvin Friede
+#
+# tad-multicharge is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# tad-multicharge is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with tad-multicharge. If not, see .
+"""
+Model
+=====
+
+This module contains all available charge models. Currently, only the
+electronegativity equilibration model (EEQ) is implemented.
+"""
+from .base import *
+from .eeq import *
diff --git a/src/tad_multicharge/model.py b/src/tad_multicharge/model/base.py
similarity index 65%
rename from src/tad_multicharge/model.py
rename to src/tad_multicharge/model/base.py
index 1f405bd..639d51c 100644
--- a/src/tad_multicharge/model.py
+++ b/src/tad_multicharge/model/base.py
@@ -16,17 +16,17 @@
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see .
"""
-Charge Model
-============
+Model: Base Charge Model
+========================
-Implementation of the electronegativity equlibration model for obtaining
-atomic partial charges as well as atom-resolved electrostatic energies.
+Implementation of a base class for charge models.
"""
from __future__ import annotations
import torch
+from abc import abstractmethod
-from .typing import Tensor, TensorLike
+from ..typing import Tensor, TensorLike
__all__ = ["ChargeModel"]
@@ -76,3 +76,35 @@ def __init__(
for tensor in (self.chi, self.kcn, self.eta, self.rad)
):
raise RuntimeError("All tensors must have the same dtype!")
+
+ @abstractmethod
+ def solve(
+ self,
+ numbers: Tensor,
+ positions: Tensor,
+ total_charge: Tensor,
+ cn: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """
+ Solve the electronegativity equilibration for the partial charges
+ minimizing the electrostatic energy.
+
+ Parameters
+ ----------
+ numbers : Tensor
+ Atomic numbers of all atoms in the system.
+ positions : Tensor
+ Cartesian coordinates of the atoms in the system (batch, natoms, 3).
+ total_charge : Tensor
+ Total charge of the system.
+ model : ChargeModel
+ Charge model to use.
+ cn : Tensor
+ Coordination numbers for all atoms in the system.
+
+ Returns
+ -------
+ (Tensor, Tensor)
+ Tuple of electrostatic energies and partial charges.
+ """
+ ...
diff --git a/src/tad_multicharge/eeq.py b/src/tad_multicharge/model/eeq.py
similarity index 57%
rename from src/tad_multicharge/eeq.py
rename to src/tad_multicharge/model/eeq.py
index a4e19ef..7b11140 100644
--- a/src/tad_multicharge/eeq.py
+++ b/src/tad_multicharge/model/eeq.py
@@ -40,7 +40,7 @@
>>> total_charge = torch.tensor(0.0)
>>> cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
>>> eeq_model = eeq.EEQModel.param2019()
->>> energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+>>> energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
>>> print(torch.sum(energy, -1))
tensor(-0.1750)
>>> print(qat)
@@ -55,12 +55,12 @@
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import cn_eeq, erf_count
-from . import defaults
-from .model import ChargeModel
-from .param import eeq2019
-from .typing import DD, Any, CountingFunction, Tensor, get_default_dtype
+from .. import defaults
+from .base import ChargeModel
+from ..param import eeq2019
+from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
-__all__ = ["EEQModel", "solve", "get_charges"]
+__all__ = ["EEQModel", "get_charges"]
class EEQModel(ChargeModel):
@@ -106,135 +106,136 @@ def param2019(
**dd,
)
+ def solve(
+ self,
+ numbers: Tensor,
+ positions: Tensor,
+ total_charge: Tensor,
+ cn: Tensor,
+ ) -> tuple[Tensor, Tensor]:
+ """
+ Solve the electronegativity equilibration for the partial charges
+ minimizing the electrostatic energy.
-def solve(
- numbers: Tensor,
- positions: Tensor,
- total_charge: Tensor,
- model: ChargeModel,
- cn: Tensor,
-) -> tuple[Tensor, Tensor]:
- """
- Solve the electronegativity equilibration for the partial charges minimizing
- the electrostatic energy.
-
- Parameters
- ----------
- numbers : Tensor
- Atomic numbers of all atoms in the system.
- positions : Tensor
- Cartesian coordinates of the atoms in the system (batch, natoms, 3).
- total_charge : Tensor
- Total charge of the system.
- model : ChargeModel
- Charge model to use.
- cn : Tensor
- Coordination numbers for all atoms in the system.
-
- Returns
- -------
- (Tensor, Tensor)
- Tuple of electrostatic energies and partial charges.
+ Parameters
+ ----------
+ numbers : Tensor
+ Atomic numbers of all atoms in the system.
+ positions : Tensor
+ Cartesian coordinates of the atoms in the system (batch, natoms, 3).
+ total_charge : Tensor
+ Total charge of the system.
+ model : ChargeModel
+ Charge model to use.
+ cn : Tensor
+ Coordination numbers for all atoms in the system.
- Example
- -------
- >>> import torch
- >>> from tad_multicharge import eeq
- >>> numbers = torch.tensor([7, 1, 1, 1])
- >>> positions = torch.tensor([
- ... [+0.00000000000000, +0.00000000000000, -0.54524837997150],
- ... [-0.88451840382282, +1.53203081565085, +0.18174945999050],
- ... [-0.88451840382282, -1.53203081565085, +0.18174945999050],
- ... [+1.76903680764564, +0.00000000000000, +0.18174945999050],
- ... ], requires_grad=True)
- >>> total_charge = torch.tensor(0.0, requires_grad=True)
- >>> cn = torch.tensor([3.0, 1.0, 1.0, 1.0])
- >>> eeq_model = eeq.EEQModel.param2019()
- >>> energy = torch.sum(eeq.solve(numbers, positions, total_charge, eeq_model, cn)[0], -1)
- >>> energy.backward()
- >>> print(positions.grad)
- tensor([[-9.3132e-09, 7.4506e-09, -4.8064e-02],
- [-1.2595e-02, 2.1816e-02, 1.6021e-02],
- [-1.2595e-02, -2.1816e-02, 1.6021e-02],
- [ 2.5191e-02, -6.9849e-10, 1.6021e-02]])
- >>> print(total_charge.grad)
- tensor(0.6312)
- """
- dd: DD = {"device": positions.device, "dtype": positions.dtype}
+ Returns
+ -------
+ (Tensor, Tensor)
+ Tuple of electrostatic energies and partial charges.
- if model.device != positions.device:
- name = model.__class__.__name__
- raise RuntimeError(
- f"All tensors of '{name}' must be on the same device!\n"
- f"Use `{name}.param2019(device=device)` to correctly set the it."
+ Example
+ -------
+ >>> import torch
+ >>> from tad_multicharge import eeq
+ >>> numbers = torch.tensor([7, 1, 1, 1])
+ >>> positions = torch.tensor([
+ ... [+0.00000000000000, +0.00000000000000, -0.54524837997150],
+ ... [-0.88451840382282, +1.53203081565085, +0.18174945999050],
+ ... [-0.88451840382282, -1.53203081565085, +0.18174945999050],
+ ... [+1.76903680764564, +0.00000000000000, +0.18174945999050],
+ ... ], requires_grad=True)
+ >>> total_charge = torch.tensor(0.0, requires_grad=True)
+ >>> cn = torch.tensor([3.0, 1.0, 1.0, 1.0])
+ >>> eeq_model = eeq.EEQModel.param2019()
+ >>> e = eeq_model.solve(numbers, positions, total_charge, cn)[0]
+ >>> energy = torch.sum(e, -1)
+ >>> energy.backward()
+ >>> print(positions.grad)
+ tensor([[-9.3132e-09, 7.4506e-09, -4.8064e-02],
+ [-1.2595e-02, 2.1816e-02, 1.6021e-02],
+ [-1.2595e-02, -2.1816e-02, 1.6021e-02],
+ [ 2.5191e-02, -6.9849e-10, 1.6021e-02]])
+ >>> print(total_charge.grad)
+ tensor(0.6312)
+ """
+ if self.device != positions.device:
+ name = self.__class__.__name__
+ raise RuntimeError(
+ f"All tensors of '{name}' must be on the same device!\n"
+ f"Use `{name}.param2019(device=device)` to correctly set it."
+ )
+
+ if self.dtype != positions.dtype:
+ name = self.__class__.__name__
+ raise RuntimeError(
+ f"All tensors of '{name}' must have the same dtype!\n"
+ f"Use `{name}.param2019(dtype=dtype)` to correctly set it."
+ )
+
+ eps = torch.tensor(torch.finfo(positions.dtype).eps, **self.dd)
+ zero = torch.tensor(0.0, **self.dd)
+ stop = torch.sqrt(torch.tensor(2.0 / math.pi, **self.dd)) # sqrt(2/pi)
+
+ real = real_atoms(numbers)
+ mask = real_pairs(numbers, mask_diagonal=True)
+
+ distances = torch.where(
+ mask,
+ storch.cdist(positions, positions, p=2),
+ eps,
)
+ diagonal = mask.new_zeros(mask.shape)
+ diagonal.diagonal(dim1=-2, dim2=-1).fill_(True)
- if model.dtype != positions.dtype:
- name = model.__class__.__name__
- raise RuntimeError(
- f"All tensors of '{name}' must have the same dtype!\n"
- f"Use `{name}.param2019(dtype=dtype)` to correctly set it."
+ cc = torch.where(
+ real,
+ -self.chi[numbers] + storch.sqrt(cn) * self.kcn[numbers],
+ zero,
+ )
+ rhs = torch.concat((cc, total_charge.unsqueeze(-1)), dim=-1)
+
+ # radii
+ rad = self.rad[numbers]
+ rads = rad.unsqueeze(-1) ** 2 + rad.unsqueeze(-2) ** 2
+ gamma = torch.where(mask, 1.0 / storch.sqrt(rads), zero)
+
+ # hardness
+ eta = torch.where(
+ real,
+ self.eta[numbers] + stop / rad,
+ torch.tensor(1.0, **self.dd),
)
- eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
- zero = torch.tensor(0.0, **dd)
- stop = torch.sqrt(torch.tensor(2.0 / math.pi, **dd)) # sqrt(2/pi)
-
- real = real_atoms(numbers)
- mask = real_pairs(numbers, mask_diagonal=True)
-
- distances = torch.where(mask, storch.cdist(positions, positions, p=2), eps)
- diagonal = mask.new_zeros(mask.shape)
- diagonal.diagonal(dim1=-2, dim2=-1).fill_(True)
-
- cn_sqrt = torch.sqrt(torch.clamp(cn, min=eps))
- cc = torch.where(
- real,
- -model.chi[numbers] + cn_sqrt * model.kcn[numbers],
- zero,
- )
- rhs = torch.concat((cc, total_charge.unsqueeze(-1)), dim=-1)
-
- # radii
- rad = model.rad[numbers]
- rads = torch.clamp(rad.unsqueeze(-1) ** 2 + rad.unsqueeze(-2) ** 2, min=eps)
- gamma = torch.where(mask, 1.0 / torch.sqrt(rads), zero)
-
- # hardness
- eta = torch.where(
- real,
- model.eta[numbers] + stop / rad,
- torch.tensor(1.0, **dd),
- )
-
- coulomb = torch.where(
- diagonal,
- eta.unsqueeze(-1),
- torch.where(
- mask,
- torch.erf(distances * gamma) / distances,
- zero,
- ),
- )
+ coulomb = torch.where(
+ diagonal,
+ eta.unsqueeze(-1),
+ torch.where(
+ mask,
+ torch.erf(distances * gamma) / distances,
+ zero,
+ ),
+ )
- constraint = torch.where(
- real,
- torch.ones(numbers.shape, **dd),
- torch.zeros(numbers.shape, **dd),
- )
- zeros = torch.zeros(numbers.shape[:-1], **dd)
-
- matrix = torch.concat(
- (
- torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1),
- torch.concat((constraint, zeros.unsqueeze(-1)), dim=-1).unsqueeze(-2),
- ),
- dim=-2,
- )
+ constraint = torch.where(
+ real,
+ torch.ones(numbers.shape, **self.dd),
+ torch.zeros(numbers.shape, **self.dd),
+ )
+ zeros = torch.zeros(numbers.shape[:-1], **self.dd)
+
+ matrix = torch.concat(
+ (
+ torch.concat((coulomb, constraint.unsqueeze(-1)), dim=-1),
+ torch.concat((constraint, zeros.unsqueeze(-1)), dim=-1).unsqueeze(-2),
+ ),
+ dim=-2,
+ )
- x = torch.linalg.solve(matrix, rhs)
- e = x * (0.5 * torch.einsum("...ij,...j->...i", matrix, x) - rhs)
- return e[..., :-1], x[..., :-1]
+ x = torch.linalg.solve(matrix, rhs)
+ e = x * (0.5 * torch.einsum("...ij,...j->...i", matrix, x) - rhs)
+ return e[..., :-1], x[..., :-1]
def get_eeq(
@@ -287,7 +288,7 @@ def get_eeq(
kcn=kcn,
**kwargs,
)
- return solve(numbers, positions, chrg, eeq, cn)
+ return eeq.solve(numbers, positions, chrg, cn)
def get_charges(
diff --git a/test/test_charge/test_charges.py b/test/test_charge/test_charges.py
index 8fde772..dfc910d 100644
--- a/test/test_charge/test_charges.py
+++ b/test/test_charge/test_charges.py
@@ -42,7 +42,7 @@
from tad_mctc.batch import pack
from tad_mctc.ncoord import cn_eeq
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
from tad_multicharge.typing import DD
from ..conftest import DEVICE
@@ -64,7 +64,7 @@ def test_single(dtype: torch.dtype) -> None:
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], **dd)
eeq_model = eeq.EEQModel.param2019(**dd)
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
tot = torch.sum(qat, -1)
assert qat.dtype == energy.dtype == dtype
@@ -89,7 +89,7 @@ def test_single_with_cn(dtype: torch.dtype, name: str) -> None:
cn = cn_eeq(numbers, positions)
eeq_model = eeq.EEQModel.param2019(**dd)
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
tot = torch.sum(qat, -1)
assert qat.dtype == energy.dtype == dtype
@@ -138,7 +138,7 @@ def test_ghost(dtype: torch.dtype) -> None:
)
eeq_model = eeq.EEQModel.param2019(**dd)
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
tot = torch.sum(qat, -1)
assert qat.dtype == energy.dtype == dtype
@@ -228,7 +228,7 @@ def test_batch(dtype: torch.dtype) -> None:
**dd,
)
eeq_model = eeq.EEQModel.param2019(**dd)
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
tot = torch.sum(qat, -1)
assert qat.dtype == energy.dtype == dtype
diff --git a/test/test_charge/test_general.py b/test/test_charge/test_general.py
index c87c34e..c696179 100644
--- a/test/test_charge/test_general.py
+++ b/test/test_charge/test_general.py
@@ -30,12 +30,12 @@
precision. For double precision, however the results are identical.
"""
from __future__ import annotations
-
import pytest
import torch
from tad_mctc.convert import str_to_device
+from tad_mctc.typing import MockTensor
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq, ChargeModel
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
@@ -99,7 +99,7 @@ def test_solve_dtype_fail() -> None:
# all tensor must have the same type
with pytest.raises(RuntimeError):
- eeq.solve(t, t.type(torch.float16), t, model, t)
+ model.solve(t, t.type(torch.float16), t, t)
@pytest.mark.cuda
@@ -115,4 +115,32 @@ def test_solve_device_fail() -> None:
# all tensor must be on the same device
with pytest.raises(RuntimeError):
- eeq.solve(t, t2, t, model, t)
+ model.solve(t, t2, t, t)
+
+
+def test_model_device_different() -> None:
+ cuda_tensor = MockTensor([4, 5, 6])
+ cuda_tensor.device = torch.device("cuda")
+
+ cpu_tensor = MockTensor([1, 2, 3])
+ cpu_tensor.device = torch.device("cpu")
+ with pytest.raises(RuntimeError) as exc:
+ ChargeModel(cpu_tensor, cpu_tensor, cpu_tensor, cuda_tensor)
+
+ assert "All tensors must be on the same device!" in str(exc.value)
+
+
+def test_solve_device_different() -> None:
+ model = eeq.EEQModel.param2019()
+
+ cuda_tensor = MockTensor([4, 5, 6])
+ cuda_tensor.device = torch.device("cuda")
+
+ cpu_tensor = MockTensor([1, 2, 3])
+ cpu_tensor.device = torch.device("cpu")
+
+ # all tensor must be on the same device
+ with pytest.raises(RuntimeError) as exc:
+ model.solve(cpu_tensor, cuda_tensor, cpu_tensor, cpu_tensor)
+
+ assert "must be on the same device!" in str(exc.value)
diff --git a/test/test_grad/test_charge.py b/test/test_grad/test_charge.py
index c5db967..917f245 100644
--- a/test/test_grad/test_charge.py
+++ b/test/test_grad/test_charge.py
@@ -38,7 +38,7 @@
from tad_mctc.data.molecules import mols as samples
from tad_mctc.ncoord import cn_eeq
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
from tad_multicharge.typing import DD, Callable, Tensor
from ..conftest import DEVICE, FAST_MODE
@@ -69,7 +69,7 @@ def gradchecker(
def func(pos: Tensor, tchrg: Tensor) -> Tensor:
cn = cn_eeq(numbers, positions)
- return eeq.solve(numbers, pos, tchrg, eeq_model, cn)[0]
+ return eeq_model.solve(numbers, pos, tchrg, cn)[0]
return func, (positions, total_charge)
@@ -127,7 +127,7 @@ def gradchecker_batch(
def func(pos: Tensor, tchrg: Tensor) -> Tensor:
cn = cn_eeq(numbers, positions)
- return eeq.solve(numbers, pos, tchrg, eeq_model, cn)[0]
+ return eeq_model.solve(numbers, pos, tchrg, cn)[0]
return func, (positions, total_charge)
diff --git a/test/test_grad/test_dedr.py b/test/test_grad/test_dedr.py
index 3feb79f..aa8844a 100644
--- a/test/test_grad/test_dedr.py
+++ b/test/test_grad/test_dedr.py
@@ -26,7 +26,7 @@
from tad_mctc.batch import pack
from tad_mctc.convert import tensor_to_numpy
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
from tad_multicharge.typing import DD, Callable, Tensor
from ..conftest import DEVICE, FAST_MODE
diff --git a/test/test_grad/test_dqdr.py b/test/test_grad/test_dqdr.py
index 07a9456..8a9167c 100644
--- a/test/test_grad/test_dqdr.py
+++ b/test/test_grad/test_dqdr.py
@@ -26,7 +26,7 @@
from tad_mctc.batch import pack
from tad_mctc.convert import reshape_fortran, tensor_to_numpy
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
from tad_multicharge.typing import DD, Callable, Tensor
from ..conftest import DEVICE, FAST_MODE
diff --git a/test/test_grad/test_param.py b/test/test_grad/test_param.py
index eb4fc98..f55f242 100644
--- a/test/test_grad/test_param.py
+++ b/test/test_grad/test_param.py
@@ -27,7 +27,7 @@
from tad_mctc.data.molecules import mols as samples
from tad_mctc.ncoord import cn_eeq
-from tad_multicharge import eeq
+from tad_multicharge.model import eeq
from tad_multicharge.typing import DD, Callable, Tensor
from ..conftest import DEVICE, FAST_MODE
@@ -57,7 +57,7 @@ def gradchecker(
def func(_chi: Tensor) -> Tensor:
model.chi = _chi
- return eeq.solve(numbers, positions, charge, model, cn)[1]
+ return model.solve(numbers, positions, charge, cn)[1]
return func, chi
@@ -114,7 +114,7 @@ def gradchecker_batch(
def func(_chi: Tensor) -> Tensor:
model.chi = _chi
- return eeq.solve(numbers, positions, charge, model, cn)[1]
+ return model.solve(numbers, positions, charge, cn)[1]
return func, chi
From 4d3e37a9313be28602a66420c2943656c89264f0 Mon Sep 17 00:00:00 2001
From: marvinfriede <51965259+marvinfriede@users.noreply.github.com>
Date: Wed, 17 Jan 2024 11:26:16 +0100
Subject: [PATCH 2/7] Adapt docs
---
README.rst | 2 +-
docs/index.rst | 2 +-
docs/modules/eeq.rst | 2 --
docs/modules/index.rst | 3 +--
docs/modules/model.rst | 2 --
docs/modules/model/base.rst | 2 ++
docs/modules/model/eeq.rst | 2 ++
docs/modules/model/index.rst | 8 ++++++++
8 files changed, 15 insertions(+), 8 deletions(-)
delete mode 100644 docs/modules/eeq.rst
delete mode 100644 docs/modules/model.rst
create mode 100644 docs/modules/model/base.rst
create mode 100644 docs/modules/model/eeq.rst
create mode 100644 docs/modules/model/index.rst
diff --git a/README.rst b/README.rst
index 2ac7552..4803542 100644
--- a/README.rst
+++ b/README.rst
@@ -159,7 +159,7 @@ The following example shows how to calculate the EEQ partial charges and the cor
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
eeq_model = eeq.EEQModel.param2019()
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
print(torch.sum(energy, -1))
# tensor(-0.1750)
diff --git a/docs/index.rst b/docs/index.rst
index 60f4290..a48c475 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -80,7 +80,7 @@ The following example shows how to calculate the EEQ partial charges and the cor
cn = torch.tensor([3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
eeq_model = eeq.EEQModel.param2019()
- energy, qat = eeq.solve(numbers, positions, total_charge, eeq_model, cn)
+ energy, qat = eeq_model.solve(numbers, positions, total_charge, cn)
print(torch.sum(energy, -1))
# tensor(-0.1750)
diff --git a/docs/modules/eeq.rst b/docs/modules/eeq.rst
deleted file mode 100644
index f0251e4..0000000
--- a/docs/modules/eeq.rst
+++ /dev/null
@@ -1,2 +0,0 @@
-.. automodule:: tad_multicharge.eeq
- :members:
diff --git a/docs/modules/index.rst b/docs/modules/index.rst
index 6eb7454..b503806 100644
--- a/docs/modules/index.rst
+++ b/docs/modules/index.rst
@@ -9,6 +9,5 @@ The following modules are contained with `tad_multicharge`.
param/index
defaults
- eeq
- model
+ model/index
typing/index
diff --git a/docs/modules/model.rst b/docs/modules/model.rst
deleted file mode 100644
index 5e5d254..0000000
--- a/docs/modules/model.rst
+++ /dev/null
@@ -1,2 +0,0 @@
-.. automodule:: tad_multicharge.model
- :members:
diff --git a/docs/modules/model/base.rst b/docs/modules/model/base.rst
new file mode 100644
index 0000000..541e945
--- /dev/null
+++ b/docs/modules/model/base.rst
@@ -0,0 +1,2 @@
+.. automodule:: tad_multicharge.model.base
+ :members:
diff --git a/docs/modules/model/eeq.rst b/docs/modules/model/eeq.rst
new file mode 100644
index 0000000..a8bd5a5
--- /dev/null
+++ b/docs/modules/model/eeq.rst
@@ -0,0 +1,2 @@
+.. automodule:: tad_multicharge.model.eeq
+ :members:
diff --git a/docs/modules/model/index.rst b/docs/modules/model/index.rst
new file mode 100644
index 0000000..43b715a
--- /dev/null
+++ b/docs/modules/model/index.rst
@@ -0,0 +1,8 @@
+.. _model:
+
+.. automodule:: tad_multicharge.model
+
+.. toctree::
+
+ base
+ eeq
From 001108cd577da930eff1a153ab5b59f16e533fb0 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 17 Jan 2024 10:26:46 +0000
Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
src/tad_multicharge/model/__init__.py | 2 +-
src/tad_multicharge/model/base.py | 3 ++-
src/tad_multicharge/model/eeq.py | 4 ++--
test/test_charge/test_general.py | 3 ++-
4 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/src/tad_multicharge/model/__init__.py b/src/tad_multicharge/model/__init__.py
index 5ed0e88..c7ab93d 100644
--- a/src/tad_multicharge/model/__init__.py
+++ b/src/tad_multicharge/model/__init__.py
@@ -19,7 +19,7 @@
Model
=====
-This module contains all available charge models. Currently, only the
+This module contains all available charge models. Currently, only the
electronegativity equilibration model (EEQ) is implemented.
"""
from .base import *
diff --git a/src/tad_multicharge/model/base.py b/src/tad_multicharge/model/base.py
index 639d51c..32a82e3 100644
--- a/src/tad_multicharge/model/base.py
+++ b/src/tad_multicharge/model/base.py
@@ -23,9 +23,10 @@
"""
from __future__ import annotations
-import torch
from abc import abstractmethod
+import torch
+
from ..typing import Tensor, TensorLike
__all__ = ["ChargeModel"]
diff --git a/src/tad_multicharge/model/eeq.py b/src/tad_multicharge/model/eeq.py
index 7b11140..c83d6d1 100644
--- a/src/tad_multicharge/model/eeq.py
+++ b/src/tad_multicharge/model/eeq.py
@@ -56,9 +56,9 @@
from tad_mctc.ncoord import cn_eeq, erf_count
from .. import defaults
-from .base import ChargeModel
from ..param import eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
+from .base import ChargeModel
__all__ = ["EEQModel", "get_charges"]
@@ -173,7 +173,7 @@ def solve(
f"All tensors of '{name}' must have the same dtype!\n"
f"Use `{name}.param2019(dtype=dtype)` to correctly set it."
)
-
+
eps = torch.tensor(torch.finfo(positions.dtype).eps, **self.dd)
zero = torch.tensor(0.0, **self.dd)
stop = torch.sqrt(torch.tensor(2.0 / math.pi, **self.dd)) # sqrt(2/pi)
diff --git a/test/test_charge/test_general.py b/test/test_charge/test_general.py
index c696179..3956185 100644
--- a/test/test_charge/test_general.py
+++ b/test/test_charge/test_general.py
@@ -30,12 +30,13 @@
precision. For double precision, however the results are identical.
"""
from __future__ import annotations
+
import pytest
import torch
from tad_mctc.convert import str_to_device
from tad_mctc.typing import MockTensor
-from tad_multicharge.model import eeq, ChargeModel
+from tad_multicharge.model import ChargeModel, eeq
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
From 2b73ca7c9d27deb243704018a05756d90894ace7 Mon Sep 17 00:00:00 2001
From: marvinfriede <51965259+marvinfriede@users.noreply.github.com>
Date: Wed, 17 Jan 2024 11:28:04 +0100
Subject: [PATCH 4/7] Remove mypy from pre-commit
---
.pre-commit-config.yaml | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b30f71e..ace90f0 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,9 +15,6 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see .
-ci:
- skip: [mypy]
-
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
@@ -62,10 +59,3 @@ repos:
hooks:
- id: black
stages: [commit]
-
- - repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.8.0
- hooks:
- - id: mypy
- additional_dependencies: [types-all]
- exclude: 'test/conftest.py'
From 4cb3f6cc41c6353593a1f577426ea9e9a8e58cba Mon Sep 17 00:00:00 2001
From: marvinfriede <51965259+marvinfriede@users.noreply.github.com>
Date: Wed, 17 Jan 2024 13:23:44 +0100
Subject: [PATCH 5/7] Remove ellipsis
---
src/tad_multicharge/model/base.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/tad_multicharge/model/base.py b/src/tad_multicharge/model/base.py
index 32a82e3..eda911a 100644
--- a/src/tad_multicharge/model/base.py
+++ b/src/tad_multicharge/model/base.py
@@ -108,4 +108,3 @@ def solve(
(Tensor, Tensor)
Tuple of electrostatic energies and partial charges.
"""
- ...
From 9133d5e289612a6a6dc806663ca61a6d752d9eaf Mon Sep 17 00:00:00 2001
From: marvinfriede <51965259+marvinfriede@users.noreply.github.com>
Date: Wed, 17 Jan 2024 13:27:08 +0100
Subject: [PATCH 6/7] Move defaults to param
---
docs/modules/index.rst | 1 -
docs/modules/{ => param}/defaults.rst | 0
docs/modules/param/index.rst | 1 +
src/tad_multicharge/model/eeq.py | 2 +-
src/tad_multicharge/{ => param}/defaults.py | 11 +++++++----
5 files changed, 9 insertions(+), 6 deletions(-)
rename docs/modules/{ => param}/defaults.rst (100%)
rename src/tad_multicharge/{ => param}/defaults.py (83%)
diff --git a/docs/modules/index.rst b/docs/modules/index.rst
index b503806..e70b6bd 100644
--- a/docs/modules/index.rst
+++ b/docs/modules/index.rst
@@ -8,6 +8,5 @@ The following modules are contained with `tad_multicharge`.
.. toctree::
param/index
- defaults
model/index
typing/index
diff --git a/docs/modules/defaults.rst b/docs/modules/param/defaults.rst
similarity index 100%
rename from docs/modules/defaults.rst
rename to docs/modules/param/defaults.rst
diff --git a/docs/modules/param/index.rst b/docs/modules/param/index.rst
index 1cd42db..d79500b 100644
--- a/docs/modules/param/index.rst
+++ b/docs/modules/param/index.rst
@@ -4,4 +4,5 @@
.. toctree::
+ defaults
eeq2019
diff --git a/src/tad_multicharge/model/eeq.py b/src/tad_multicharge/model/eeq.py
index c83d6d1..9e11a52 100644
--- a/src/tad_multicharge/model/eeq.py
+++ b/src/tad_multicharge/model/eeq.py
@@ -55,7 +55,7 @@
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import cn_eeq, erf_count
-from .. import defaults
+from ..param import defaults
from ..param import eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
from .base import ChargeModel
diff --git a/src/tad_multicharge/defaults.py b/src/tad_multicharge/param/defaults.py
similarity index 83%
rename from src/tad_multicharge/defaults.py
rename to src/tad_multicharge/param/defaults.py
index d7e3a9e..d957a34 100644
--- a/src/tad_multicharge/defaults.py
+++ b/src/tad_multicharge/param/defaults.py
@@ -16,14 +16,17 @@
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see .
"""
-Defaults
-========
+Parameters: Defaults
+====================
Default global parameters of the charge models.
-- EEQ: real-space cutoffs for the coordination number
+For the EEQ model, the following defaults are set:
+- real-space cutoffs for the coordination number
-- EEQ: Steepness of CN counting function
+- maximum coordination number
+
+- steepness of CN counting function
"""
EEQ_CN_CUTOFF = 25.0
From c8327730af57fe02c90f7b88acd956444a874075 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Wed, 17 Jan 2024 12:27:20 +0000
Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
src/tad_multicharge/model/eeq.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/tad_multicharge/model/eeq.py b/src/tad_multicharge/model/eeq.py
index 9e11a52..a78002c 100644
--- a/src/tad_multicharge/model/eeq.py
+++ b/src/tad_multicharge/model/eeq.py
@@ -55,8 +55,7 @@
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import cn_eeq, erf_count
-from ..param import defaults
-from ..param import eeq2019
+from ..param import defaults, eeq2019
from ..typing import DD, Any, CountingFunction, Tensor, get_default_dtype
from .base import ChargeModel