From d67ff35fa3533657c33e5cf5f6a08bf0fb2a1edc Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Mon, 16 Dec 2024 17:04:11 +0100 Subject: [PATCH 1/2] Add general integer exponents up to 6 --- src/torchpme/potentials/inversepowerlaw.py | 29 +++++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index bd44236e..eeffba82 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,7 +1,7 @@ from typing import Optional import torch -from torch.special import gammainc, gammaincc, gammaln +from torch.special import gammaln from .potential import Potential @@ -17,6 +17,27 @@ def gamma(x: torch.Tensor) -> torch.Tensor: return torch.exp(gammaln(x)) +# Auxilary function for stable Fourier transform implementation +def gammainc_upper_over_powerlaw(exponent, zz): + if exponent == 1: + return torch.exp(-zz) / zz + if exponent == 2: + return torch.sqrt(torch.pi / zz) * torch.erfc(torch.sqrt(zz)) + if exponent == 3: + return -torch.expi(-zz) + if exponent == 4: + return 2 * ( + torch.exp(-zz) - torch.sqrt(torch.pi * zz) * torch.erfc(torch.sqrt(zz)) + ) + if exponent == 5: + return torch.exp(-zz) + zz * torch.expi(-zz) + if exponent == 6: + return ( + (2 - 4 * zz) * torch.exp(-zz) + + 4 * torch.sqrt(torch.pi) * zz**1.5 * torch.erfc(torch.sqrt(zz)) + ) / 3 + + class InversePowerLawPotential(Potential): """ Inverse power-law potentials of the form :math:`1/r^p`. @@ -46,7 +67,7 @@ class InversePowerLawPotential(Potential): def __init__( self, - exponent: float, + exponent: int, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -103,7 +124,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: x = 0.5 * dist**2 / smearing**2 peff = exponent / 2 prefac = 1.0 / (2 * smearing**2) ** peff - return prefac * gammainc(peff, x) / x**peff + return self.from_dist(dist) - prefac * gammainc_upper_over_powerlaw(exponent, x) @torch.jit.export def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: @@ -136,7 +157,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: return torch.where( k_sq == 0, 0.0, - prefac * gammaincc(peff, masked) / masked**peff * gamma(peff), + prefac * gammainc_upper_over_powerlaw(exponent, masked), ) def self_contribution(self) -> torch.Tensor: From ad9f1949af0fa495d812ee1d7b47f81f5cc777f6 Mon Sep 17 00:00:00 2001 From: E-Rum Date: Fri, 20 Dec 2024 12:10:47 +0000 Subject: [PATCH 2/2] Updated helper function for new Vesin version, started to change tests --- src/torchpme/potentials/inversepowerlaw.py | 7 +++++-- tests/calculators/test_values_ewald.py | 6 +++--- tests/helpers.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index eeffba82..88e82e13 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -19,6 +19,9 @@ def gamma(x: torch.Tensor) -> torch.Tensor: # Auxilary function for stable Fourier transform implementation def gammainc_upper_over_powerlaw(exponent, zz): + if exponent not in [1, 2, 3, 4, 5, 6]: + raise ValueError(f"Unsupported exponent: {exponent}") + if exponent == 1: return torch.exp(-zz) / zz if exponent == 2: @@ -79,8 +82,8 @@ def __init__( if device is None: device = torch.device("cpu") - if exponent <= 0 or exponent > 3: - raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3") + # function call to check the validity of the exponent + gammainc_upper_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device)) self.register_buffer( "exponent", torch.tensor(exponent, dtype=dtype, device=device) ) diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index 208d937d..6b405011 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -100,7 +100,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): lr_wavelength = 0.5 * smearing calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), lr_wavelength=lr_wavelength, @@ -111,7 +111,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 calc = PMECalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), mesh_spacing=smearing / 8, @@ -198,7 +198,7 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smeareff, ), lr_wavelength=smeareff / 2, diff --git a/tests/helpers.py b/tests/helpers.py index a4d14a86..6322f0ee 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -257,7 +257,7 @@ def neighbor_list( nl = NeighborList(cutoff=cutoff, full_list=full_neighbor_list) neighbor_indices, d, S = nl.compute( - points=positions, box=box, periodic=periodic, quantities="PdS" + points=positions, box=box, periodic=periodic, quantities="pdS" ) neighbor_indices = torch.from_numpy(neighbor_indices.astype(int)).to(