Skip to content

Commit d24562d

Browse files
Merge pull request #37261 from ajjackson/abins-pydantic
Try using pydantic to cut down on type validation boilerplate
2 parents 6c27d44 + e6629e3 commit d24562d

18 files changed

+200
-283
lines changed

conda/recipes/mantid/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ requirements:
7474
- joblib
7575
- orsopy {{ orsopy }}
7676
- quasielasticbayes
77+
- pydantic
78+
7779
run_constrained:
7880
- matplotlib {{ matplotlib }}
7981

mantid-developer-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies:
5151
- versioningit>=2.1
5252
- joblib
5353
- orsopy==1.2.1 # Fix the version to avoid updates being pulled in automatically, which might change the Reflectometry ORSO file content or layout and cause tests to fail.
54+
- pydantic
5455

5556
# Not Windows, OpenGL implementation:
5657
- mesa-libgl-devel-cos7-x86_64>=18.3.4

mantid-developer-osx.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies:
4949
- versioningit>=2.1
5050
- joblib
5151
- orsopy==1.2.1 # Fix the version to avoid updates being pulled in automatically, which might change the Reflectometry ORSO file content or layout and cause tests to fail.
52+
- pydantic
5253

5354
# Needed only for development
5455
- black # may be out of sync with pre-commit

mantid-developer-win.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ dependencies:
4848
- versioningit>=2.1
4949
- joblib
5050
- orsopy==1.2.1 # Fix the version to avoid updates being pulled in automatically, which might change the Reflectometry ORSO file content or layout and cause tests to fail.
51+
- pydantic
52+
5153
# Needed only for development
5254
- black # may be out of sync with pre-commit
5355
- cppcheck==2.14.2
5456
- pre-commit>=2.12.0
57+

scripts/abins/abinsdata.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# SPDX - License - Identifier: GPL - 3.0 +
77
from typing import Any, Dict, Type, TypedDict, TypeVar
88

9+
from pydantic import validate_call
10+
911
import mantid
1012
from abins.kpointsdata import KpointsData
1113
from abins.atomsdata import AtomsData
@@ -23,13 +25,9 @@ class AbinsData:
2325
2426
"""
2527

28+
@validate_call(config=dict(arbitrary_types_allowed=True, strict=True))
2629
def __init__(self, *, k_points_data: KpointsData, atoms_data: AtomsData) -> None:
27-
if not isinstance(k_points_data, KpointsData):
28-
raise TypeError("Invalid type of k-points data.: {}".format(type(k_points_data)))
2930
self._k_points_data = k_points_data
30-
31-
if not isinstance(atoms_data, AtomsData):
32-
raise TypeError("Invalid type of atoms data.")
3331
self._atoms_data = atoms_data
3432
self._check_consistent_dimensions()
3533

scripts/abins/atomsdata.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
# SPDX - License - Identifier: GPL - 3.0 +
77
import collections.abc
88
import numbers
9-
from typing import Dict, List, Optional, overload, Union, TypedDict
9+
from typing import Dict, List, Optional, overload, TypedDict, Union
1010
import re
11+
1112
import numpy as np
1213

1314
import abins
@@ -146,12 +147,10 @@ def __len__(self) -> int:
146147
return len(self._data)
147148

148149
@overload
149-
def __getitem__(self, item: int) -> _AtomData:
150-
...
150+
def __getitem__(self, item: int) -> _AtomData: ...
151151

152152
@overload
153-
def __getitem__(self, item: slice) -> List[_AtomData]:
154-
...
153+
def __getitem__(self, item: slice) -> List[_AtomData]: ...
155154

156155
def __getitem__(self, item):
157156
return self._data[item]

scripts/abins/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
66
# SPDX - License - Identifier: GPL - 3.0 +
77
import math
8+
from typing import Literal
89
import warnings
910

1011
import numpy as np
@@ -164,6 +165,7 @@
164165

165166
# ALL_SAMPLE_FORMS = ["SingleCrystal", "Powder"] # valid forms of samples
166167
ALL_SAMPLE_FORMS = ["Powder"] # valid forms of samples
168+
ALL_SAMPLE_FORMS_TYPE = Literal["Powder"]
167169

168170
# keywords which define data structure of KpointsData
169171
ALL_KEYWORDS_K_DATA = ["weights", "k_vectors", "frequencies", "atomic_displacements", "unit_cell"]
@@ -191,6 +193,9 @@
191193
INT_ID = np.dtype(np.uint32).num
192194
INT_TYPE = np.dtype(np.uint32)
193195

196+
# Valid types for hdf5 attr read/write
197+
HDF5_ATTR_TYPE = np.int64 | int | np.float64 | float | str | bytes | bool
198+
194199
HIGHER_ORDER_QUANTUM_EVENTS = 3 # number of quantum order effects taken into account
195200
HIGHER_ORDER_QUANTUM_EVENTS_DIM = HIGHER_ORDER_QUANTUM_EVENTS
196201

scripts/abins/io.py

Lines changed: 54 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,46 +10,41 @@
1010
import os
1111
import subprocess
1212
import shutil
13+
from typing import List, Optional
14+
1315
import h5py
16+
1417
import numpy as np
18+
from pydantic import BaseModel, ConfigDict, Field, validate_call
1519

1620
import abins
17-
from abins.constants import AB_INITIO_FILE_EXTENSIONS, BUF
21+
from abins.constants import AB_INITIO_FILE_EXTENSIONS, BUF, HDF5_ATTR_TYPE
1822
from mantid.kernel import logger, ConfigService
1923

2024

21-
class IO(object):
25+
class IO(BaseModel):
2226
"""
2327
Class for Abins I/O HDF file operations.
2428
"""
2529

26-
def __init__(self, input_filename=None, group_name=None, setting="", autoconvolution: bool = False, temperature: float = None):
27-
self._setting = setting
28-
self._autoconvolution = autoconvolution
29-
self._temperature = temperature
30-
31-
if isinstance(input_filename, str):
32-
self._input_filename = input_filename
33-
try:
34-
self._hash_input_filename = self.calculate_ab_initio_file_hash()
35-
except IOError as err:
36-
logger.error(str(err))
37-
except ValueError as err:
38-
logger.error(str(err))
30+
model_config = ConfigDict(strict=True)
3931

40-
# extract name of file from the full path in the platform independent way
41-
filename = os.path.basename(self._input_filename)
32+
input_filename: str = Field(min_length=1)
33+
group_name: str = Field(min_length=1)
34+
setting: str = ""
35+
autoconvolution: int = 10
36+
temperature: float = None
4237

43-
if filename.strip() == "":
44-
raise ValueError("Name of the file cannot be an empty string.")
38+
def model_post_init(self, __context):
39+
try:
40+
self._hash_input_filename = self.calculate_ab_initio_file_hash()
41+
except IOError as err:
42+
logger.error(str(err))
43+
except ValueError as err:
44+
logger.error(str(err))
4545

46-
else:
47-
raise ValueError("Invalid name of input file. String was expected.")
48-
49-
if isinstance(group_name, str):
50-
self._group_name = group_name
51-
else:
52-
raise ValueError("Invalid name of the group. String was expected.")
46+
# extract name of file from the full path in the platform independent way
47+
filename = os.path.basename(self.input_filename)
5348

5449
if filename.split(".")[-1] in AB_INITIO_FILE_EXTENSIONS:
5550
core_name = filename[0 : filename.rfind(".")] # e.g. NaCl.phonon -> NaCl (core_name) -> NaCl.hdf5
@@ -85,15 +80,15 @@ def _valid_setting(self):
8580
:returns: True if consistent, otherwise False.
8681
"""
8782
saved_setting = self.load(list_of_attributes=["setting"])
88-
return self._setting == saved_setting["attributes"]["setting"]
83+
return self.setting == saved_setting["attributes"]["setting"]
8984

9085
def _valid_autoconvolution(self):
9186
"""
9287
Check if autoconvolution setting matches content of HDF file
9388
:returns: True if consistent, otherwise False
9489
"""
9590
saved_autoconvolution = self.load(list_of_attributes=["autoconvolution"])
96-
return self._autoconvolution == saved_autoconvolution["attributes"]["autoconvolution"]
91+
return self.autoconvolution == saved_autoconvolution["attributes"]["autoconvolution"]
9792

9893
def _valid_temperature(self):
9994
"""
@@ -104,7 +99,7 @@ def _valid_temperature(self):
10499
105100
:returns: True if consistent or temperature not set for Clerk, otherwise False
106101
"""
107-
if self._temperature is None:
102+
if self.temperature is None:
108103
return True
109104

110105
else:
@@ -115,7 +110,7 @@ def _valid_temperature(self):
115110
return False
116111

117112
else:
118-
return np.abs(self._temperature - saved_temperature["attributes"]["temperature"]) < T_THRESHOLD
113+
return np.abs(self.temperature - saved_temperature["attributes"]["temperature"]) < T_THRESHOLD
119114

120115
@classmethod
121116
def _close_enough(cls, previous, new):
@@ -215,7 +210,8 @@ def erase_hdf_file(self):
215210
with h5py.File(self._hdf_filename, "w") as hdf_file:
216211
hdf_file.close()
217212

218-
def add_attribute(self, name=None, value=None):
213+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True, strict=True))
214+
def add_attribute(self, name: str, value: HDF5_ATTR_TYPE | None) -> None:
219215
"""
220216
Adds attribute to the dictionary with other attributes.
221217
:param name: name of the attribute
@@ -228,10 +224,10 @@ def add_file_attributes(self):
228224
Add attributes for input data filename, hash of file, advanced parameters to data for HDF5 file
229225
"""
230226
self.add_attribute("hash", self._hash_input_filename)
231-
self.add_attribute("setting", self._setting)
232-
self.add_attribute("autoconvolution", self._autoconvolution)
233-
self.add_attribute("temperature", self._temperature)
234-
self.add_attribute("filename", self._input_filename)
227+
self.add_attribute("setting", self.setting)
228+
self.add_attribute("autoconvolution", self.autoconvolution)
229+
self.add_attribute("temperature", self.temperature)
230+
self.add_attribute("filename", self.input_filename)
235231
self.add_attribute("advanced_parameters", json.dumps(abins.parameters.non_performance_parameters))
236232

237233
def add_data(self, name=None, value=None):
@@ -250,15 +246,10 @@ def _save_attributes(self, group=None):
250246
:param group: group to which attributes should be saved.
251247
"""
252248
for name in self._attributes:
253-
if isinstance(self._attributes[name], (np.int64, int, np.float64, float, str, bytes, bool)):
254-
group.attrs[name] = self._attributes[name]
255-
elif self._attributes[name] is None:
249+
if self._attributes[name] is None:
256250
group.attrs[name] = "None"
257251
else:
258-
raise ValueError(
259-
"Invalid value of attribute. String, "
260-
"int, bool or bytes was expected! " + name + "= (invalid type : %s) " % type(self._attributes[name])
261-
)
252+
group.attrs[name] = self._attributes[name]
262253

263254
def _recursively_save_structured_data_to_group(self, hdf_file=None, path=None, dic=None):
264255
"""
@@ -309,17 +300,17 @@ def _save_data(self, hdf_file=None, group=None):
309300
elif isinstance(self._data[item], dict):
310301
self._recursively_save_structured_data_to_group(hdf_file=hdf_file, path=group.name + "/" + item + "/", dic=self._data[item])
311302
else:
312-
raise ValueError("Invalid structured dataset. Cannot save %s type" % type(item))
303+
raise TypeError("Invalid structured dataset. Cannot save %s type" % type(item))
313304

314305
def save(self):
315306
"""
316307
Saves datasets and attributes to an hdf file.
317308
"""
318309

319310
with h5py.File(self._hdf_filename, "a") as hdf_file:
320-
if self._group_name not in hdf_file:
321-
hdf_file.create_group(self._group_name)
322-
group = hdf_file[self._group_name]
311+
if self.group_name not in hdf_file:
312+
hdf_file.create_group(self.group_name)
313+
group = hdf_file[self.group_name]
323314

324315
if len(self._attributes.keys()) > 0:
325316
self._save_attributes(group=group)
@@ -342,7 +333,8 @@ def save(self):
342333
pass
343334

344335
@staticmethod
345-
def _list_of_str(list_str=None):
336+
@validate_call
337+
def _list_of_str(list_str: Optional[List[str]]) -> bool:
346338
"""
347339
Checks if all elements of the list are strings.
348340
:param list_str: list to check
@@ -351,9 +343,6 @@ def _list_of_str(list_str=None):
351343
if list_str is None:
352344
return False
353345

354-
if not (isinstance(list_str, list) and all([isinstance(list_str[item], str) for item in range(len(list_str))])):
355-
raise ValueError("Invalid list of items to load!")
356-
357346
return True
358347

359348
def _load_attributes(self, list_of_attributes=None, group=None):
@@ -422,39 +411,34 @@ def _encode_utf8_if_text(item):
422411
else:
423412
return item
424413

425-
def _load_dataset(self, hdf_file=None, name=None, group=None):
414+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
415+
def _load_dataset(self, *, hdf_file: h5py.File, name: str, group: h5py.Group):
426416
"""
427417
Loads one structured dataset.
428418
:param hdf_file: hdf file object from which structured dataset should be loaded.
429419
:param name: name of dataset
430420
:param group: name of the main group
431421
:returns: loaded dataset
432422
"""
433-
if not isinstance(name, str):
434-
raise ValueError("Invalid name of the dataset.")
435-
436423
if name in group:
437424
hdf_group = group[name]
438425
else:
439-
raise ValueError("Invalid name of the dataset.")
426+
raise ValueError(f"Dataset {name} not found in group {group.name}.")
440427

441428
# noinspection PyUnresolvedReferences,PyProtectedMember
442-
if isinstance(hdf_group, h5py._hl.dataset.Dataset):
429+
if isinstance(hdf_group, h5py.Dataset):
443430
return hdf_group[()]
444431
elif all([self._get_subgrp_name(hdf_group[el].name).isdigit() for el in hdf_group.keys()]):
445-
structured_dataset_list = []
446-
# here we make an assumption about keys which have a numeric values; we assume that always : 1, 2, 3... Max
447-
num_keys = len(hdf_group.keys())
448-
for item in range(num_keys):
449-
structured_dataset_list.append(
450-
self._recursively_load_dict_contents_from_group(hdf_file=hdf_file, path=hdf_group.name + "/%s" % item)
451-
)
432+
structured_dataset_list = [
433+
self._recursively_load_dict_contents_from_group(hdf_file=hdf_file, path=f"{hdf_group.name}/{key}")
434+
for key in sorted(hdf_group.keys(), key=int)
435+
]
452436
return structured_dataset_list
453437
else:
454438
return self._recursively_load_dict_contents_from_group(hdf_file=hdf_file, path=hdf_group.name + "/")
455439

456440
@classmethod
457-
def _recursively_load_dict_contents_from_group(cls, hdf_file=None, path=None):
441+
def _recursively_load_dict_contents_from_group(cls, *, hdf_file: h5py.File, path: str):
458442
"""
459443
Loads structure dataset which has form of Python dictionary.
460444
:param hdf_file: hdf file object from which dataset is loaded
@@ -468,7 +452,7 @@ def _recursively_load_dict_contents_from_group(cls, hdf_file=None, path=None):
468452
if isinstance(item, h5py._hl.dataset.Dataset):
469453
ans[key] = item[()]
470454
elif isinstance(item, h5py._hl.group.Group):
471-
ans[key] = cls._recursively_load_dict_contents_from_group(hdf_file, path + key + "/")
455+
ans[key] = cls._recursively_load_dict_contents_from_group(hdf_file=hdf_file, path=f"{path}{key}/")
472456
return ans
473457

474458
def load(self, list_of_attributes=None, list_of_datasets=None):
@@ -484,10 +468,10 @@ def load(self, list_of_attributes=None, list_of_datasets=None):
484468

485469
results = {}
486470
with h5py.File(self._hdf_filename, "r") as hdf_file:
487-
if self._group_name not in hdf_file:
488-
raise ValueError("No group %s in hdf file." % self._group_name)
471+
if self.group_name not in hdf_file:
472+
raise ValueError("No group %s in hdf file." % self.group_name)
489473

490-
group = hdf_file[self._group_name]
474+
group = hdf_file[self.group_name]
491475

492476
if self._list_of_str(list_str=list_of_attributes):
493477
results["attributes"] = self._load_attributes(list_of_attributes=list_of_attributes, group=group)
@@ -521,7 +505,7 @@ def _calculate_hash(filename=None):
521505
return hash_calculator.hexdigest()
522506

523507
def get_input_filename(self):
524-
return self._input_filename
508+
return self.input_filename
525509

526510
def calculate_ab_initio_file_hash(self):
527511
"""
@@ -530,4 +514,4 @@ def calculate_ab_initio_file_hash(self):
530514
:returns: string representation of hash for file with vibrational data which contains only hexadecimal digits
531515
"""
532516

533-
return self._calculate_hash(filename=self._input_filename)
517+
return self._calculate_hash(filename=self.input_filename)

0 commit comments

Comments
 (0)