1
1
from typing import Optional
2
2
3
3
import torch
4
- from torch .special import gammainc , gammaincc , gammaln
4
+ from torch .special import gammainc
5
5
6
- from .potential import Potential
7
-
8
-
9
- def gamma (x : torch .Tensor ) -> torch .Tensor :
10
- """
11
- (Complete) Gamma function.
6
+ from torchpme .lib import gamma , gammaincc_over_powerlaw
12
7
13
- pytorch has not implemented the commonly used (complete) Gamma function. We define
14
- it in a custom way to make autograd work as in
15
- https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122
16
- """
17
- return torch .exp (gammaln (x ))
8
+ from .potential import Potential
18
9
19
10
20
11
class InversePowerLawPotential (Potential ):
@@ -46,16 +37,16 @@ class InversePowerLawPotential(Potential):
46
37
47
38
def __init__ (
48
39
self ,
49
- exponent : float ,
40
+ exponent : int ,
50
41
smearing : Optional [float ] = None ,
51
42
exclusion_radius : Optional [float ] = None ,
52
43
dtype : Optional [torch .dtype ] = None ,
53
44
device : Optional [torch .device ] = None ,
54
45
):
55
46
super ().__init__ (smearing , exclusion_radius , dtype , device )
56
47
57
- if exponent <= 0 or exponent > 3 :
58
- raise ValueError ( f"` exponent` p= { exponent } has to satisfy 0 < p <= 3" )
48
+ # function call to check the validity of the exponent
49
+ gammaincc_over_powerlaw ( exponent , torch . tensor ( 1.0 , dtype = dtype , device = device ) )
59
50
self .register_buffer (
60
51
"exponent" , torch .tensor (exponent , dtype = self .dtype , device = self .device )
61
52
)
@@ -130,9 +121,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
130
121
# for consistency reasons.
131
122
masked = torch .where (x == 0 , 1.0 , x ) # avoid NaNs in backwards, see Coulomb
132
123
return torch .where (
133
- k_sq == 0 ,
134
- 0.0 ,
135
- prefac * gammaincc (peff , masked ) / masked ** peff * gamma (peff ),
124
+ k_sq == 0 , 0.0 , prefac * gammaincc_over_powerlaw (exponent , masked )
136
125
)
137
126
138
127
def self_contribution (self ) -> torch .Tensor :
@@ -145,7 +134,11 @@ def self_contribution(self) -> torch.Tensor:
145
134
return 1 / gamma (phalf + 1 ) / (2 * self .smearing ** 2 ) ** phalf
146
135
147
136
def background_correction (self ) -> torch .Tensor :
148
- # "charge neutrality" correction for 1/r^p potential
137
+ # "charge neutrality" correction for 1/r^p potential diverges for exponent p = 3
138
+ # and is not needed for p > 3 , so we set it to zero (see in
139
+ # https://doi.org/10.48550/arXiv.2412.03281 SI section)
140
+ if self .exponent >= 3 :
141
+ return torch .tensor (0.0 , dtype = self .dtype , device = self .device )
149
142
if self .smearing is None :
150
143
raise ValueError (
151
144
"Cannot compute background correction without specifying `smearing`."
0 commit comments