Skip to content

Commit e5b1ce8

Browse files
authored
Add Explicit dtype and device Support for Calculators and Ensure Compatibility with Potentials (#143)
* Refactor parameter handling in calculators and potentials for improved dtype and device management * Updated docstrings and changelog, added an assertion to check for an instance of the potential, and resolved the TorchScript Potential/Calculator incompatibility. * Update changelog and add test for potential and calculator compatibility
1 parent 04edb22 commit e5b1ce8

File tree

16 files changed

+140
-79
lines changed

16 files changed

+140
-79
lines changed

docs/extensions/versions_list.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(self):
4343
:margin: 0 0 0 0\n"""
4444

4545
for group_i, (version_short, group) in enumerate(grouped_versions.items()):
46-
4746
if group_i < 3:
4847
generated_content += f"""
4948
.. grid-item::

docs/src/references/changelog.rst

+7
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
2424
`Unreleased <https://github.com/lab-cosmo/torch-pme/>`_
2525
-------------------------------------------------------
2626

27+
Added
28+
#####
29+
30+
* Added ``dtype`` and ``device`` for ``Calculator`` classses
31+
2732
Fixed
2833
#####
2934

35+
* Ensured consistency of ``dtype`` and ``device`` in the ``Potential`` and
36+
``Calculator`` classses
3037
* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
3138
* Fix inconsistent ``cutoff`` in neighbor list example
3239
* All calculators now check if the cell is zero if the potential is range-separated

examples/5-autograd-demo.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -463,16 +463,16 @@ def forward(self, positions, cell, charges):
463463

464464
print(
465465
f"""
466-
Delta-Value: {value-jit_value}
466+
Delta-Value: {value - jit_value}
467467
468468
Delta-Position gradients:
469-
{positions.grad.T-jit_positions.grad.T}
469+
{positions.grad.T - jit_positions.grad.T}
470470
471471
Delta-Cell gradients:
472-
{cell.grad-jit_cell.grad}
472+
{cell.grad - jit_cell.grad}
473473
474474
Delta-Charges gradients:
475-
{charges.grad.T-jit_charges.grad.T}
475+
{charges.grad.T - jit_charges.grad.T}
476476
"""
477477
)
478478

src/torchpme/calculators/calculator.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,37 @@ class Calculator(torch.nn.Module):
2626
will come from a full (True) or half (False, default) neighbor list.
2727
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
2828
common values.
29+
:param dtype: type used for the internal buffers and parameters
30+
:param device: device used for the internal buffers and parameters
2931
"""
3032

3133
def __init__(
3234
self,
3335
potential: Potential,
3436
full_neighbor_list: bool = False,
3537
prefactor: float = 1.0,
38+
dtype: Optional[torch.dtype] = None,
39+
device: Optional[torch.device] = None,
3640
):
3741
super().__init__()
38-
# TorchScript requires to initialize all attributes in __init__
39-
self._device = torch.device("cpu")
40-
self._dtype = torch.float32
4142

43+
assert isinstance(potential, Potential), (
44+
f"Potential must be an instance of Potential, got {type(potential)}"
45+
)
46+
47+
self.device = "cpu" if device is None else device
48+
self.dtype = torch.get_default_dtype() if dtype is None else dtype
4249
self.potential = potential
4350

51+
assert self.dtype == self.potential.dtype, (
52+
f"Potential and Calculator must have the same dtype, got {self.dtype} and "
53+
f"{self.potential.dtype}"
54+
)
55+
assert self.device == self.potential.device, (
56+
f"Potential and Calculator must have the same device, got {self.device} and "
57+
f"{self.potential.device}"
58+
)
59+
4460
self.full_neighbor_list = full_neighbor_list
4561

4662
self.prefactor = prefactor

src/torchpme/calculators/ewald.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..lib import generate_kvectors_for_ewald
@@ -53,6 +55,8 @@ class EwaldCalculator(Calculator):
5355
:obj:`False`, a "half" neighbor list is expected.
5456
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
5557
common values.
58+
:param dtype: type used for the internal buffers and parameters
59+
:param device: device used for the internal buffers and parameters
5660
"""
5761

5862
def __init__(
@@ -61,11 +65,15 @@ def __init__(
6165
lr_wavelength: float,
6266
full_neighbor_list: bool = False,
6367
prefactor: float = 1.0,
68+
dtype: Optional[torch.dtype] = None,
69+
device: Optional[torch.device] = None,
6470
):
6571
super().__init__(
6672
potential=potential,
6773
full_neighbor_list=full_neighbor_list,
6874
prefactor=prefactor,
75+
dtype=dtype,
76+
device=device,
6977
)
7078
if potential.smearing is None:
7179
raise ValueError(

src/torchpme/calculators/p3m.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..lib.kspace_filter import P3MKSpaceFilter
@@ -40,6 +42,8 @@ class P3MCalculator(PMECalculator):
4042
set to :py:obj:`False`, a "half" neighbor list is expected.
4143
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
4244
common values.
45+
:param dtype: type used for the internal buffers and parameters
46+
:param device: device used for the internal buffers and parameters
4347
4448
For an **example** on the usage for any calculator refer to :ref:`userdoc-how-to`.
4549
"""
@@ -51,6 +55,8 @@ def __init__(
5155
interpolation_nodes: int = 4,
5256
full_neighbor_list: bool = False,
5357
prefactor: float = 1.0,
58+
dtype: Optional[torch.dtype] = None,
59+
device: Optional[torch.device] = None,
5460
):
5561
self.mesh_spacing: float = mesh_spacing
5662

@@ -62,6 +68,8 @@ def __init__(
6268
potential=potential,
6369
full_neighbor_list=full_neighbor_list,
6470
prefactor=prefactor,
71+
dtype=dtype,
72+
device=device,
6573
)
6674

6775
self.kspace_filter: P3MKSpaceFilter = P3MKSpaceFilter(

src/torchpme/calculators/pme.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24
from torch import profiler
35

@@ -45,6 +47,8 @@ class PMECalculator(Calculator):
4547
set to :obj:`False`, a "half" neighbor list is expected.
4648
:param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and
4749
common values.
50+
:param dtype: type used for the internal buffers and parameters
51+
:param device: device used for the internal buffers and parameters
4852
"""
4953

5054
def __init__(
@@ -54,11 +58,15 @@ def __init__(
5458
interpolation_nodes: int = 4,
5559
full_neighbor_list: bool = False,
5660
prefactor: float = 1.0,
61+
dtype: Optional[torch.dtype] = None,
62+
device: Optional[torch.device] = None,
5763
):
5864
super().__init__(
5965
potential=potential,
6066
full_neighbor_list=full_neighbor_list,
6167
prefactor=prefactor,
68+
dtype=dtype,
69+
device=device,
6270
)
6371

6472
if potential.smearing is None:
@@ -69,8 +77,8 @@ def __init__(
6977
self.mesh_spacing: float = mesh_spacing
7078

7179
self.kspace_filter: KSpaceFilter = KSpaceFilter(
72-
cell=torch.eye(3),
73-
ns_mesh=torch.ones(3, dtype=int),
80+
cell=torch.eye(3, dtype=self.dtype, device=self.device),
81+
ns_mesh=torch.ones(3, dtype=int, device=self.device),
7482
kernel=self.potential,
7583
fft_norm="backward",
7684
ifft_norm="forward",
@@ -79,8 +87,8 @@ def __init__(
7987
self.interpolation_nodes: int = interpolation_nodes
8088

8189
self.mesh_interpolator: MeshInterpolator = MeshInterpolator(
82-
cell=torch.eye(3),
83-
ns_mesh=torch.ones(3, dtype=int),
90+
cell=torch.eye(3, dtype=self.dtype, device=self.device),
91+
ns_mesh=torch.ones(3, dtype=int, device=self.device),
8492
interpolation_nodes=self.interpolation_nodes,
8593
method="Lagrange", # convention for classic PME
8694
)

src/torchpme/lib/mesh_interpolator.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,7 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor:
432432
"""
433433
if mesh_vals.dim() != 4:
434434
raise ValueError(
435-
f"`mesh_vals` of dimension {mesh_vals.dim()} has to be of "
436-
"dimension 4"
435+
f"`mesh_vals` of dimension {mesh_vals.dim()} has to be of dimension 4"
437436
)
438437

439438
return (

src/torchpme/potentials/combined.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def __init__(
4747
dtype=dtype,
4848
device=device,
4949
)
50-
if dtype is None:
51-
dtype = torch.get_default_dtype()
52-
if device is None:
53-
device = torch.device("cpu")
50+
5451
smearings = [pot.smearing for pot in potentials]
5552
if not all(smearings) and any(smearings):
5653
raise ValueError(
@@ -76,7 +73,9 @@ def __init__(
7673
"The number of initial weights must match the number of potentials being combined"
7774
)
7875
else:
79-
initial_weights = torch.ones(len(potentials), dtype=dtype, device=device)
76+
initial_weights = torch.ones(
77+
len(potentials), dtype=self.dtype, device=self.device
78+
)
8079
# for torchscript
8180
self.potentials = torch.nn.ModuleList(potentials)
8281
if learnable_weights:

src/torchpme/potentials/coulomb.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,17 @@ def __init__(
3838
device: Optional[torch.device] = None,
3939
):
4040
super().__init__(smearing, exclusion_radius, dtype, device)
41-
if dtype is None:
42-
dtype = torch.get_default_dtype()
43-
if device is None:
44-
device = torch.device("cpu")
4541

4642
# constants used in the forwward
4743
self.register_buffer(
4844
"_rsqrt2",
49-
torch.rsqrt(torch.tensor(2.0, dtype=dtype, device=device)),
45+
torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
5046
)
5147
self.register_buffer(
5248
"_sqrt_2_on_pi",
53-
torch.sqrt(torch.tensor(2.0 / torch.pi, dtype=dtype, device=device)),
49+
torch.sqrt(
50+
torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
51+
),
5452
)
5553

5654
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:

src/torchpme/potentials/inversepowerlaw.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ def __init__(
5353
device: Optional[torch.device] = None,
5454
):
5555
super().__init__(smearing, exclusion_radius, dtype, device)
56-
if dtype is None:
57-
dtype = torch.get_default_dtype()
58-
if device is None:
59-
device = torch.device("cpu")
6056

6157
if exponent <= 0 or exponent > 3:
6258
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3")
6359
self.register_buffer(
64-
"exponent", torch.tensor(exponent, dtype=dtype, device=device)
60+
"exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
6561
)
6662

6763
@torch.jit.export

src/torchpme/potentials/potential.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,18 @@ def __init__(
4242
device: Optional[torch.device] = None,
4343
):
4444
super().__init__()
45-
if dtype is None:
46-
dtype = torch.get_default_dtype()
47-
if device is None:
48-
device = torch.device("cpu")
45+
self.dtype = torch.get_default_dtype() if dtype is None else dtype
46+
self.device = "cpu" if device is None else device
4947
if smearing is not None:
5048
self.register_buffer(
51-
"smearing", torch.tensor(smearing, device=device, dtype=dtype)
49+
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
5250
)
5351
else:
5452
self.smearing = None
5553
if exclusion_radius is not None:
5654
self.register_buffer(
5755
"exclusion_radius",
58-
torch.tensor(exclusion_radius, device=device, dtype=dtype),
56+
torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype),
5957
)
6058
else:
6159
self.exclusion_radius = None

src/torchpme/potentials/spline.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,12 @@ def __init__(
6666
dtype=dtype,
6767
device=device,
6868
)
69-
if dtype is None:
70-
dtype = torch.get_default_dtype()
71-
if device is None:
72-
device = torch.device("cpu")
7369

7470
if len(y_grid) != len(r_grid):
7571
raise ValueError("Length of radial grid and value array mismatch.")
7672

77-
r_grid = r_grid.to(dtype=dtype, device=device)
78-
y_grid = y_grid.to(dtype=dtype, device=device)
73+
r_grid = r_grid.to(dtype=self.dtype, device=self.device)
74+
y_grid = y_grid.to(dtype=self.dtype, device=self.device)
7975

8076
if reciprocal:
8177
if torch.min(r_grid) <= 0.0:
@@ -93,7 +89,7 @@ def __init__(
9389
else:
9490
k_grid = r_grid.clone()
9591
else:
96-
k_grid = k_grid.to(dtype=dtype, device=device)
92+
k_grid = k_grid.to(dtype=self.dtype, device=self.device)
9793

9894
if yhat_grid is None:
9995
# computes automatically!
@@ -104,7 +100,7 @@ def __init__(
104100
compute_second_derivatives(r_grid, y_grid),
105101
)
106102
else:
107-
yhat_grid = yhat_grid.to(dtype=dtype, device=device)
103+
yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device)
108104

109105
# the function is defined for k**2, so we define the grid accordingly
110106
if reciprocal:
@@ -115,13 +111,15 @@ def __init__(
115111
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)
116112

117113
if y_at_zero is None:
118-
self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
114+
self._y_at_zero = self._spline(
115+
torch.zeros(1, dtype=self.dtype, device=self.device)
116+
)
119117
else:
120118
self._y_at_zero = y_at_zero
121119

122120
if yhat_at_zero is None:
123121
self._yhat_at_zero = self._krn_spline(
124-
torch.zeros(1, dtype=dtype, device=device)
122+
torch.zeros(1, dtype=self.dtype, device=self.device)
125123
)
126124
else:
127125
self._yhat_at_zero = yhat_at_zero

0 commit comments

Comments
 (0)