10
10
import os
11
11
import subprocess
12
12
import shutil
13
+ from typing import List , Optional
14
+
13
15
import h5py
16
+
14
17
import numpy as np
18
+ from pydantic import BaseModel , ConfigDict , Field , validate_call
15
19
16
20
import abins
17
- from abins .constants import AB_INITIO_FILE_EXTENSIONS , BUF
21
+ from abins .constants import AB_INITIO_FILE_EXTENSIONS , BUF , HDF5_ATTR_TYPE
18
22
from mantid .kernel import logger , ConfigService
19
23
20
24
21
- class IO (object ):
25
+ class IO (BaseModel ):
22
26
"""
23
27
Class for Abins I/O HDF file operations.
24
28
"""
25
29
26
- def __init__ (self , input_filename = None , group_name = None , setting = "" , autoconvolution : bool = False , temperature : float = None ):
27
- self ._setting = setting
28
- self ._autoconvolution = autoconvolution
29
- self ._temperature = temperature
30
-
31
- if isinstance (input_filename , str ):
32
- self ._input_filename = input_filename
33
- try :
34
- self ._hash_input_filename = self .calculate_ab_initio_file_hash ()
35
- except IOError as err :
36
- logger .error (str (err ))
37
- except ValueError as err :
38
- logger .error (str (err ))
30
+ model_config = ConfigDict (strict = True )
39
31
40
- # extract name of file from the full path in the platform independent way
41
- filename = os .path .basename (self ._input_filename )
32
+ input_filename : str = Field (min_length = 1 )
33
+ group_name : str = Field (min_length = 1 )
34
+ setting : str = ""
35
+ autoconvolution : int = 10
36
+ temperature : float = None
42
37
43
- if filename .strip () == "" :
44
- raise ValueError ("Name of the file cannot be an empty string." )
38
+ def model_post_init (self , __context ):
39
+ try :
40
+ self ._hash_input_filename = self .calculate_ab_initio_file_hash ()
41
+ except IOError as err :
42
+ logger .error (str (err ))
43
+ except ValueError as err :
44
+ logger .error (str (err ))
45
45
46
- else :
47
- raise ValueError ("Invalid name of input file. String was expected." )
48
-
49
- if isinstance (group_name , str ):
50
- self ._group_name = group_name
51
- else :
52
- raise ValueError ("Invalid name of the group. String was expected." )
46
+ # extract name of file from the full path in the platform independent way
47
+ filename = os .path .basename (self .input_filename )
53
48
54
49
if filename .split ("." )[- 1 ] in AB_INITIO_FILE_EXTENSIONS :
55
50
core_name = filename [0 : filename .rfind ("." )] # e.g. NaCl.phonon -> NaCl (core_name) -> NaCl.hdf5
@@ -85,15 +80,15 @@ def _valid_setting(self):
85
80
:returns: True if consistent, otherwise False.
86
81
"""
87
82
saved_setting = self .load (list_of_attributes = ["setting" ])
88
- return self ._setting == saved_setting ["attributes" ]["setting" ]
83
+ return self .setting == saved_setting ["attributes" ]["setting" ]
89
84
90
85
def _valid_autoconvolution (self ):
91
86
"""
92
87
Check if autoconvolution setting matches content of HDF file
93
88
:returns: True if consistent, otherwise False
94
89
"""
95
90
saved_autoconvolution = self .load (list_of_attributes = ["autoconvolution" ])
96
- return self ._autoconvolution == saved_autoconvolution ["attributes" ]["autoconvolution" ]
91
+ return self .autoconvolution == saved_autoconvolution ["attributes" ]["autoconvolution" ]
97
92
98
93
def _valid_temperature (self ):
99
94
"""
@@ -104,7 +99,7 @@ def _valid_temperature(self):
104
99
105
100
:returns: True if consistent or temperature not set for Clerk, otherwise False
106
101
"""
107
- if self ._temperature is None :
102
+ if self .temperature is None :
108
103
return True
109
104
110
105
else :
@@ -115,7 +110,7 @@ def _valid_temperature(self):
115
110
return False
116
111
117
112
else :
118
- return np .abs (self ._temperature - saved_temperature ["attributes" ]["temperature" ]) < T_THRESHOLD
113
+ return np .abs (self .temperature - saved_temperature ["attributes" ]["temperature" ]) < T_THRESHOLD
119
114
120
115
@classmethod
121
116
def _close_enough (cls , previous , new ):
@@ -215,7 +210,8 @@ def erase_hdf_file(self):
215
210
with h5py .File (self ._hdf_filename , "w" ) as hdf_file :
216
211
hdf_file .close ()
217
212
218
- def add_attribute (self , name = None , value = None ):
213
+ @validate_call (config = ConfigDict (arbitrary_types_allowed = True , strict = True ))
214
+ def add_attribute (self , name : str , value : HDF5_ATTR_TYPE | None ) -> None :
219
215
"""
220
216
Adds attribute to the dictionary with other attributes.
221
217
:param name: name of the attribute
@@ -228,10 +224,10 @@ def add_file_attributes(self):
228
224
Add attributes for input data filename, hash of file, advanced parameters to data for HDF5 file
229
225
"""
230
226
self .add_attribute ("hash" , self ._hash_input_filename )
231
- self .add_attribute ("setting" , self ._setting )
232
- self .add_attribute ("autoconvolution" , self ._autoconvolution )
233
- self .add_attribute ("temperature" , self ._temperature )
234
- self .add_attribute ("filename" , self ._input_filename )
227
+ self .add_attribute ("setting" , self .setting )
228
+ self .add_attribute ("autoconvolution" , self .autoconvolution )
229
+ self .add_attribute ("temperature" , self .temperature )
230
+ self .add_attribute ("filename" , self .input_filename )
235
231
self .add_attribute ("advanced_parameters" , json .dumps (abins .parameters .non_performance_parameters ))
236
232
237
233
def add_data (self , name = None , value = None ):
@@ -250,15 +246,10 @@ def _save_attributes(self, group=None):
250
246
:param group: group to which attributes should be saved.
251
247
"""
252
248
for name in self ._attributes :
253
- if isinstance (self ._attributes [name ], (np .int64 , int , np .float64 , float , str , bytes , bool )):
254
- group .attrs [name ] = self ._attributes [name ]
255
- elif self ._attributes [name ] is None :
249
+ if self ._attributes [name ] is None :
256
250
group .attrs [name ] = "None"
257
251
else :
258
- raise ValueError (
259
- "Invalid value of attribute. String, "
260
- "int, bool or bytes was expected! " + name + "= (invalid type : %s) " % type (self ._attributes [name ])
261
- )
252
+ group .attrs [name ] = self ._attributes [name ]
262
253
263
254
def _recursively_save_structured_data_to_group (self , hdf_file = None , path = None , dic = None ):
264
255
"""
@@ -309,17 +300,17 @@ def _save_data(self, hdf_file=None, group=None):
309
300
elif isinstance (self ._data [item ], dict ):
310
301
self ._recursively_save_structured_data_to_group (hdf_file = hdf_file , path = group .name + "/" + item + "/" , dic = self ._data [item ])
311
302
else :
312
- raise ValueError ("Invalid structured dataset. Cannot save %s type" % type (item ))
303
+ raise TypeError ("Invalid structured dataset. Cannot save %s type" % type (item ))
313
304
314
305
def save (self ):
315
306
"""
316
307
Saves datasets and attributes to an hdf file.
317
308
"""
318
309
319
310
with h5py .File (self ._hdf_filename , "a" ) as hdf_file :
320
- if self ._group_name not in hdf_file :
321
- hdf_file .create_group (self ._group_name )
322
- group = hdf_file [self ._group_name ]
311
+ if self .group_name not in hdf_file :
312
+ hdf_file .create_group (self .group_name )
313
+ group = hdf_file [self .group_name ]
323
314
324
315
if len (self ._attributes .keys ()) > 0 :
325
316
self ._save_attributes (group = group )
@@ -342,7 +333,8 @@ def save(self):
342
333
pass
343
334
344
335
@staticmethod
345
- def _list_of_str (list_str = None ):
336
+ @validate_call
337
+ def _list_of_str (list_str : Optional [List [str ]]) -> bool :
346
338
"""
347
339
Checks if all elements of the list are strings.
348
340
:param list_str: list to check
@@ -351,9 +343,6 @@ def _list_of_str(list_str=None):
351
343
if list_str is None :
352
344
return False
353
345
354
- if not (isinstance (list_str , list ) and all ([isinstance (list_str [item ], str ) for item in range (len (list_str ))])):
355
- raise ValueError ("Invalid list of items to load!" )
356
-
357
346
return True
358
347
359
348
def _load_attributes (self , list_of_attributes = None , group = None ):
@@ -422,39 +411,34 @@ def _encode_utf8_if_text(item):
422
411
else :
423
412
return item
424
413
425
- def _load_dataset (self , hdf_file = None , name = None , group = None ):
414
+ @validate_call (config = ConfigDict (arbitrary_types_allowed = True ))
415
+ def _load_dataset (self , * , hdf_file : h5py .File , name : str , group : h5py .Group ):
426
416
"""
427
417
Loads one structured dataset.
428
418
:param hdf_file: hdf file object from which structured dataset should be loaded.
429
419
:param name: name of dataset
430
420
:param group: name of the main group
431
421
:returns: loaded dataset
432
422
"""
433
- if not isinstance (name , str ):
434
- raise ValueError ("Invalid name of the dataset." )
435
-
436
423
if name in group :
437
424
hdf_group = group [name ]
438
425
else :
439
- raise ValueError ("Invalid name of the dataset ." )
426
+ raise ValueError (f"Dataset { name } not found in group { group . name } ." )
440
427
441
428
# noinspection PyUnresolvedReferences,PyProtectedMember
442
- if isinstance (hdf_group , h5py ._hl . dataset . Dataset ):
429
+ if isinstance (hdf_group , h5py .Dataset ):
443
430
return hdf_group [()]
444
431
elif all ([self ._get_subgrp_name (hdf_group [el ].name ).isdigit () for el in hdf_group .keys ()]):
445
- structured_dataset_list = []
446
- # here we make an assumption about keys which have a numeric values; we assume that always : 1, 2, 3... Max
447
- num_keys = len (hdf_group .keys ())
448
- for item in range (num_keys ):
449
- structured_dataset_list .append (
450
- self ._recursively_load_dict_contents_from_group (hdf_file = hdf_file , path = hdf_group .name + "/%s" % item )
451
- )
432
+ structured_dataset_list = [
433
+ self ._recursively_load_dict_contents_from_group (hdf_file = hdf_file , path = f"{ hdf_group .name } /{ key } " )
434
+ for key in sorted (hdf_group .keys (), key = int )
435
+ ]
452
436
return structured_dataset_list
453
437
else :
454
438
return self ._recursively_load_dict_contents_from_group (hdf_file = hdf_file , path = hdf_group .name + "/" )
455
439
456
440
@classmethod
457
- def _recursively_load_dict_contents_from_group (cls , hdf_file = None , path = None ):
441
+ def _recursively_load_dict_contents_from_group (cls , * , hdf_file : h5py . File , path : str ):
458
442
"""
459
443
Loads structure dataset which has form of Python dictionary.
460
444
:param hdf_file: hdf file object from which dataset is loaded
@@ -468,7 +452,7 @@ def _recursively_load_dict_contents_from_group(cls, hdf_file=None, path=None):
468
452
if isinstance (item , h5py ._hl .dataset .Dataset ):
469
453
ans [key ] = item [()]
470
454
elif isinstance (item , h5py ._hl .group .Group ):
471
- ans [key ] = cls ._recursively_load_dict_contents_from_group (hdf_file , path + key + " /" )
455
+ ans [key ] = cls ._recursively_load_dict_contents_from_group (hdf_file = hdf_file , path = f" { path } { key } /" )
472
456
return ans
473
457
474
458
def load (self , list_of_attributes = None , list_of_datasets = None ):
@@ -484,10 +468,10 @@ def load(self, list_of_attributes=None, list_of_datasets=None):
484
468
485
469
results = {}
486
470
with h5py .File (self ._hdf_filename , "r" ) as hdf_file :
487
- if self ._group_name not in hdf_file :
488
- raise ValueError ("No group %s in hdf file." % self ._group_name )
471
+ if self .group_name not in hdf_file :
472
+ raise ValueError ("No group %s in hdf file." % self .group_name )
489
473
490
- group = hdf_file [self ._group_name ]
474
+ group = hdf_file [self .group_name ]
491
475
492
476
if self ._list_of_str (list_str = list_of_attributes ):
493
477
results ["attributes" ] = self ._load_attributes (list_of_attributes = list_of_attributes , group = group )
@@ -521,7 +505,7 @@ def _calculate_hash(filename=None):
521
505
return hash_calculator .hexdigest ()
522
506
523
507
def get_input_filename (self ):
524
- return self ._input_filename
508
+ return self .input_filename
525
509
526
510
def calculate_ab_initio_file_hash (self ):
527
511
"""
@@ -530,4 +514,4 @@ def calculate_ab_initio_file_hash(self):
530
514
:returns: string representation of hash for file with vibrational data which contains only hexadecimal digits
531
515
"""
532
516
533
- return self ._calculate_hash (filename = self ._input_filename )
517
+ return self ._calculate_hash (filename = self .input_filename )
0 commit comments