Skip to content

Commit 3540a81

Browse files
authored
Merge pull request #128 from lab-cosmo/exponents
Add general integer exponents
2 parents e5b1ce8 + 0f8d7e5 commit 3540a81

File tree

9 files changed

+162
-51
lines changed

9 files changed

+162
-51
lines changed

docs/src/references/changelog.rst

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Added
3232
Fixed
3333
#####
3434

35+
* Refactor the ``InversePowerLawPotential`` class to restrict the exponent to integer
36+
values
3537
* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
3638
``Calculator`` classses
3739
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class

examples/8-combined-potential.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
# evaluation, and so one has to set it also for the combined potential, even if it is
6666
# not used explicitly in the evaluation of the combination.
6767

68-
pot_1 = InversePowerLawPotential(exponent=1.0, smearing=smearing)
69-
pot_2 = InversePowerLawPotential(exponent=2.0, smearing=smearing)
68+
pot_1 = InversePowerLawPotential(exponent=1, smearing=smearing)
69+
pot_2 = InversePowerLawPotential(exponent=2, smearing=smearing)
7070

7171
potential = CombinedPotential(potentials=[pot_1, pot_2], smearing=smearing)
7272

src/torchpme/calculators/ewald.py

-1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,5 @@ def _compute_kspace(
138138
charge_tot = torch.sum(charges, dim=0)
139139
prefac = self.potential.background_correction()
140140
energy -= 2 * prefac * charge_tot * ivolume
141-
142141
# Compensate for double counting of pairs (i,j) and (j,i)
143142
return energy / 2

src/torchpme/lib/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
generate_kvectors_for_mesh,
55
get_ns_mesh,
66
)
7+
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
78
from .mesh_interpolator import MeshInterpolator
89

910
__all__ = [
@@ -16,4 +17,8 @@
1617
"generate_kvectors_for_ewald",
1718
"generate_kvectors_for_mesh",
1819
"get_ns_mesh",
20+
"gamma",
21+
"CustomExp1",
22+
"gammaincc_over_powerlaw",
23+
"torch_exp1",
1924
]

src/torchpme/lib/math.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
from scipy.special import exp1
3+
from torch.special import gammaln
4+
5+
6+
def gamma(x: torch.Tensor) -> torch.Tensor:
7+
"""
8+
(Complete) Gamma function.
9+
10+
pytorch has not implemented the commonly used (complete) Gamma function. We define
11+
it in a custom way to make autograd work as in
12+
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
13+
"""
14+
return torch.exp(gammaln(x))
15+
16+
17+
class CustomExp1(torch.autograd.Function):
18+
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""
19+
20+
@staticmethod
21+
def forward(ctx, input):
22+
ctx.save_for_backward(input)
23+
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
24+
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)
25+
26+
@staticmethod
27+
def backward(ctx, grad_output):
28+
(input,) = ctx.saved_tensors
29+
return -grad_output * torch.exp(-input) / input
30+
31+
32+
def torch_exp1(input):
33+
"""Wrapper for the custom exponential integral function."""
34+
return CustomExp1.apply(input)
35+
36+
37+
def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
38+
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
39+
if exponent == 1:
40+
return torch.exp(-z) / z
41+
if exponent == 2:
42+
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
43+
if exponent == 3:
44+
return torch_exp1(z)
45+
if exponent == 4:
46+
return 2 * (
47+
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
48+
)
49+
if exponent == 5:
50+
return torch.exp(-z) - z * torch_exp1(z)
51+
if exponent == 6:
52+
return (
53+
(2 - 4 * z) * torch.exp(-z)
54+
+ 4 * torch.sqrt(torch.pi * z**3) * torch.erfc(torch.sqrt(z))
55+
) / 3
56+
raise ValueError(f"Unsupported exponent: {exponent}")

src/torchpme/potentials/inversepowerlaw.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,11 @@
11
from typing import Optional
22

33
import torch
4-
from torch.special import gammainc, gammaincc, gammaln
4+
from torch.special import gammainc
55

6-
from .potential import Potential
7-
8-
9-
def gamma(x: torch.Tensor) -> torch.Tensor:
10-
"""
11-
(Complete) Gamma function.
6+
from torchpme.lib import gamma, gammaincc_over_powerlaw
127

13-
pytorch has not implemented the commonly used (complete) Gamma function. We define
14-
it in a custom way to make autograd work as in
15-
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
16-
"""
17-
return torch.exp(gammaln(x))
8+
from .potential import Potential
189

1910

2011
class InversePowerLawPotential(Potential):
@@ -46,16 +37,16 @@ class InversePowerLawPotential(Potential):
4637

4738
def __init__(
4839
self,
49-
exponent: float,
40+
exponent: int,
5041
smearing: Optional[float] = None,
5142
exclusion_radius: Optional[float] = None,
5243
dtype: Optional[torch.dtype] = None,
5344
device: Optional[torch.device] = None,
5445
):
5546
super().__init__(smearing, exclusion_radius, dtype, device)
5647

57-
if exponent <= 0 or exponent > 3:
58-
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3")
48+
# function call to check the validity of the exponent
49+
gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device))
5950
self.register_buffer(
6051
"exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
6152
)
@@ -130,9 +121,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
130121
# for consistency reasons.
131122
masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb
132123
return torch.where(
133-
k_sq == 0,
134-
0.0,
135-
prefac * gammaincc(peff, masked) / masked**peff * gamma(peff),
124+
k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked)
136125
)
137126

138127
def self_contribution(self) -> torch.Tensor:
@@ -145,7 +134,11 @@ def self_contribution(self) -> torch.Tensor:
145134
return 1 / gamma(phalf + 1) / (2 * self.smearing**2) ** phalf
146135

147136
def background_correction(self) -> torch.Tensor:
148-
# "charge neutrality" correction for 1/r^p potential
137+
# "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
138+
# and is not needed for p > 3 , so we set it to zero (see in
139+
# https://doi.org/10.48550/arXiv.2412.03281 SI section)
140+
if self.exponent >= 3:
141+
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
149142
if self.smearing is None:
150143
raise ValueError(
151144
"Cannot compute background correction without specifying `smearing`."

tests/calculators/test_values_ewald.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
100100
lr_wavelength = 0.5 * smearing
101101
calc = EwaldCalculator(
102102
InversePowerLawPotential(
103-
exponent=1.0,
103+
exponent=1,
104104
smearing=smearing,
105105
),
106106
lr_wavelength=lr_wavelength,
@@ -111,7 +111,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name):
111111
smearing = sr_cutoff / 5.0
112112
calc = PMECalculator(
113113
InversePowerLawPotential(
114-
exponent=1.0,
114+
exponent=1,
115115
smearing=smearing,
116116
),
117117
mesh_spacing=smearing / 8,
@@ -198,7 +198,7 @@ def test_wigner(crystal_name, scaling_factor):
198198
# Compute potential and compare against reference
199199
calc = EwaldCalculator(
200200
InversePowerLawPotential(
201-
exponent=1.0,
201+
exponent=1,
202202
smearing=smeareff,
203203
),
204204
lr_wavelength=smeareff / 2,

tests/lib/test_math.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import torch
3+
from scipy.special import exp1
4+
5+
from torchpme.lib import torch_exp1
6+
7+
8+
def finite_difference_derivative(func, x, h=1e-5):
9+
return (func(x + h) - func(x - h)) / (2 * h)
10+
11+
12+
def test_torch_exp1_consistency_with_scipy():
13+
x = torch.rand(1000, dtype=torch.float64)
14+
torch_result = torch_exp1(x)
15+
scipy_result = exp1(x.numpy())
16+
assert np.allclose(torch_result.numpy(), scipy_result, atol=1e-6)
17+
18+
19+
def test_torch_exp1_derivative():
20+
x = torch.rand(1, dtype=torch.float64, requires_grad=True)
21+
torch_result = torch_exp1(x)
22+
torch_result.backward()
23+
torch_exp1_prime = x.grad
24+
finite_diff_result = finite_difference_derivative(exp1, x.detach().numpy())
25+
assert np.allclose(torch_exp1_prime.numpy(), finite_diff_result, atol=1e-6)

0 commit comments

Comments
 (0)