3
3
import tensorflow as tf
4
4
import pandas as pd
5
5
import os
6
- import typing as t
7
- from typing import Union , List , Callable , Dict
8
- from collections .abc import MutableSequence
6
+ # import typing as t
7
+ from typing import Union , List , Callable , Dict , Optional
8
+ # from collections.abc import MutableSequence
9
9
10
10
from kgcnn .data .utils import save_pickle_file , load_pickle_file , ragged_tensor_from_nested_numpy
11
11
from kgcnn .graph .base import GraphDict
12
12
13
- logging .basicConfig () # Module logger
13
+ # Module logger
14
+ logging .basicConfig ()
14
15
module_logger = logging .getLogger (__name__ )
15
16
module_logger .setLevel (logging .INFO )
16
17
17
18
18
- class MemoryGraphList (MutableSequence ):
19
+ class MemoryGraphList (list ):
19
20
r"""Class to store a list of graph dictionaries in memory.
20
21
21
- Contains a python list as property :obj:`_list` . The graph properties are defined by tensor-like (numpy) arrays
22
+ Inherits from a python list. The graph properties are defined by tensor-like (numpy) arrays
22
23
for indices, attributes, labels, symbol etc. in :obj:`GraphDict`, which are the items of the list.
23
24
Access to items via `[]` indexing operator.
24
25
@@ -45,22 +46,26 @@ class MemoryGraphList(MutableSequence):
45
46
data.clean("range_indices") # Returns cleaned graph indices
46
47
print(len(data))
47
48
print(data[0])
49
+
48
50
"""
49
51
50
- def __init__ (self , input_list : list = None ):
52
+ _require_validate = True
53
+
54
+ def __init__ (self , iterable : list = None ):
51
55
r"""Initialize an empty :obj:`MemoryGraphList` instance.
52
56
53
57
Args:
54
- input_list (list, MemoryGraphList): A list or :obj:`MemoryGraphList` of :obj:`GraphDict` items.
58
+ iterable (list, MemoryGraphList): A list or :obj:`MemoryGraphList` of :obj:`GraphDict` items.
55
59
"""
56
- self ._list = []
60
+ iterable = iterable if iterable is not None else []
61
+ super (MemoryGraphList , self ).__init__ (iterable )
57
62
self .logger = module_logger
58
- if input_list is None :
59
- input_list = []
60
- if isinstance ( input_list , list ):
61
- self . _list = [ GraphDict ( x ) for x in input_list ]
62
- if isinstance (input_list , MemoryGraphList ):
63
- self . _list = [ GraphDict (x ) for x in input_list . _list ]
63
+ self . validate ()
64
+
65
+ def validate ( self ):
66
+ for i , x in enumerate ( self ):
67
+ if not isinstance (x , GraphDict ):
68
+ self [ i ] = GraphDict (x )
64
69
65
70
def assign_property (self , key : str , value : list ):
66
71
"""Assign a list of numpy arrays of a property to :obj:`GraphDict` s in this list.
@@ -77,24 +82,23 @@ def assign_property(self, key: str, value: list):
77
82
return self
78
83
if not isinstance (value , list ):
79
84
raise TypeError ("Expected type 'list' to assign graph properties." )
80
- if len (self . _list ) == 0 :
85
+ if len (self ) == 0 :
81
86
self .empty (len (value ))
82
- if len (self . _list ) != len (value ):
87
+ if len (self ) != len (value ):
83
88
raise ValueError ("Can only store graph attributes from list with same length." )
84
89
for i , x in enumerate (value ):
85
- self . _list [i ].assign_property (key , x )
90
+ self [i ].assign_property (key , x )
86
91
return self
87
92
88
- def obtain_property (self , key : str ) -> Union [list , None ]:
93
+ def obtain_property (self , key : str ) -> Union [List , None ]:
89
94
r"""Returns a list with the values of all the graphs defined for the string property name `key`. If none of
90
95
the graphs in the list have this property, returns None.
91
96
92
97
Args:
93
98
key (str): The string name of the property to be retrieved for all the graphs contained in this list
94
99
"""
95
- # "_list" is a list of GraphDicts, which means "prop_list" here will be a list of all the property
96
100
# values for teach of the graphs which make up this list.
97
- prop_list = [x .obtain_property (key ) for x in self . _list ]
101
+ prop_list = [x .obtain_property (key ) for x in self ]
98
102
99
103
# If a certain string property is not set for a GraphDict, it will still return None. Here we check:
100
104
# If all the items for our given property name are None then we know that this property is generally not
@@ -105,70 +109,42 @@ def obtain_property(self, key: str) -> Union[list, None]:
105
109
106
110
return prop_list
107
111
108
- def __len__ (self ):
109
- """Return the current length of this instance."""
110
- return len (self ._list )
111
-
112
- def __getitem__ (self , item ):
112
+ def __getitem__ (self , item ) -> Union [GraphDict , List ]:
113
113
# Does not make a copy of the data, as a python list does.
114
114
if isinstance (item , int ):
115
- return self ._list [item ]
116
- new_list = MemoryGraphList ()
115
+ return super (MemoryGraphList , self ).__getitem__ (item )
117
116
if isinstance (item , slice ):
118
- return new_list . _set_internal_list ( self . _list [ item ] )
119
- if isinstance (item , list ):
120
- return new_list . _set_internal_list ([ self . _list [ int (i )] for i in item ])
117
+ return MemoryGraphList ( super ( MemoryGraphList , self ). __getitem__ ( item ) )
118
+ if isinstance (item , ( list , tuple ) ):
119
+ return MemoryGraphList ([ super ( MemoryGraphList , self ). __getitem__ ( int (i )) for i in item ])
121
120
if isinstance (item , np .ndarray ):
122
- return new_list . _set_internal_list ([ self . _list [ int (i )] for i in item ])
121
+ return MemoryGraphList ([ super ( MemoryGraphList , self ). __getitem__ ( int (i )) for i in item ])
123
122
raise TypeError ("Unsupported type for `MemoryGraphList` items." )
124
123
125
124
def __setitem__ (self , key , value ):
126
125
if not isinstance (value , GraphDict ):
127
126
raise TypeError ("Require a GraphDict as list item." )
128
- self ._list [key ] = value
129
-
130
- def __delitem__ (self , key ):
131
- value = self ._list .__delitem__ (key )
132
- return value
133
-
134
- def __iter__ (self ):
135
- return iter (self ._list )
127
+ super (MemoryGraphList , self ).__setitem__ (key , value )
136
128
137
129
def __repr__ (self ):
138
130
return "<{} [{}]>" .format (type (self ).__name__ , "" if len (self ) == 0 else self [0 ].__repr__ () + " ..." )
139
131
140
132
def append (self , graph ):
141
133
assert isinstance (graph , GraphDict ), "Must append `GraphDict` to self."
142
- self . _list .append (graph )
134
+ super ( MemoryGraphList , self ) .append (graph )
143
135
144
136
def insert (self , index : int , value ) -> None :
145
137
assert isinstance (value , GraphDict ), "Must insert `GraphDict` to self."
146
- self . _list .insert (index , value )
138
+ super ( MemoryGraphList , self ) .insert (index , value )
147
139
148
140
def __add__ (self , other ):
149
141
assert isinstance (other , MemoryGraphList ), "Must add `MemoryGraphList` to self."
150
- new_list = MemoryGraphList ()
151
- new_list ._set_internal_list (self ._list + other ._list )
152
- return new_list
153
-
154
- def _set_internal_list (self , value : list ):
155
- if not isinstance (value , list ):
156
- raise TypeError ("Must set list for `MemoryGraphList` internal assignment." )
157
- self ._list = value
158
- return self
142
+ return MemoryGraphList (super (MemoryGraphList , self ).__add__ (other ))
159
143
160
144
def copy (self ):
161
145
"""Copy data in the list."""
162
146
return MemoryGraphList ([x .copy () for x in self ])
163
147
164
- def clear (self ):
165
- """Clear internal list.
166
-
167
- Returns:
168
- None
169
- """
170
- self ._list .clear ()
171
-
172
148
def empty (self , length : int ):
173
149
"""Create an empty list in place. Overwrites existing list.
174
150
@@ -182,7 +158,9 @@ def empty(self, length: int):
182
158
return self
183
159
if length < 0 :
184
160
raise ValueError ("Length of empty list must be >=0." )
185
- self ._list = [GraphDict () for _ in range (length )]
161
+ self .clear ()
162
+ for _ in range (length ):
163
+ self .append (GraphDict ())
186
164
return self
187
165
188
166
def update (self , other ) -> None :
@@ -194,7 +172,7 @@ def update(self, other) -> None:
194
172
@property
195
173
def length (self ):
196
174
"""Length of list."""
197
- return len (self . _list )
175
+ return len (self )
198
176
199
177
@length .setter
200
178
def length (self , value : int ):
@@ -257,7 +235,7 @@ def map_list(self, method: Union[str, Callable], **kwargs):
257
235
# Can add progress info here.
258
236
# Method by name.
259
237
if isinstance (method , str ):
260
- for i , x in enumerate (self . _list ):
238
+ for i , x in enumerate (self ):
261
239
# If this is a class method.
262
240
if hasattr (x , method ):
263
241
getattr (x , method )(** kwargs )
@@ -268,7 +246,7 @@ def map_list(self, method: Union[str, Callable], **kwargs):
268
246
raise NotImplementedError ("Serialization for method in `map_list` is not yet supported" )
269
247
else :
270
248
# For any callable method to map.
271
- for i , x in enumerate (self . _list ):
249
+ for i , x in enumerate (self ):
272
250
method (x , ** kwargs )
273
251
return self
274
252
@@ -319,7 +297,7 @@ def clean(self, inputs: Union[list, str]):
319
297
self .logger .info ("No invalid graphs for assigned properties found." )
320
298
# Remove from the end via pop().
321
299
for i in invalid_graphs :
322
- self ._list . pop (int (i ))
300
+ self .pop (int (i ))
323
301
return invalid_graphs
324
302
325
303
# Alias of internal assign and obtain property.
@@ -336,14 +314,17 @@ class MemoryGraphDataset(MemoryGraphList):
336
314
337
315
.. code-block:: python
338
316
317
+ import numpy as np
339
318
from kgcnn.data.base import MemoryGraphDataset
340
319
dataset = MemoryGraphDataset(data_directory="", dataset_name="Example")
341
320
# Methods of MemoryGraphList
342
321
dataset.set("edge_indices", [np.array([[1, 0], [0, 1]])])
343
322
dataset.set("edge_labels", [np.array([[0], [1]])])
344
323
dataset.save()
324
+ dataset.load()
325
+
326
+ The file directory and file name are used in child classes and in :obj:`save` and :obj:`load` .
345
327
346
- The file directory and file name are used in child classes and in :obj:`save` and :obj:`load`.
347
328
"""
348
329
349
330
fits_in_memory = True
@@ -420,20 +401,20 @@ def error(self, *args, **kwargs):
420
401
421
402
def save (self , filepath : str = None ):
422
403
r"""Save all graph properties to python dictionary as pickled file. By default, saves a file named
423
- :obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory`.
404
+ :obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory` .
424
405
425
406
Args:
426
407
filepath (str): Full path of output file. Default is None.
427
408
"""
428
409
if filepath is None :
429
410
filepath = os .path .join (self .data_directory , self .dataset_name + ".kgcnn.pickle" )
430
411
self .info ("Pickle dataset..." )
431
- save_pickle_file ([x .to_dict () for x in self . _list ], filepath )
412
+ save_pickle_file ([x .to_dict () for x in self ], filepath )
432
413
return self
433
414
434
415
def load (self , filepath : str = None ):
435
416
r"""Load graph properties from a pickled file. By default, loads a file named
436
- :obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory`.
417
+ :obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory` .
437
418
438
419
Args:
439
420
filepath (str): Full path of input file.
@@ -442,7 +423,9 @@ def load(self, filepath: str = None):
442
423
filepath = os .path .join (self .data_directory , self .dataset_name + ".kgcnn.pickle" )
443
424
self .info ("Load pickled dataset..." )
444
425
in_list = load_pickle_file (filepath )
445
- self ._list = [GraphDict (x ) for x in in_list ]
426
+ self .clear ()
427
+ for x in in_list :
428
+ self .append (GraphDict (x ))
446
429
return self
447
430
448
431
def read_in_table_file (self , file_path : str = None , ** kwargs ):
@@ -519,7 +502,7 @@ def message_warning(msg):
519
502
for x in hyper_input :
520
503
if not isinstance (x , dict ):
521
504
message_error (
522
- "Wrong type of list item in `assert_valid_model_input`. Found %s but must be `dict`" % type (x ))
505
+ "Wrong type of list item in `assert_valid_model_input`. Found '%s' but must be `dict` . " % type (x ))
523
506
524
507
for x in hyper_input :
525
508
if "name" not in x :
@@ -681,7 +664,7 @@ def check_and_extend_splits(to_split):
681
664
def get_train_test_indices (self ,
682
665
train : str = "train" ,
683
666
test : str = "test" ,
684
- valid : t . Optional [str ] = None ,
667
+ valid : Optional [str ] = None ,
685
668
split_index : Union [int , list ] = 1 ,
686
669
shuffle : bool = False ,
687
670
seed : int = None
@@ -715,15 +698,15 @@ def get_train_test_indices(self,
715
698
"""
716
699
out_indices = []
717
700
if not isinstance (split_index , (list , tuple )):
718
- split_index_list : t . List [int ] = [split_index ]
701
+ split_index_list : List [int ] = [split_index ]
719
702
else :
720
- split_index_list : t . List [int ] = split_index
703
+ split_index_list : List [int ] = split_index
721
704
722
705
for split_index in split_index_list :
723
706
724
707
# This list will itself contain numpy arrays which are filled with graph indices of the dataset
725
708
# each element of this list will correspond to one property name (train, test...)
726
- graph_index_split_list : t . List [np .ndarray ] = []
709
+ graph_index_split_list : List [np .ndarray ] = []
727
710
728
711
for property_name in [train , test , valid ]:
729
712
@@ -735,14 +718,14 @@ def get_train_test_indices(self,
735
718
# This list will contain all the indices of the dataset elements (graphs) which are
736
719
# associated with the current iteration's split index for the current iteration's
737
720
# property name (train, test...)
738
- graph_index_list : t . List [int ] = []
721
+ graph_index_list : List [int ] = []
739
722
740
723
# "obtain_property" returns a list which contains only the property values corresponding to
741
724
# the given property name for each graph inside the dataset in the same order.
742
725
# In this case, this is supposed to be a split list, which is a list that contains integer
743
726
# indices, each representing one particular dataset split. The split list of each graph
744
727
# only contains those split indices to which that graph is associated.
745
- split_prop : t . List [t . List [int ]] = self .obtain_property (property_name )
728
+ split_prop : List [List [int ]] = self .obtain_property (property_name )
746
729
for index , split_list in enumerate (split_prop ):
747
730
if split_list is not None :
748
731
if split_index in split_list :
0 commit comments