Skip to content

Commit 8618e9d

Browse files
committed
clean up interface
1 parent fac8881 commit 8618e9d

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

structuretoolkit/analyse/fitsnap.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99
def get_ace_descriptor_derivatives(
1010
structure: Atoms,
1111
atom_types: list[str],
12-
ranks: list[int] = [1, 2, 3, 4, 5, 6],
13-
lmax: list[int] = [1, 2, 2, 2, 1, 1],
14-
nmax: list[int] = [22, 2, 2, 2, 1, 1],
12+
ranks: list[int] = [1, 2, 3, 4],
13+
lmax: list[int] = [0, 5, 2, 1],
14+
nmax: list[int] = [22, 5, 3, 1],
15+
mumax: int = 1,
1516
nmaxbase: int = 22,
16-
rcutfac: float = 4.604694451,
17-
lambda_value: float = 3.059235105,
18-
lmin: list[int] = [1, 1, 1, 1, 1, 1],
17+
erefs: list[float] = [0.0],
18+
rcutfac: float = 4.5,
19+
rcinner: float = 1.2,
20+
drcinner: float = 0.01,
21+
RPI_heuristic: str = "root_SO3_span",
22+
lambda_value: float = 1.275,
23+
lmin: list[int] = [0, 0, 1, 1],
1924
bzeroflag: bool = True,
2025
cutoff: float = 10.0,
2126
) -> np.ndarray:
@@ -43,13 +48,19 @@ def get_ace_descriptor_derivatives(
4348
"ranks": " ".join([str(r) for r in ranks]),
4449
"lmax": " ".join([str(l) for l in lmax]),
4550
"nmax": " ".join([str(n) for n in nmax]),
51+
"mumax": mumax,
4652
"nmaxbase": nmaxbase,
4753
"rcutfac": rcutfac,
54+
"erefs": " ".join([str(e) for e in erefs]),
55+
"rcinner": rcinner,
56+
"drcinner": drcinner,
57+
"RPI_heuristic": RPI_heuristic,
4858
"lambda": lambda_value,
4959
"type": " ".join(atom_types),
5060
"lmin": " ".join([str(l) for l in lmin]),
51-
"bzeroflag": True,
61+
"bzeroflag": bzeroflag,
5262
"bikflag": True,
63+
"dgradflag": True,
5364
},
5465
"CALCULATOR": {
5566
"calculator": "LAMMPSPACE",

tests/test_fitsnap.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,19 @@ def test_get_ace_descriptor_derivatives(self):
5555
mat_a = get_ace_descriptor_derivatives(
5656
structure=self.structure,
5757
atom_types=self.type,
58-
ranks=[1, 2, 3, 4, 5, 6],
59-
lmax=[1, 2, 2, 2, 1, 1],
60-
nmax=[22, 2, 2, 2, 1, 1],
58+
ranks=[1, 2, 3, 4],
59+
lmax=[0, 5, 2, 1],
60+
nmax=[22, 5, 3, 1],
61+
mumax=1,
6162
nmaxbase=22,
62-
rcutfac=4.604694451,
63-
lambda_value=3.059235105,
64-
lmin=[1, 1, 1, 1, 1, 1],
63+
erefs=[0.0],
64+
rcutfac=4.5,
65+
rcinner=1.2,
66+
drcinner=0.01,
67+
RPI_heuristic="root_SO3_span",
68+
lambda_value=1.275,
69+
lmin=[0, 0, 1, 1],
70+
bzeroflag=True,
6571
cutoff=10.0,
6672
)
67-
self.assertEqual(mat_a.shape, (16, 68))
73+
self.assertEqual(mat_a.shape, (16, 141))

0 commit comments

Comments
 (0)