Skip to content

Commit 6e9fc68

Browse files
committed
Implement Fitsnap interface for SNAP and ACE
1 parent ab5021d commit 6e9fc68

File tree

3 files changed

+339
-0
lines changed

3 files changed

+339
-0
lines changed

structuretoolkit/analyse/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929
get_snap_descriptor_derivatives,
3030
)
3131

32+
try:
33+
from structuretoolkit.analyse.fitsnap import (
34+
get_snap_descriptor_derivatives as get_snap_descriptor_derivatives_fitsnap,
35+
get_ace_descriptor_derivatives as get_ace_descriptor_derivatives_fitsnap,
36+
get_ace_descriptor_derivatives,
37+
)
38+
except ImportError:
39+
pass
40+
3241

3342
def get_symmetry(
3443
structure, use_magmoms=False, use_elements=True, symprec=1e-5, angle_tolerance=-1.0

structuretoolkit/analyse/fitsnap.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from ase.atoms import Atoms
2+
from typing import Optional, Union
3+
import random
4+
import numpy as np
5+
from fitsnap3lib.fitsnap import FitSnap
6+
7+
8+
def get_ace_descriptor_derivatives(
9+
structure: Atoms,
10+
atom_types: list[str],
11+
ranks: list[int] = [1, 2, 3, 4, 5, 6],
12+
lmax: list[int] = [1, 2, 2, 2, 1, 1],
13+
nmax: list[int] = [22, 2, 2, 2, 1, 1],
14+
nmaxbase: int = 22,
15+
rcutfac: float = 4.604694451,
16+
lambda_value: float = 3.059235105,
17+
lmin: list[int] = [1, 1, 1, 1, 1, 1],
18+
bzeroflag: bool = True,
19+
cutoff: float = 10.0,
20+
) -> np.ndarray:
21+
"""
22+
Calculate per atom ACE descriptors using FitSNAP https://fitsnap.github.io
23+
24+
Args:
25+
structure (ase.atoms.Atoms): atomistic structure as ASE atoms object
26+
atom_types (list[str]): list of element types
27+
ranks (list):
28+
lmax (list):
29+
nmax (list):
30+
nmaxbase (int):
31+
rcutfac (float):
32+
lambda_value (float):
33+
lmin (list):
34+
cutoff (float): cutoff radius for the construction of the neighbor list
35+
36+
Returns:
37+
np.ndarray: Numpy array with the calculated descriptor derivatives
38+
"""
39+
settings = {
40+
"ACE": {
41+
"numTypes": len(atom_types),
42+
"ranks": " ".join([str(r) for r in ranks]),
43+
"lmax": " ".join([str(l) for l in lmax]),
44+
"nmax": " ".join([str(n) for n in nmax]),
45+
"nmaxbase": nmaxbase,
46+
"rcutfac": rcutfac,
47+
"lambda": lambda_value,
48+
"type": " ".join(atom_types),
49+
"lmin": " ".join([str(l) for l in lmin]),
50+
"bzeroflag": True,
51+
"bikflag": True,
52+
},
53+
"CALCULATOR": {
54+
"calculator": "LAMMPSPACE",
55+
"energy": 1,
56+
"force": 1,
57+
"stress": 0,
58+
},
59+
"REFERENCE": {
60+
"units": "metal",
61+
"atom_style": "atomic",
62+
"pair_style": "zero " + str(cutoff),
63+
"pair_coeff": "* *",
64+
},
65+
}
66+
fs = FitSnap(settings, comm=None, arglist=["--overwrite"])
67+
a, b, w = fs.calculator.process_single(_ase_scraper(data=[structure])[0])
68+
return a
69+
70+
71+
def get_snap_descriptor_derivatives(
72+
structure: Atoms,
73+
atom_types: list[str],
74+
twojmax: int = 6,
75+
element_radius: list[int] = [4.0],
76+
rcutfac: float = 1.0,
77+
rfac0: float = 0.99363,
78+
rmin0: float = 0.0,
79+
bzeroflag: bool = False,
80+
quadraticflag: bool = False,
81+
weights: Optional[Union[list, np.ndarray]] = None,
82+
cutoff: float = 10.0,
83+
) -> np.ndarray:
84+
"""
85+
Calculate per atom SNAP descriptors using FitSNAP https://fitsnap.github.io
86+
87+
Args:
88+
structure (ase.atoms.Atoms): atomistic structure as ASE atoms object
89+
atom_types (list[str]): list of element types
90+
twojmax (int): band limit for bispectrum components (non-negative integer)
91+
element_radius (list[int]): list of radii for the individual elements
92+
rcutfac (float): scale factor applied to all cutoff radii (positive real)
93+
rfac0 (float): parameter in distance to angle conversion (0 < rcutfac < 1)
94+
rmin0 (float): parameter in distance to angle conversion (distance units)
95+
bzeroflag (bool): subtract B0
96+
quadraticflag (bool): generate quadratic terms
97+
weights (list/np.ndarry/None): list of neighbor weights, one for each type
98+
cutoff (float): cutoff radius for the construction of the neighbor list
99+
100+
Returns:
101+
np.ndarray: Numpy array with the calculated descriptor derivatives
102+
"""
103+
if weights is None:
104+
weights = [1.0] * len(atom_types)
105+
settings = {
106+
"BISPECTRUM": {
107+
"numTypes": len(atom_types),
108+
"twojmax": twojmax,
109+
"rcutfac": rcutfac,
110+
"rfac0": rfac0,
111+
"rmin0": rmin0,
112+
"wj": " ".join([str(w) for w in weights]),
113+
"radelem": " ".join([str(r) for r in element_radius]),
114+
"type": " ".join(atom_types),
115+
"wselfallflag": 0,
116+
"chemflag": 0,
117+
"bzeroflag": bzeroflag,
118+
"quadraticflag": quadraticflag,
119+
},
120+
"CALCULATOR": {
121+
"calculator": "LAMMPSSNAP",
122+
"energy": 1,
123+
"force": 1,
124+
"stress": 0,
125+
},
126+
"REFERENCE": {
127+
"units": "metal",
128+
"atom_style": "atomic",
129+
"pair_style": "zero " + str(cutoff),
130+
"pair_coeff": "* *",
131+
},
132+
}
133+
fs = FitSnap(settings, comm=None, arglist=["--overwrite"])
134+
a, b, w = fs.calculator.process_single(_ase_scraper(data=[structure])[0])
135+
return a
136+
137+
138+
def _assign_validation(group_table):
139+
"""
140+
Given a dictionary of group info, add another key for test bools.
141+
142+
Args:
143+
group_table: Dictionary of group names. Must have keys "nconfigs" and "testing_size".
144+
145+
Modifies the dictionary in place by adding another key "test_bools".
146+
"""
147+
148+
for name in group_table:
149+
nconfigs = group_table[name]["nconfigs"]
150+
assert "testing_size" in group_table[name]
151+
assert group_table[name]["testing_size"] <= 1.0
152+
test_bools = [
153+
random.random() < group_table[name]["testing_size"]
154+
for i in range(0, nconfigs)
155+
]
156+
157+
group_table[name]["test_bools"] = test_bools
158+
159+
160+
def _ase_scraper(data) -> list:
161+
"""
162+
Function to organize groups and allocate shared arrays used in Calculator. For now when using
163+
ASE frames, we don't have groups.
164+
165+
Args:
166+
s: fitsnap instance.
167+
data: List of ASE frames or dictionary group table containing frames.
168+
169+
Returns a list of data dictionaries suitable for fitsnap descriptor calculator.
170+
If running in parallel, this list will be distributed over procs, so that each proc will have a
171+
portion of the list.
172+
"""
173+
174+
# Simply collate data from Atoms objects if we have a list of Atoms objects.
175+
if type(data) == list:
176+
# s.data = [collate_data(atoms) for atoms in data]
177+
return [_collate_data(atoms) for atoms in data]
178+
# If we have a dictionary, assume we are dealing with groups.
179+
elif type(data) == dict:
180+
_assign_validation(data)
181+
# s.data = []
182+
ret = []
183+
for name in data:
184+
frames = data[name]["frames"]
185+
# Extend the fitsnap data list with this group.
186+
# s.data.extend([collate_data(atoms, name, data[name]) for atoms in frames])
187+
ret.extend([_collate_data(atoms, name, data[name]) for atoms in frames])
188+
return ret
189+
else:
190+
raise Exception("Argument must be list or dictionary for ASE scraper.")
191+
192+
193+
def _get_apre(cell):
194+
"""
195+
Calculate transformed ASE cell for LAMMPS calculations. Thank you Jan Janssen!
196+
197+
Args:
198+
cell: ASE atoms cell.
199+
200+
Returns transformed cell as np.array which is suitable for LAMMPS.
201+
"""
202+
a, b, c = cell
203+
an, bn, cn = [np.linalg.norm(v) for v in cell]
204+
205+
alpha = np.arccos(np.dot(b, c) / (bn * cn))
206+
beta = np.arccos(np.dot(a, c) / (an * cn))
207+
gamma = np.arccos(np.dot(a, b) / (an * bn))
208+
209+
xhi = an
210+
xyp = np.cos(gamma) * bn
211+
yhi = np.sin(gamma) * bn
212+
xzp = np.cos(beta) * cn
213+
yzp = (bn * cn * np.cos(alpha) - xyp * xzp) / yhi
214+
zhi = np.sqrt(cn**2 - xzp**2 - yzp**2)
215+
216+
return np.array(((xhi, 0, 0), (xyp, yhi, 0), (xzp, yzp, zhi)))
217+
218+
219+
def _collate_data(atoms, name: str = None, group_dict: dict = None) -> dict:
220+
"""
221+
Function to organize fitting data for FitSNAP from ASE atoms objects.
222+
223+
Args:
224+
atoms: ASE atoms object for a single configuration of atoms.
225+
name: Optional name of this configuration.
226+
group_dict: Optional dictionary containing group information.
227+
228+
Returns a data dictionary for a single configuration.
229+
"""
230+
231+
# Transform ASE cell to be appropriate for LAMMPS.
232+
apre = _get_apre(cell=atoms.cell)
233+
R = np.dot(np.linalg.inv(atoms.cell), apre)
234+
positions = np.matmul(atoms.get_positions(), R)
235+
cell = apre.T
236+
237+
# Make a data dictionary for this config.
238+
239+
data = {}
240+
data["Group"] = name # 'ASE' # TODO: Make this customizable for ASE groups.
241+
data["File"] = None
242+
data["Stress"] = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
243+
data["Positions"] = positions
244+
data["Energy"] = 0.0
245+
data["AtomTypes"] = atoms.get_chemical_symbols()
246+
data["NumAtoms"] = len(atoms)
247+
data["Forces"] = np.array([0.0, 0.0, 0.0] * len(atoms))
248+
data["QMLattice"] = cell
249+
data["test_bool"] = 0
250+
data["Lattice"] = cell
251+
data["Rotation"] = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
252+
data["Translation"] = np.zeros((len(atoms), 3))
253+
# Inject the weights.
254+
if group_dict is not None:
255+
data["eweight"] = group_dict["eweight"] if "eweight" in group_dict else 1.0
256+
data["fweight"] = group_dict["fweight"] if "fweight" in group_dict else 1.0
257+
data["vweight"] = group_dict["vweight"] if "vweight" in group_dict else 1.0
258+
else:
259+
data["eweight"] = 1.0
260+
data["fweight"] = 1.0
261+
data["vweight"] = 1.0
262+
263+
return data

tests/test_fitsnap.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from ase.build.bulk import bulk
2+
import structuretoolkit as stk
3+
import unittest
4+
5+
6+
try:
7+
from structuretoolkit.analyse.fitsnap import (
8+
get_ace_descriptor_derivatives,
9+
get_snap_descriptor_derivatives,
10+
)
11+
12+
skip_snap_test = False
13+
except ImportError:
14+
skip_snap_test = True
15+
16+
17+
@unittest.skipIf(
18+
skip_snap_test, "LAMMPS is not installed, so the SNAP tests are skipped."
19+
)
20+
class TestSNAP(unittest.TestCase):
21+
@classmethod
22+
def setUpClass(cls):
23+
cls.structure = bulk("Cu", cubic=True)
24+
cls.numtypes = 1
25+
cls.twojmax = 6
26+
cls.rcutfac = 1.0
27+
cls.rfac0 = 0.99363
28+
cls.rmin0 = 0.0
29+
cls.bzeroflag = False
30+
cls.quadraticflag = False
31+
cls.radelem = [4.0]
32+
cls.type = ['Cu']
33+
cls.wj = [1.0]
34+
35+
def test_get_snap_descriptor_derivatives(self):
36+
n_coeff = len(stk.analyse.get_snap_descriptor_names(
37+
twojmax=self.twojmax
38+
))
39+
mat_a = get_snap_descriptor_derivatives(
40+
structure=self.structure,
41+
atom_types=self.type,
42+
twojmax=self.twojmax,
43+
element_radius=self.radelem,
44+
rcutfac=self.rcutfac,
45+
rfac0=self.rfac0,
46+
rmin0=self.rmin0,
47+
bzeroflag=self.bzeroflag,
48+
quadraticflag=self.quadraticflag,
49+
weights=self.wj,
50+
cutoff=10.0,
51+
)
52+
self.assertEqual(mat_a.shape, (len(self.structure) * 3 + 7, n_coeff + 1))
53+
54+
def test_get_ace_descriptor_derivatives(self):
55+
mat_a = get_ace_descriptor_derivatives(
56+
structure=self.structure,
57+
atom_types=self.type,
58+
ranks=[1, 2, 3, 4, 5, 6],
59+
lmax=[1, 2, 2, 2, 1, 1],
60+
nmax=[22, 2, 2, 2, 1, 1],
61+
nmaxbase=22,
62+
rcutfac=4.604694451,
63+
lambda_value=3.059235105,
64+
lmin=[1, 1, 1, 1, 1, 1],
65+
cutoff=10.0,
66+
)
67+
self.assertEqual(mat_a.shape, (16, 68))

0 commit comments

Comments
 (0)