Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Magnetostatics in TorchPME #133

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/torchpme/calculators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .calculator import Calculator
from .calculator_dipole import CalculatorDipole
from .ewald import EwaldCalculator
from .p3m import P3MCalculator
from .pme import PMECalculator

__all__ = ["Calculator", "EwaldCalculator", "P3MCalculator", "PMECalculator"]
__all__ = [
"Calculator",
"EwaldCalculator",
"P3MCalculator",
"PMECalculator",
"CalculatorDipole",
]
220 changes: 220 additions & 0 deletions src/torchpme/calculators/calculator_dipole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import torch
from torch import profiler

from ..potentials import PotentialDipole


class CalculatorDipole(torch.nn.Module):
"""TODO: Add docstring"""

def __init__(
self,
potential: PotentialDipole,
full_neighbor_list: bool = False,
prefactor: float = 1.0,
):
super().__init__()
# TorchScript requires to initialize all attributes in __init__
self._device = torch.device("cpu")
self._dtype = torch.float32

self.potential = potential

self.full_neighbor_list = full_neighbor_list

self.prefactor = prefactor

def _compute_rspace(
self,
dipoles: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_vectors: torch.Tensor,
) -> torch.Tensor:
"""TODO: Add docstring"""
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
# contained in the neighbor list
with profiler.record_function("compute bare potential"):
if self.potential.smearing is None:
potentials_bare_scalar, potentials_bare_tensor = (
self.potential.from_dist(neighbor_vectors)
)
else:
raise NotImplementedError(
"TODO: Implement smearing for `compute_rspace`"
)

# Multiply the bare potential terms V(r_ij) with the corresponding dipoles
# of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of
# multiple atom i's, we need to access those from neighbor_indices
atom_is = neighbor_indices[:, 0]
atom_js = neighbor_indices[:, 1]
with profiler.record_function("compute real potential"):
contributions_is = dipoles[atom_js] * potentials_bare_scalar - torch.einsum(
"ij,ijk->ik", dipoles[atom_js], potentials_bare_tensor
)

# For each atom i, add up all contributions of the form q_j*V(r_ij) for j
# ranging over all of its neighbors.
with profiler.record_function("assign potential"):
potential = torch.zeros_like(dipoles)
potential.index_add_(0, atom_is, contributions_is)
# If we are using a half neighbor list, we need to add the contributions
# from the "inverse" pairs (j, i) to the atoms i
if not self.full_neighbor_list:
contributions_js = dipoles[
atom_is
] * potentials_bare_scalar - torch.einsum(
"ij,ijk->ik", dipoles[atom_is], potentials_bare_tensor
)
potential.index_add_(0, atom_js, contributions_js)

# Compensate for double counting of pairs (i,j) and (j,i)
return potential / 2

def _compute_kspace(
self,
dipoles: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError(
f"`compute_kspace` not implemented for {self.__class__.__name__}"
)

def forward(
self,
dipoles: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_vectors: torch.Tensor,
):
r"""TODO: Add docstring"""
# self._validate_compute_parameters(
# charges=charges,
# cell=cell,
# positions=positions,
# neighbor_indices=neighbor_indices,
# neighbor_distances=neighbor_distances,
# smearing=self.potential.smearing,
# )

# Compute short-range (SR) part using a real space sum
potential_sr = self._compute_rspace(
dipoles=dipoles,
neighbor_indices=neighbor_indices,
neighbor_vectors=neighbor_vectors,
)

if self.potential.smearing is None:
return self.prefactor * potential_sr
return None
# Compute long-range (LR) part using a Fourier / reciprocal space sum
# potential_lr = self._compute_kspace(
# charges=charges,
# cell=cell,
# positions=positions,
# )

# return self.prefactor * (potential_sr + potential_lr)

# @staticmethod
# def _validate_compute_parameters(
# dipoles: torch.Tensor,
# cell: torch.Tensor,
# positions: torch.Tensor,
# neighbor_indices: torch.Tensor,
# neighbor_vectors: torch.Tensor,
# smearing: Optional[float],
# ) -> None:
# device = positions.device
# dtype = positions.dtype

# # check shape, dtype and device of positions
# num_atoms = len(positions)
# if list(positions.shape) != [len(positions), 3]:
# raise ValueError(
# "`positions` must be a tensor with shape [n_atoms, 3], got tensor "
# f"with shape {list(positions.shape)}"
# )

# # check shape, dtype and device of cell
# if list(cell.shape) != [3, 3]:
# raise ValueError(
# "`cell` must be a tensor with shape [3, 3], got tensor with shape "
# f"{list(cell.shape)}"
# )

# if cell.dtype != dtype:
# raise ValueError(
# f"type of `cell` ({cell.dtype}) must be same as `positions` ({dtype})"
# )

# if cell.device != device:
# raise ValueError(
# f"device of `cell` ({cell.device}) must be same as `positions` "
# f"({device})"
# )

# if smearing is not None and torch.equal(
# cell.det(), torch.tensor(0.0, dtype=cell.dtype, device=cell.device)
# ):
# raise ValueError(
# "provided `cell` has a determinant of 0 and therefore is not valid for "
# "periodic calculation"
# )

# # check shape, dtype & device of `charges`
# if charges.dim() != 2:
# raise ValueError(
# "`charges` must be a 2-dimensional tensor, got "
# f"tensor with {charges.dim()} dimension(s) and shape "
# f"{list(charges.shape)}"
# )

# if list(charges.shape) != [num_atoms, charges.shape[1]]:
# raise ValueError(
# "`charges` must be a tensor with shape [n_atoms, n_channels], with "
# "`n_atoms` being the same as the variable `positions`. Got tensor with "
# f"shape {list(charges.shape)} where positions contains "
# f"{len(positions)} atoms"
# )

# if charges.dtype != dtype:
# raise ValueError(
# f"type of `charges` ({charges.dtype}) must be same as `positions` "
# f"({dtype})"
# )

# if charges.device != device:
# raise ValueError(
# f"device of `charges` ({charges.device}) must be same as `positions` "
# f"({device})"
# )

# # check shape, dtype & device of `neighbor_indices` and `neighbor_distances`
# if neighbor_indices.shape[1] != 2:
# raise ValueError(
# "neighbor_indices is expected to have shape [num_neighbors, 2]"
# f", but got {list(neighbor_indices.shape)} for one "
# "structure"
# )

# if neighbor_indices.device != device:
# raise ValueError(
# f"device of `neighbor_indices` ({neighbor_indices.device}) must be "
# f"same as `positions` ({device})"
# )

# if neighbor_distances.shape != neighbor_indices[:, 0].shape:
# raise ValueError(
# "`neighbor_indices` and `neighbor_distances` need to have shapes "
# "[num_neighbors, 2] and [num_neighbors], but got "
# f"{list(neighbor_indices.shape)} and {list(neighbor_distances.shape)}"
# )

# if neighbor_distances.device != device:
# raise ValueError(
# f"device of `neighbor_distances` ({neighbor_distances.device}) must be "
# f"same as `positions` ({device})"
# )
2 changes: 2 additions & 0 deletions src/torchpme/potentials/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .coulomb import CoulombPotential
from .inversepowerlaw import InversePowerLawPotential
from .potential import Potential
from .potential_dipole import PotentialDipole
from .spline import SplinePotential

__all__ = [
Expand All @@ -10,4 +11,5 @@
"InversePowerLawPotential",
"Potential",
"SplinePotential",
"PotentialDipole",
]
49 changes: 49 additions & 0 deletions src/torchpme/potentials/potential_dipole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Optional

import torch

from .potential import Potential


class PotentialDipole(Potential):
"""TODO: Add docstring"""

def __init__(
self,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device("cpu")

def from_dist(self, vector: torch.Tensor) -> torch.Tensor:
"""TODO: Add docstring"""
r_mag = torch.norm(vector, dim=1, keepdim=True)
scalar_potential = 1.0 / (r_mag**3)
r_outer = torch.einsum(
"bi,bj->bij", vector, vector
) # outer product shape (batch, 3, 3)
tensor_potential = (3.0 / (r_mag**5)).unsqueeze(-1) * r_outer
return scalar_potential, tensor_potential

def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("TODO: Implement smearing for `lr_from_dist`")

def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("TODO: Implement smearing for `lr_from_k_sq`")

def self_contribution(self) -> torch.Tensor:
raise NotImplementedError("TODO: Implement smearing for `self_contribution`")

def background_correction(self) -> torch.Tensor:
raise NotImplementedError(
"TODO: Implement smearing for `background_correction`"
)

self_contribution.__doc__ = Potential.self_contribution.__doc__
background_correction.__doc__ = Potential.background_correction.__doc__
26 changes: 26 additions & 0 deletions tests/test_magnetostatics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

from torchpme.calculators import CalculatorDipole
from torchpme.potentials import PotentialDipole


def test_magnetostatics():
calculator = CalculatorDipole(
potential=PotentialDipole(),
full_neighbor_list=False,
)
dipoles = torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0]])
pot = calculator(
dipoles=dipoles,
cell=torch.tensor([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 4.0]]),
neighbor_indices=torch.tensor([[1, 0], [1, 2], [0, 2]]),
neighbor_vectors=torch.tensor(
[[0.0, 2.0, 0.0], [0.0, 2.0, 0.0], [0.0, 4.0, 0.0]]
),
)
result = torch.einsum("ij,ij->i", pot, dipoles).sum()
expected_result = torch.tensor(-0.2656)
assert torch.isclose(
result, expected_result, atol=1e-4
), f"Expected {expected_result}, but got {result}"
Loading