Skip to content

Commit 99421a3

Browse files
committed
Moved mathematical functions from inversepowerlaw.py to a separate math.py script
1 parent f077ba6 commit 99421a3

File tree

4 files changed

+65
-56
lines changed

4 files changed

+65
-56
lines changed

src/torchpme/lib/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
generate_kvectors_for_mesh,
55
get_ns_mesh,
66
)
7+
from .math import CustomExp1, gamma, gammaincc_over_powerlaw, torch_exp1
78
from .mesh_interpolator import MeshInterpolator
89

910
__all__ = [
@@ -16,4 +17,8 @@
1617
"generate_kvectors_for_ewald",
1718
"generate_kvectors_for_mesh",
1819
"get_ns_mesh",
20+
"gamma",
21+
"CustomExp1",
22+
"gammaincc_over_powerlaw",
23+
"torch_exp1",
1924
]

src/torchpme/lib/math.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
from scipy.special import exp1
3+
from torch.special import gammaln
4+
5+
6+
def gamma(x: torch.Tensor) -> torch.Tensor:
7+
"""
8+
(Complete) Gamma function.
9+
10+
pytorch has not implemented the commonly used (complete) Gamma function. We define
11+
it in a custom way to make autograd work as in
12+
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
13+
"""
14+
return torch.exp(gammaln(x))
15+
16+
17+
class CustomExp1(torch.autograd.Function):
18+
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""
19+
20+
@staticmethod
21+
def forward(ctx, input):
22+
ctx.save_for_backward(input)
23+
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
24+
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)
25+
26+
@staticmethod
27+
def backward(ctx, grad_output):
28+
(input,) = ctx.saved_tensors
29+
return -grad_output * torch.exp(-input) / input
30+
31+
32+
def torch_exp1(input):
33+
"""Wrapper for the custom exponential integral function."""
34+
return CustomExp1.apply(input)
35+
36+
37+
def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
38+
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
39+
if exponent == 1:
40+
return torch.exp(-z) / z
41+
if exponent == 2:
42+
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
43+
if exponent == 3:
44+
return torch_exp1(z)
45+
if exponent == 4:
46+
return 2 * (
47+
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
48+
)
49+
if exponent == 5:
50+
return torch.exp(-z) - z * torch_exp1(z)
51+
if exponent == 6:
52+
return (
53+
(2 - 4 * z) * torch.exp(-z)
54+
+ 4 * torch.sqrt(torch.pi * z**3) * torch.erfc(torch.sqrt(z))
55+
) / 3
56+
raise ValueError(f"Unsupported exponent: {exponent}")

src/torchpme/potentials/inversepowerlaw.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,11 @@
11
from typing import Optional
22

33
import torch
4-
from scipy.special import exp1
5-
from torch.special import gammainc, gammaln
4+
from torch.special import gammainc
65

7-
from .potential import Potential
8-
9-
10-
def gamma(x: torch.Tensor) -> torch.Tensor:
11-
"""
12-
(Complete) Gamma function.
13-
14-
pytorch has not implemented the commonly used (complete) Gamma function. We define
15-
it in a custom way to make autograd work as in
16-
https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
17-
"""
18-
return torch.exp(gammaln(x))
19-
20-
21-
class CustomExp1(torch.autograd.Function):
22-
"""Custom exponential integral function Exp1(x) to have an autograd-compatible version."""
6+
from torchpme.lib import gamma, gammaincc_over_powerlaw
237

24-
@staticmethod
25-
def forward(ctx, input):
26-
ctx.save_for_backward(input)
27-
input_numpy = input.cpu().numpy() if not input.is_cpu else input.numpy()
28-
return torch.tensor(exp1(input_numpy), device=input.device, dtype=input.dtype)
29-
30-
@staticmethod
31-
def backward(ctx, grad_output):
32-
(input,) = ctx.saved_tensors
33-
return -grad_output * torch.exp(-input) / input
34-
35-
36-
def torch_exp1(input):
37-
"""Wrapper for the custom exponential integral function."""
38-
return CustomExp1.apply(input)
39-
40-
41-
def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
42-
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
43-
if exponent == 1:
44-
return torch.exp(-z) / z
45-
if exponent == 2:
46-
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
47-
if exponent == 3:
48-
return torch_exp1(z)
49-
if exponent == 4:
50-
return 2 * (
51-
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
52-
)
53-
if exponent == 5:
54-
return torch.exp(-z) - z * torch_exp1(z)
55-
if exponent == 6:
56-
return (
57-
(2 - 4 * z) * torch.exp(-z)
58-
+ 4 * torch.sqrt(torch.pi * z**3) * torch.erfc(torch.sqrt(z))
59-
) / 3
60-
raise ValueError(f"Unsupported exponent: {exponent}")
8+
from .potential import Potential
619

6210

6311
class InversePowerLawPotential(Potential):

tests/utils/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from scipy.special import exp1
44

5-
from torchpme.potentials.inversepowerlaw import torch_exp1
5+
from torchpme.lib import torch_exp1
66

77

88
def finite_difference_derivative(func, x, h=1e-5):

0 commit comments

Comments
 (0)