Skip to content

Commit 1ec0bae

Browse files
committed
Changed MemoryGraphList base class to be a python list.
1 parent fd017d8 commit 1ec0bae

File tree

4 files changed

+62
-79
lines changed

4 files changed

+62
-79
lines changed

kgcnn/data/base.py

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33
import tensorflow as tf
44
import pandas as pd
55
import os
6-
import typing as t
7-
from typing import Union, List, Callable, Dict
8-
from collections.abc import MutableSequence
6+
# import typing as t
7+
from typing import Union, List, Callable, Dict, Optional
8+
# from collections.abc import MutableSequence
99

1010
from kgcnn.data.utils import save_pickle_file, load_pickle_file, ragged_tensor_from_nested_numpy
1111
from kgcnn.graph.base import GraphDict
1212

13-
logging.basicConfig() # Module logger
13+
# Module logger
14+
logging.basicConfig()
1415
module_logger = logging.getLogger(__name__)
1516
module_logger.setLevel(logging.INFO)
1617

1718

18-
class MemoryGraphList(MutableSequence):
19+
class MemoryGraphList(list):
1920
r"""Class to store a list of graph dictionaries in memory.
2021
21-
Contains a python list as property :obj:`_list`. The graph properties are defined by tensor-like (numpy) arrays
22+
Inherits from a python list. The graph properties are defined by tensor-like (numpy) arrays
2223
for indices, attributes, labels, symbol etc. in :obj:`GraphDict`, which are the items of the list.
2324
Access to items via `[]` indexing operator.
2425
@@ -45,22 +46,26 @@ class MemoryGraphList(MutableSequence):
4546
data.clean("range_indices") # Returns cleaned graph indices
4647
print(len(data))
4748
print(data[0])
49+
4850
"""
4951

50-
def __init__(self, input_list: list = None):
52+
_require_validate = True
53+
54+
def __init__(self, iterable: list = None):
5155
r"""Initialize an empty :obj:`MemoryGraphList` instance.
5256
5357
Args:
54-
input_list (list, MemoryGraphList): A list or :obj:`MemoryGraphList` of :obj:`GraphDict` items.
58+
iterable (list, MemoryGraphList): A list or :obj:`MemoryGraphList` of :obj:`GraphDict` items.
5559
"""
56-
self._list = []
60+
iterable = iterable if iterable is not None else []
61+
super(MemoryGraphList, self).__init__(iterable)
5762
self.logger = module_logger
58-
if input_list is None:
59-
input_list = []
60-
if isinstance(input_list, list):
61-
self._list = [GraphDict(x) for x in input_list]
62-
if isinstance(input_list, MemoryGraphList):
63-
self._list = [GraphDict(x) for x in input_list._list]
63+
self.validate()
64+
65+
def validate(self):
66+
for i, x in enumerate(self):
67+
if not isinstance(x, GraphDict):
68+
self[i] = GraphDict(x)
6469

6570
def assign_property(self, key: str, value: list):
6671
"""Assign a list of numpy arrays of a property to :obj:`GraphDict` s in this list.
@@ -77,24 +82,23 @@ def assign_property(self, key: str, value: list):
7782
return self
7883
if not isinstance(value, list):
7984
raise TypeError("Expected type 'list' to assign graph properties.")
80-
if len(self._list) == 0:
85+
if len(self) == 0:
8186
self.empty(len(value))
82-
if len(self._list) != len(value):
87+
if len(self) != len(value):
8388
raise ValueError("Can only store graph attributes from list with same length.")
8489
for i, x in enumerate(value):
85-
self._list[i].assign_property(key, x)
90+
self[i].assign_property(key, x)
8691
return self
8792

88-
def obtain_property(self, key: str) -> Union[list, None]:
93+
def obtain_property(self, key: str) -> Union[List, None]:
8994
r"""Returns a list with the values of all the graphs defined for the string property name `key`. If none of
9095
the graphs in the list have this property, returns None.
9196
9297
Args:
9398
key (str): The string name of the property to be retrieved for all the graphs contained in this list
9499
"""
95-
# "_list" is a list of GraphDicts, which means "prop_list" here will be a list of all the property
96100
# values for teach of the graphs which make up this list.
97-
prop_list = [x.obtain_property(key) for x in self._list]
101+
prop_list = [x.obtain_property(key) for x in self]
98102

99103
# If a certain string property is not set for a GraphDict, it will still return None. Here we check:
100104
# If all the items for our given property name are None then we know that this property is generally not
@@ -105,70 +109,42 @@ def obtain_property(self, key: str) -> Union[list, None]:
105109

106110
return prop_list
107111

108-
def __len__(self):
109-
"""Return the current length of this instance."""
110-
return len(self._list)
111-
112-
def __getitem__(self, item):
112+
def __getitem__(self, item) -> Union[GraphDict, List]:
113113
# Does not make a copy of the data, as a python list does.
114114
if isinstance(item, int):
115-
return self._list[item]
116-
new_list = MemoryGraphList()
115+
return super(MemoryGraphList, self).__getitem__(item)
117116
if isinstance(item, slice):
118-
return new_list._set_internal_list(self._list[item])
119-
if isinstance(item, list):
120-
return new_list._set_internal_list([self._list[int(i)] for i in item])
117+
return MemoryGraphList(super(MemoryGraphList, self).__getitem__(item))
118+
if isinstance(item, (list, tuple)):
119+
return MemoryGraphList([super(MemoryGraphList, self).__getitem__(int(i)) for i in item])
121120
if isinstance(item, np.ndarray):
122-
return new_list._set_internal_list([self._list[int(i)] for i in item])
121+
return MemoryGraphList([super(MemoryGraphList, self).__getitem__(int(i)) for i in item])
123122
raise TypeError("Unsupported type for `MemoryGraphList` items.")
124123

125124
def __setitem__(self, key, value):
126125
if not isinstance(value, GraphDict):
127126
raise TypeError("Require a GraphDict as list item.")
128-
self._list[key] = value
129-
130-
def __delitem__(self, key):
131-
value = self._list.__delitem__(key)
132-
return value
133-
134-
def __iter__(self):
135-
return iter(self._list)
127+
super(MemoryGraphList, self).__setitem__(key, value)
136128

137129
def __repr__(self):
138130
return "<{} [{}]>".format(type(self).__name__, "" if len(self) == 0 else self[0].__repr__() + " ...")
139131

140132
def append(self, graph):
141133
assert isinstance(graph, GraphDict), "Must append `GraphDict` to self."
142-
self._list.append(graph)
134+
super(MemoryGraphList, self).append(graph)
143135

144136
def insert(self, index: int, value) -> None:
145137
assert isinstance(value, GraphDict), "Must insert `GraphDict` to self."
146-
self._list.insert(index, value)
138+
super(MemoryGraphList, self).insert(index, value)
147139

148140
def __add__(self, other):
149141
assert isinstance(other, MemoryGraphList), "Must add `MemoryGraphList` to self."
150-
new_list = MemoryGraphList()
151-
new_list._set_internal_list(self._list + other._list)
152-
return new_list
153-
154-
def _set_internal_list(self, value: list):
155-
if not isinstance(value, list):
156-
raise TypeError("Must set list for `MemoryGraphList` internal assignment.")
157-
self._list = value
158-
return self
142+
return MemoryGraphList(super(MemoryGraphList, self).__add__(other))
159143

160144
def copy(self):
161145
"""Copy data in the list."""
162146
return MemoryGraphList([x.copy() for x in self])
163147

164-
def clear(self):
165-
"""Clear internal list.
166-
167-
Returns:
168-
None
169-
"""
170-
self._list.clear()
171-
172148
def empty(self, length: int):
173149
"""Create an empty list in place. Overwrites existing list.
174150
@@ -182,7 +158,9 @@ def empty(self, length: int):
182158
return self
183159
if length < 0:
184160
raise ValueError("Length of empty list must be >=0.")
185-
self._list = [GraphDict() for _ in range(length)]
161+
self.clear()
162+
for _ in range(length):
163+
self.append(GraphDict())
186164
return self
187165

188166
def update(self, other) -> None:
@@ -194,7 +172,7 @@ def update(self, other) -> None:
194172
@property
195173
def length(self):
196174
"""Length of list."""
197-
return len(self._list)
175+
return len(self)
198176

199177
@length.setter
200178
def length(self, value: int):
@@ -257,7 +235,7 @@ def map_list(self, method: Union[str, Callable], **kwargs):
257235
# Can add progress info here.
258236
# Method by name.
259237
if isinstance(method, str):
260-
for i, x in enumerate(self._list):
238+
for i, x in enumerate(self):
261239
# If this is a class method.
262240
if hasattr(x, method):
263241
getattr(x, method)(**kwargs)
@@ -268,7 +246,7 @@ def map_list(self, method: Union[str, Callable], **kwargs):
268246
raise NotImplementedError("Serialization for method in `map_list` is not yet supported")
269247
else:
270248
# For any callable method to map.
271-
for i, x in enumerate(self._list):
249+
for i, x in enumerate(self):
272250
method(x, **kwargs)
273251
return self
274252

@@ -319,7 +297,7 @@ def clean(self, inputs: Union[list, str]):
319297
self.logger.info("No invalid graphs for assigned properties found.")
320298
# Remove from the end via pop().
321299
for i in invalid_graphs:
322-
self._list.pop(int(i))
300+
self.pop(int(i))
323301
return invalid_graphs
324302

325303
# Alias of internal assign and obtain property.
@@ -336,14 +314,17 @@ class MemoryGraphDataset(MemoryGraphList):
336314
337315
.. code-block:: python
338316
317+
import numpy as np
339318
from kgcnn.data.base import MemoryGraphDataset
340319
dataset = MemoryGraphDataset(data_directory="", dataset_name="Example")
341320
# Methods of MemoryGraphList
342321
dataset.set("edge_indices", [np.array([[1, 0], [0, 1]])])
343322
dataset.set("edge_labels", [np.array([[0], [1]])])
344323
dataset.save()
324+
dataset.load()
325+
326+
The file directory and file name are used in child classes and in :obj:`save` and :obj:`load` .
345327
346-
The file directory and file name are used in child classes and in :obj:`save` and :obj:`load`.
347328
"""
348329

349330
fits_in_memory = True
@@ -420,20 +401,20 @@ def error(self, *args, **kwargs):
420401

421402
def save(self, filepath: str = None):
422403
r"""Save all graph properties to python dictionary as pickled file. By default, saves a file named
423-
:obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory`.
404+
:obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory` .
424405
425406
Args:
426407
filepath (str): Full path of output file. Default is None.
427408
"""
428409
if filepath is None:
429410
filepath = os.path.join(self.data_directory, self.dataset_name + ".kgcnn.pickle")
430411
self.info("Pickle dataset...")
431-
save_pickle_file([x.to_dict() for x in self._list], filepath)
412+
save_pickle_file([x.to_dict() for x in self], filepath)
432413
return self
433414

434415
def load(self, filepath: str = None):
435416
r"""Load graph properties from a pickled file. By default, loads a file named
436-
:obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory`.
417+
:obj:`dataset_name.kgcnn.pickle` in :obj:`data_directory` .
437418
438419
Args:
439420
filepath (str): Full path of input file.
@@ -442,7 +423,9 @@ def load(self, filepath: str = None):
442423
filepath = os.path.join(self.data_directory, self.dataset_name + ".kgcnn.pickle")
443424
self.info("Load pickled dataset...")
444425
in_list = load_pickle_file(filepath)
445-
self._list = [GraphDict(x) for x in in_list]
426+
self.clear()
427+
for x in in_list:
428+
self.append(GraphDict(x))
446429
return self
447430

448431
def read_in_table_file(self, file_path: str = None, **kwargs):
@@ -519,7 +502,7 @@ def message_warning(msg):
519502
for x in hyper_input:
520503
if not isinstance(x, dict):
521504
message_error(
522-
"Wrong type of list item in `assert_valid_model_input`. Found %s but must be `dict`" % type(x))
505+
"Wrong type of list item in `assert_valid_model_input`. Found '%s' but must be `dict` ." % type(x))
523506

524507
for x in hyper_input:
525508
if "name" not in x:
@@ -681,7 +664,7 @@ def check_and_extend_splits(to_split):
681664
def get_train_test_indices(self,
682665
train: str = "train",
683666
test: str = "test",
684-
valid: t.Optional[str] = None,
667+
valid: Optional[str] = None,
685668
split_index: Union[int, list] = 1,
686669
shuffle: bool = False,
687670
seed: int = None
@@ -715,15 +698,15 @@ def get_train_test_indices(self,
715698
"""
716699
out_indices = []
717700
if not isinstance(split_index, (list, tuple)):
718-
split_index_list: t.List[int] = [split_index]
701+
split_index_list: List[int] = [split_index]
719702
else:
720-
split_index_list: t.List[int] = split_index
703+
split_index_list: List[int] = split_index
721704

722705
for split_index in split_index_list:
723706

724707
# This list will itself contain numpy arrays which are filled with graph indices of the dataset
725708
# each element of this list will correspond to one property name (train, test...)
726-
graph_index_split_list: t.List[np.ndarray] = []
709+
graph_index_split_list: List[np.ndarray] = []
727710

728711
for property_name in [train, test, valid]:
729712

@@ -735,14 +718,14 @@ def get_train_test_indices(self,
735718
# This list will contain all the indices of the dataset elements (graphs) which are
736719
# associated with the current iteration's split index for the current iteration's
737720
# property name (train, test...)
738-
graph_index_list: t.List[int] = []
721+
graph_index_list: List[int] = []
739722

740723
# "obtain_property" returns a list which contains only the property values corresponding to
741724
# the given property name for each graph inside the dataset in the same order.
742725
# In this case, this is supposed to be a split list, which is a list that contains integer
743726
# indices, each representing one particular dataset split. The split list of each graph
744727
# only contains those split indices to which that graph is associated.
745-
split_prop: t.List[t.List[int]] = self.obtain_property(property_name)
728+
split_prop: List[List[int]] = self.obtain_property(property_name)
746729
for index, split_list in enumerate(split_prop):
747730
if split_list is not None:
748731
if split_index in split_list:

kgcnn/data/datasets/QM9Dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def remove_uncharacterized(self):
224224
indices = np.array([x[0] for x in data], dtype="int") - 1
225225
indices_backward = np.flip(np.sort(indices))
226226
for i in indices_backward:
227-
self._list.pop(int(i)) # Ideally use public pop() here.
227+
self.pop(int(i))
228228
self.info("Removed %s uncharacterized molecules." % len(indices_backward))
229229
self.__removed_uncharacterized = True
230230
return indices_backward

kgcnn/model/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,4 @@ def update_wrapper(*args, **kwargs):
139139

140140
return update_wrapper
141141

142-
return model_update_decorator
142+
return model_update_decorator

kgcnn/mol/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class MolGraphInterface:
1111
r"""The `MolGraphInterface` defines the base class interface to handle a molecular graph.
1212
1313
The method implementation to generate a molecule-instance from smiles etc. can be obtained from different backends
14-
like `RDkit`. The mol-instance of a chemical informatics package like `RDkit` is treated via composition.
14+
like `RDkit` . The mol-instance of a chemical informatics package like `RDkit` is treated via composition.
1515
The interface is designed to extract a graph from a mol instance, not to make a mol object from a graph.
1616
1717
"""

0 commit comments

Comments
 (0)