1
- # coding: utf-8
2
1
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
3
2
# Distributed under the terms of "New BSD License", see the LICENSE file.
4
3
5
4
import itertools
6
5
import warnings
7
- from typing import Dict , List , Optional , Tuple , Union
6
+ from typing import Optional , Union
8
7
9
8
import numpy as np
10
9
from ase .atoms import Atoms
@@ -104,11 +103,11 @@ def _set_mode(self, new_mode: str) -> None:
104
103
Raises:
105
104
KeyError: If the new mode is not found in the available modes.
106
105
"""
107
- if new_mode not in self ._mode . keys () :
106
+ if new_mode not in self ._mode :
108
107
raise KeyError (
109
108
f"{ new_mode } not found. Available modes: { ', ' .join (self ._mode .keys ())} "
110
109
)
111
- self ._mode = {key : False for key in self ._mode . keys () }
110
+ self ._mode = {key : False for key in self ._mode }
112
111
self ._mode [new_mode ] = True
113
112
114
113
def __repr__ (self ) -> str :
@@ -366,7 +365,7 @@ def _get_distances_and_indices(
366
365
num_neighbors : Optional [int ] = None ,
367
366
cutoff_radius : float = np .inf ,
368
367
width_buffer : float = 1.2 ,
369
- ) -> Tuple [np .ndarray , np .ndarray ]:
368
+ ) -> tuple [np .ndarray , np .ndarray ]:
370
369
"""
371
370
Get the distances and indices of the neighbors for the given positions.
372
371
@@ -406,7 +405,8 @@ def _get_distances_and_indices(
406
405
warnings .warn (
407
406
"Number of neighbors found within the cutoff_radius is equal to (estimated) "
408
407
+ "num_neighbors. Increase num_neighbors (or set it to None) or "
409
- + "width_buffer to find all neighbors within cutoff_radius."
408
+ + "width_buffer to find all neighbors within cutoff_radius." ,
409
+ stacklevel = 2 ,
410
410
)
411
411
self ._extended_indices = indices .copy ()
412
412
indices [distances < np .inf ] = self ._get_wrapped_indices ()[
@@ -508,7 +508,8 @@ def _estimate_num_neighbors(
508
508
if num_neighbors > self .num_neighbors :
509
509
warnings .warn (
510
510
"Taking a larger search area after initialization has the risk of "
511
- + "missing neighborhood atoms"
511
+ + "missing neighborhood atoms" ,
512
+ stacklevel = 2 ,
512
513
)
513
514
return num_neighbors
514
515
@@ -632,15 +633,14 @@ def _check_width(self, width: float, pbc: list[bool, bool, bool]) -> bool:
632
633
bool: True if the width exceeds the specified value, False otherwise.
633
634
634
635
"""
635
- if any (pbc ) and np .prod (self .filled .distances .shape ) > 0 :
636
- if (
637
- np .linalg .norm (
638
- self .flattened .vecs [..., pbc ], axis = - 1 , ord = self .norm_order
639
- ).max ()
640
- > width
641
- ):
642
- return True
643
- return False
636
+ return bool (
637
+ any (pbc )
638
+ and np .prod (self .filled .distances .shape ) > 0
639
+ and np .linalg .norm (
640
+ self .flattened .vecs [..., pbc ], axis = - 1 , ord = self .norm_order
641
+ ).max ()
642
+ > width
643
+ )
644
644
645
645
def get_spherical_harmonics (
646
646
self ,
@@ -811,9 +811,9 @@ def __getattr__(self, name):
811
811
def __dir__ (self ):
812
812
"""Show value names which are available for different filling modes."""
813
813
return list (
814
- set (
815
- [ "distances" , "vecs" , "indices" , "shells" , "atom_numbers" ]
816
- ). intersection ( self . ref_neigh . __dir__ ())
814
+ { "distances" , "vecs" , "indices" , "shells" , "atom_numbers" }. intersection (
815
+ self . ref_neigh . __dir__ ()
816
+ )
817
817
)
818
818
819
819
@@ -1008,7 +1008,7 @@ def get_global_shells(
1008
1008
1009
1009
def get_shell_matrix (
1010
1010
self ,
1011
- chemical_pair : Optional [List [str ]] = None ,
1011
+ chemical_pair : Optional [list [str ]] = None ,
1012
1012
cluster_by_distances : bool = False ,
1013
1013
cluster_by_vecs : bool = False ,
1014
1014
):
@@ -1225,7 +1225,7 @@ def reset_clusters(self, vecs: bool = True, distances: bool = True):
1225
1225
1226
1226
def cluster_analysis (
1227
1227
self , id_list : list , return_cluster_sizes : bool = False
1228
- ) -> Union [Dict [int , List [int ]], Tuple [ Dict [int , List [int ]], List [int ]]]:
1228
+ ) -> Union [dict [int , list [int ]], tuple [ dict [int , list [int ]], list [int ]]]:
1229
1229
"""
1230
1230
Perform cluster analysis on a list of atom IDs.
1231
1231
@@ -1240,11 +1240,8 @@ def cluster_analysis(
1240
1240
"""
1241
1241
self ._cluster = [0 ] * len (self ._ref_structure )
1242
1242
c_count = 1
1243
- # element_list = self.get_atomic_numbers()
1244
1243
for ia in id_list :
1245
- # el0 = element_list[ia]
1246
1244
nbrs = self .ragged .indices [ia ]
1247
- # print ("nbrs: ", ia, nbrs)
1248
1245
if self ._cluster [ia ] == 0 :
1249
1246
self ._cluster [ia ] = c_count
1250
1247
self .__probe_cluster (c_count , nbrs , id_list )
@@ -1261,7 +1258,7 @@ def cluster_analysis(
1261
1258
return cluster_dict # sizes
1262
1259
1263
1260
def __probe_cluster (
1264
- self , c_count : int , neighbors : List [int ], id_list : List [int ]
1261
+ self , c_count : int , neighbors : list [int ], id_list : list [int ]
1265
1262
) -> None :
1266
1263
"""
1267
1264
Recursively probe the cluster and assign cluster IDs to neighbors.
@@ -1275,19 +1272,20 @@ def __probe_cluster(
1275
1272
None
1276
1273
"""
1277
1274
for nbr_id in neighbors :
1278
- if self ._cluster [nbr_id ] == 0 :
1279
- if nbr_id in id_list : # TODO: check also for ordered structures
1280
- self ._cluster [nbr_id ] = c_count
1281
- nbrs = self .ragged .indices [nbr_id ]
1282
- self .__probe_cluster (c_count , nbrs , id_list )
1275
+ if (
1276
+ self ._cluster [nbr_id ] == 0 and nbr_id in id_list
1277
+ ): # TODO: check also for ordered structures
1278
+ self ._cluster [nbr_id ] = c_count
1279
+ nbrs = self .ragged .indices [nbr_id ]
1280
+ self .__probe_cluster (c_count , nbrs , id_list )
1283
1281
1284
1282
# TODO: combine with corresponding routine in plot3d
1285
1283
def get_bonds (
1286
1284
self ,
1287
1285
radius : float = np .inf ,
1288
1286
max_shells : Optional [int ] = None ,
1289
1287
prec : float = 0.1 ,
1290
- ) -> List [ Dict [str , List [ List [int ]]]]:
1288
+ ) -> list [ dict [str , list [ list [int ]]]]:
1291
1289
"""
1292
1290
Get the bonds in the structure.
1293
1291
@@ -1303,7 +1301,7 @@ def get_bonds(
1303
1301
1304
1302
def get_cluster (
1305
1303
dist_vec : np .ndarray , ind_vec : np .ndarray , prec : float = prec
1306
- ) -> List [np .ndarray ]:
1304
+ ) -> list [np .ndarray ]:
1307
1305
"""
1308
1306
Get clusters from a distance vector and index vector.
1309
1307
@@ -1326,7 +1324,6 @@ def get_cluster(
1326
1324
ind_shell = []
1327
1325
for d , i in zip (dist , ind ):
1328
1326
id_list = get_cluster (d [d < radius ], i [d < radius ])
1329
- # print ("id: ", d[d<radius], id_list, dist_lst)
1330
1327
ia_shells_dict = {}
1331
1328
for i_shell_list in id_list :
1332
1329
ia_shell_dict = {}
@@ -1338,9 +1335,11 @@ def get_cluster(
1338
1335
for el , ia_lst in ia_shell_dict .items ():
1339
1336
if el not in ia_shells_dict :
1340
1337
ia_shells_dict [el ] = []
1341
- if max_shells is not None :
1342
- if len (ia_shells_dict [el ]) + 1 > max_shells :
1343
- continue
1338
+ if (
1339
+ max_shells is not None
1340
+ and len (ia_shells_dict [el ]) + 1 > max_shells
1341
+ ):
1342
+ continue
1344
1343
ia_shells_dict [el ].append (ia_lst )
1345
1344
ind_shell .append (ia_shells_dict )
1346
1345
return ind_shell
@@ -1457,7 +1456,8 @@ def _get_neighbors(
1457
1456
if neigh ._check_width (width = width , pbc = structure .pbc ):
1458
1457
warnings .warn (
1459
1458
"width_buffer may have been too small - "
1460
- "most likely not all neighbors properly assigned"
1459
+ "most likely not all neighbors properly assigned" ,
1460
+ stacklevel = 2 ,
1461
1461
)
1462
1462
return neigh
1463
1463
0 commit comments