Skip to content

Commit dd56d66

Browse files
authored
extend ruff linter (#315)
* extend ruff linter * fix broken comparison
1 parent d1cc0a4 commit dd56d66

20 files changed

+329
-163
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repos:
44
hooks:
55
- id: ruff
66
name: ruff lint
7-
args: ["--select", "I", "--fix"]
7+
args: ["--fix"]
88
files: ^structuretoolkit/
99
- id: ruff-format
1010
name: ruff format

pyproject.toml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,53 @@ include = ["structuretoolkit*"]
6363
[tool.setuptools.dynamic]
6464
version = {attr = "structuretoolkit.__version__"}
6565

66+
[tool.ruff]
67+
exclude = [".ci_support", "tests", "setup.py", "_version.py"]
68+
69+
[tool.ruff.lint]
70+
select = [
71+
# pycodestyle
72+
"E",
73+
# Pyflakes
74+
"F",
75+
# pyupgrade
76+
"UP",
77+
# flake8-bugbear
78+
"B",
79+
# flake8-simplify
80+
"SIM",
81+
# isort
82+
"I",
83+
# flake8-comprehensions
84+
"C4",
85+
# eradicate
86+
"ERA",
87+
# pylint
88+
"PL",
89+
]
90+
ignore = [
91+
# ignore functions in argument defaults
92+
"B008",
93+
# ignore exception naming
94+
"B904",
95+
# ignore line-length violations
96+
"E501",
97+
# ignore equality comparisons for numpy arrays
98+
"E712",
99+
# ignore bare except
100+
"E722",
101+
# ignore ambiguous variable name
102+
"E741",
103+
# Too many arguments in function definition
104+
"PLR0913",
105+
# Magic value used in comparison
106+
"PLR2004",
107+
# Too many branches
108+
"PLR0912",
109+
# Too many statements
110+
"PLR0915",
111+
]
112+
66113
[tool.versioneer]
67114
VCS = "git"
68115
style = "pep440-pre"

structuretoolkit/__init__.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,62 @@
9898

9999
from . import _version
100100

101+
__all__ = [
102+
"find_mic",
103+
"find_solids",
104+
"get_adaptive_cna_descriptors",
105+
"get_average_of_unique_labels",
106+
"get_centro_symmetry_descriptors",
107+
"get_cluster_positions",
108+
"get_delaunay_neighbors",
109+
"get_diamond_structure_descriptors",
110+
"get_distances_array",
111+
"get_equivalent_atoms",
112+
"get_interstitials",
113+
"get_layers",
114+
"get_mean_positions",
115+
"get_neighborhood",
116+
"get_neighbors",
117+
"get_steinhardt_parameters",
118+
"get_strain",
119+
"get_symmetry",
120+
"get_voronoi_neighbors",
121+
"get_voronoi_vertices",
122+
"get_voronoi_volumes",
123+
"analyse_find_solids",
124+
"analyse_cna_adaptive",
125+
"analyse_centro_symmetry",
126+
"cluster_positions",
127+
"analyse_diamond_structure",
128+
"analyse_phonopy_equivalent_atoms",
129+
"get_steinhardt_parameter_structure",
130+
"analyse_voronoi_volume",
131+
"B2",
132+
"C14",
133+
"C15",
134+
"C36",
135+
"D03",
136+
"create_mesh",
137+
"get_grainboundary_info",
138+
"get_high_index_surface_info",
139+
"grainboundary",
140+
"high_index_surface",
141+
"sqs_structures",
142+
"grainboundary_info",
143+
"high_index_surface_info",
144+
"grainboundary_build",
145+
"get_sqs_structures",
146+
"SymmetryError",
147+
"apply_strain",
148+
"ase_to_pymatgen",
149+
"ase_to_pyscal",
150+
"center_coordinates_in_unit_cell",
151+
"get_cell",
152+
"get_extended_positions",
153+
"get_vertical_length",
154+
"get_wrapped_coordinates",
155+
"pymatgen_to_ase",
156+
"select_index",
157+
"plot3d",
158+
]
101159
__version__ = _version.get_versions()["version"]

structuretoolkit/analyse/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,39 @@ def get_ir_reciprocal_mesh(
254254
is_shift=is_shift,
255255
is_time_reversal=is_time_reversal,
256256
)
257+
258+
259+
__all__ = [
260+
"find_mic",
261+
"get_distances_array",
262+
"soap_descriptor_per_atom",
263+
"get_neighborhood",
264+
"get_neighbors",
265+
"get_equivalent_atoms",
266+
"find_solids",
267+
"get_adaptive_cna_descriptors",
268+
"get_centro_symmetry_descriptors",
269+
"get_diamond_structure_descriptors",
270+
"get_steinhardt_parameters",
271+
"get_voronoi_volumes",
272+
"get_snap_descriptor_derivatives",
273+
"get_snap_descriptor_names",
274+
"get_snap_descriptors_per_atom",
275+
"get_average_of_unique_labels",
276+
"get_cluster_positions",
277+
"get_delaunay_neighbors",
278+
"get_interstitials",
279+
"get_layers",
280+
"get_mean_positions",
281+
"get_voronoi_neighbors",
282+
"get_voronoi_vertices",
283+
"get_strain",
284+
"get_ir_reciprocal_mesh",
285+
"get_symmetry",
286+
"symmetrize_vectors",
287+
"group_points_by_symmetry",
288+
"get_primitive_cell",
289+
"get_spacegroup",
290+
"get_symmetry_dataset",
291+
"get_equivalent_points",
292+
]

structuretoolkit/analyse/dscribe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def soap_descriptor_per_atom(
1313
rbf: str = "gto",
1414
weighting: Optional[np.ndarray] = None,
1515
average: str = "off",
16-
compression: dict = {"mode": "off", "species_weighting": None},
16+
compression: dict = None,
1717
species: Optional[list] = None,
1818
periodic: bool = True,
1919
sparse: bool = False,
@@ -50,6 +50,8 @@ def soap_descriptor_per_atom(
5050
"""
5151
from dscribe.descriptors import SOAP
5252

53+
if compression is None:
54+
compression = {"mode": "off", "species_weighting": None}
5355
if species is None:
5456
species = list(set(structure.get_chemical_symbols()))
5557
periodic_soap = SOAP(

structuretoolkit/analyse/neighbors.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
# coding: utf-8
21
# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department
32
# Distributed under the terms of "New BSD License", see the LICENSE file.
43

54
import itertools
65
import warnings
7-
from typing import Dict, List, Optional, Tuple, Union
6+
from typing import Optional, Union
87

98
import numpy as np
109
from ase.atoms import Atoms
@@ -104,11 +103,11 @@ def _set_mode(self, new_mode: str) -> None:
104103
Raises:
105104
KeyError: If the new mode is not found in the available modes.
106105
"""
107-
if new_mode not in self._mode.keys():
106+
if new_mode not in self._mode:
108107
raise KeyError(
109108
f"{new_mode} not found. Available modes: {', '.join(self._mode.keys())}"
110109
)
111-
self._mode = {key: False for key in self._mode.keys()}
110+
self._mode = {key: False for key in self._mode}
112111
self._mode[new_mode] = True
113112

114113
def __repr__(self) -> str:
@@ -366,7 +365,7 @@ def _get_distances_and_indices(
366365
num_neighbors: Optional[int] = None,
367366
cutoff_radius: float = np.inf,
368367
width_buffer: float = 1.2,
369-
) -> Tuple[np.ndarray, np.ndarray]:
368+
) -> tuple[np.ndarray, np.ndarray]:
370369
"""
371370
Get the distances and indices of the neighbors for the given positions.
372371
@@ -406,7 +405,8 @@ def _get_distances_and_indices(
406405
warnings.warn(
407406
"Number of neighbors found within the cutoff_radius is equal to (estimated) "
408407
+ "num_neighbors. Increase num_neighbors (or set it to None) or "
409-
+ "width_buffer to find all neighbors within cutoff_radius."
408+
+ "width_buffer to find all neighbors within cutoff_radius.",
409+
stacklevel=2,
410410
)
411411
self._extended_indices = indices.copy()
412412
indices[distances < np.inf] = self._get_wrapped_indices()[
@@ -508,7 +508,8 @@ def _estimate_num_neighbors(
508508
if num_neighbors > self.num_neighbors:
509509
warnings.warn(
510510
"Taking a larger search area after initialization has the risk of "
511-
+ "missing neighborhood atoms"
511+
+ "missing neighborhood atoms",
512+
stacklevel=2,
512513
)
513514
return num_neighbors
514515

@@ -632,15 +633,14 @@ def _check_width(self, width: float, pbc: list[bool, bool, bool]) -> bool:
632633
bool: True if the width exceeds the specified value, False otherwise.
633634
634635
"""
635-
if any(pbc) and np.prod(self.filled.distances.shape) > 0:
636-
if (
637-
np.linalg.norm(
638-
self.flattened.vecs[..., pbc], axis=-1, ord=self.norm_order
639-
).max()
640-
> width
641-
):
642-
return True
643-
return False
636+
return bool(
637+
any(pbc)
638+
and np.prod(self.filled.distances.shape) > 0
639+
and np.linalg.norm(
640+
self.flattened.vecs[..., pbc], axis=-1, ord=self.norm_order
641+
).max()
642+
> width
643+
)
644644

645645
def get_spherical_harmonics(
646646
self,
@@ -811,9 +811,9 @@ def __getattr__(self, name):
811811
def __dir__(self):
812812
"""Show value names which are available for different filling modes."""
813813
return list(
814-
set(
815-
["distances", "vecs", "indices", "shells", "atom_numbers"]
816-
).intersection(self.ref_neigh.__dir__())
814+
{"distances", "vecs", "indices", "shells", "atom_numbers"}.intersection(
815+
self.ref_neigh.__dir__()
816+
)
817817
)
818818

819819

@@ -1008,7 +1008,7 @@ def get_global_shells(
10081008

10091009
def get_shell_matrix(
10101010
self,
1011-
chemical_pair: Optional[List[str]] = None,
1011+
chemical_pair: Optional[list[str]] = None,
10121012
cluster_by_distances: bool = False,
10131013
cluster_by_vecs: bool = False,
10141014
):
@@ -1225,7 +1225,7 @@ def reset_clusters(self, vecs: bool = True, distances: bool = True):
12251225

12261226
def cluster_analysis(
12271227
self, id_list: list, return_cluster_sizes: bool = False
1228-
) -> Union[Dict[int, List[int]], Tuple[Dict[int, List[int]], List[int]]]:
1228+
) -> Union[dict[int, list[int]], tuple[dict[int, list[int]], list[int]]]:
12291229
"""
12301230
Perform cluster analysis on a list of atom IDs.
12311231
@@ -1240,11 +1240,8 @@ def cluster_analysis(
12401240
"""
12411241
self._cluster = [0] * len(self._ref_structure)
12421242
c_count = 1
1243-
# element_list = self.get_atomic_numbers()
12441243
for ia in id_list:
1245-
# el0 = element_list[ia]
12461244
nbrs = self.ragged.indices[ia]
1247-
# print ("nbrs: ", ia, nbrs)
12481245
if self._cluster[ia] == 0:
12491246
self._cluster[ia] = c_count
12501247
self.__probe_cluster(c_count, nbrs, id_list)
@@ -1261,7 +1258,7 @@ def cluster_analysis(
12611258
return cluster_dict # sizes
12621259

12631260
def __probe_cluster(
1264-
self, c_count: int, neighbors: List[int], id_list: List[int]
1261+
self, c_count: int, neighbors: list[int], id_list: list[int]
12651262
) -> None:
12661263
"""
12671264
Recursively probe the cluster and assign cluster IDs to neighbors.
@@ -1275,19 +1272,20 @@ def __probe_cluster(
12751272
None
12761273
"""
12771274
for nbr_id in neighbors:
1278-
if self._cluster[nbr_id] == 0:
1279-
if nbr_id in id_list: # TODO: check also for ordered structures
1280-
self._cluster[nbr_id] = c_count
1281-
nbrs = self.ragged.indices[nbr_id]
1282-
self.__probe_cluster(c_count, nbrs, id_list)
1275+
if (
1276+
self._cluster[nbr_id] == 0 and nbr_id in id_list
1277+
): # TODO: check also for ordered structures
1278+
self._cluster[nbr_id] = c_count
1279+
nbrs = self.ragged.indices[nbr_id]
1280+
self.__probe_cluster(c_count, nbrs, id_list)
12831281

12841282
# TODO: combine with corresponding routine in plot3d
12851283
def get_bonds(
12861284
self,
12871285
radius: float = np.inf,
12881286
max_shells: Optional[int] = None,
12891287
prec: float = 0.1,
1290-
) -> List[Dict[str, List[List[int]]]]:
1288+
) -> list[dict[str, list[list[int]]]]:
12911289
"""
12921290
Get the bonds in the structure.
12931291
@@ -1303,7 +1301,7 @@ def get_bonds(
13031301

13041302
def get_cluster(
13051303
dist_vec: np.ndarray, ind_vec: np.ndarray, prec: float = prec
1306-
) -> List[np.ndarray]:
1304+
) -> list[np.ndarray]:
13071305
"""
13081306
Get clusters from a distance vector and index vector.
13091307
@@ -1326,7 +1324,6 @@ def get_cluster(
13261324
ind_shell = []
13271325
for d, i in zip(dist, ind):
13281326
id_list = get_cluster(d[d < radius], i[d < radius])
1329-
# print ("id: ", d[d<radius], id_list, dist_lst)
13301327
ia_shells_dict = {}
13311328
for i_shell_list in id_list:
13321329
ia_shell_dict = {}
@@ -1338,9 +1335,11 @@ def get_cluster(
13381335
for el, ia_lst in ia_shell_dict.items():
13391336
if el not in ia_shells_dict:
13401337
ia_shells_dict[el] = []
1341-
if max_shells is not None:
1342-
if len(ia_shells_dict[el]) + 1 > max_shells:
1343-
continue
1338+
if (
1339+
max_shells is not None
1340+
and len(ia_shells_dict[el]) + 1 > max_shells
1341+
):
1342+
continue
13441343
ia_shells_dict[el].append(ia_lst)
13451344
ind_shell.append(ia_shells_dict)
13461345
return ind_shell
@@ -1457,7 +1456,8 @@ def _get_neighbors(
14571456
if neigh._check_width(width=width, pbc=structure.pbc):
14581457
warnings.warn(
14591458
"width_buffer may have been too small - "
1460-
"most likely not all neighbors properly assigned"
1459+
"most likely not all neighbors properly assigned",
1460+
stacklevel=2,
14611461
)
14621462
return neigh
14631463

0 commit comments

Comments
 (0)