Skip to content

Commit af864ac

Browse files
committed
work on kgcnn.io.file
1 parent fd3352b commit af864ac

File tree

1 file changed

+124
-22
lines changed

1 file changed

+124
-22
lines changed

kgcnn/io/file.py

Lines changed: 124 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,120 @@
11
import os.path
2-
32
import numpy as np
3+
import tensorflow as tf
44
import h5py
55
from typing import List, Union
66

77

8+
def _check_for_inner_shape(array_list: List[np.ndarray]) -> Union[None, tuple, list]:
9+
"""Simple function to verify inner shape for list of numpy arrays."""
10+
# Cannot find inner shape for empty list.
11+
if len(array_list) == 0:
12+
return None
13+
# For fast check all items must be numpy arrays to get the inner shape easily.
14+
if not all(isinstance(x, np.ndarray) for x in array_list):
15+
return None
16+
shapes = [x.shape for x in array_list]
17+
# Must have all same rank.
18+
if not all(len(x) == len(shapes[0]) for x in shapes):
19+
return None
20+
# All Empty. No inner shape.
21+
if len(shapes[0]) == 0:
22+
return None
23+
# Empty inner shape.
24+
if len(shapes[0]) <= 1:
25+
return tuple([])
26+
# If all same inner shape.
27+
if all(x[1:] == shapes[0][1:] for x in shapes):
28+
return shapes[0][1:]
29+
30+
831
class RaggedTensorNumpyFile:
932

33+
_device = '/cpu:0'
34+
1035
def __init__(self, file_path: str, compressed: bool = False):
36+
"""Make class for a NPZ file.
37+
38+
Args:
39+
file_path (str): Path to file on disk.
40+
compressed (bool): Whether to use compression.
41+
"""
1142
self.file_path = file_path
1243
self.compressed = compressed
1344

14-
def write(self, ragged_array: List[np.ndarray]):
15-
inner_shape = ragged_array[0].shape
16-
values = np.concatenate([x for x in ragged_array], axis=0)
17-
row_splits = np.cumsum(np.array([len(x) for x in ragged_array], dtype="int64"), dtype="int64")
18-
row_splits = np.pad(row_splits, [1, 0])
19-
out = {"values": values, "row_splits": row_splits, "shape": np.array([])}
45+
def write(self, ragged_array: Union[tf.RaggedTensor, List[np.ndarray], list]):
46+
"""Write ragged array to file.
47+
48+
.. code-block:: python
49+
50+
from kgcnn.io.file import RaggedTensorNumpyFile
51+
import numpy as np
52+
data = [np.array([[0, 1],[0, 2]]), np.array([[1, 1]]), np.array([[0, 1],[2, 2], [0, 3]])]
53+
f = RaggedTensorNumpyFile("test.npz")
54+
f.write(data)
55+
print(f.read())
56+
57+
Args:
58+
ragged_array (list, tf.RaggedTensor): List or list of numpy arrays.
59+
60+
Returns:
61+
None.
62+
"""
63+
if not isinstance(ragged_array, tf.RaggedTensor):
64+
with tf.device(self._device):
65+
ragged_array = tf.ragged.constant(ragged_array, inner_shape=_check_for_inner_shape(ragged_array))
66+
assert ragged_array.ragged_rank == 1, "Only support for ragged_rank=1 at the moment."
67+
values = np.array(ragged_array.values)
68+
row_splits = np.array(ragged_array.row_splits)
69+
shape = np.array([x if x is not None else 0 for x in ragged_array.shape], dtype="uint64")
70+
ragged_rank = np.array(ragged_array.ragged_rank)
71+
rank = np.array(len(shape))
72+
out = {"values": values,
73+
"row_splits": row_splits,
74+
"shape": shape,
75+
"ragged_rank": ragged_rank,
76+
"rank": rank}
2077
if self.compressed:
2178
np.savez_compressed(self.file_path, **out)
2279
else:
2380
np.savez(self.file_path, **out)
2481

25-
def read(self):
82+
def read(self, return_as_tensor: bool = False):
83+
"""Read the file into memory.
84+
85+
Args:
86+
return_as_tensor: Whether to return tf.RaggedTensor.
87+
88+
Returns:
89+
tf.RaggedTensor: Ragged tensor form file.
90+
"""
2691
data = np.load(self.file_path)
2792
values = data.get("values")
2893
row_splits = data.get("row_splits")
94+
if return_as_tensor:
95+
with tf.device(self._device):
96+
out = tf.RaggedTensor.from_row_splits(values, row_splits)
97+
return out
2998
return np.split(values, row_splits[1:-1])
3099

31100
def __getitem__(self, item):
32-
raise NotImplementedError("Not implemented for file reference load.")
101+
raise NotImplementedError("Not implemented for file reference item load.")
102+
103+
def exists(self):
104+
return os.path.exists(self.file_path)
33105

34106

35107
class RaggedTensorHDFile:
36108

109+
_device = '/cpu:0'
110+
37111
def __init__(self, file_path: str, compressed: bool = None):
112+
"""Make class for a HDF5 file.
113+
114+
Args:
115+
file_path (str): Path to file on disk.
116+
compressed: Compression to use. Not used at the moment.
117+
"""
38118
self.file_path = file_path
39119
self.compressed = compressed
40120

@@ -43,10 +123,10 @@ def write(self, ragged_array: List[np.ndarray]):
43123
44124
.. code-block:: python
45125
46-
from kgcnn.io.file import RaggedArrayHDFile
126+
from kgcnn.io.file import RaggedTensorHDFile
47127
import numpy as np
48128
data = [np.array([[0, 1],[0, 2]]), np.array([[1, 1]]), np.array([[0, 1],[2, 2], [0, 3]])]
49-
f = RaggedArrayHDFile("test.hdf5")
129+
f = RaggedTensorHDFile("test.hdf5")
50130
f.write(data)
51131
print(f.read())
52132
@@ -56,24 +136,46 @@ def write(self, ragged_array: List[np.ndarray]):
56136
Returns:
57137
None.
58138
"""
59-
inner_shape = ragged_array[0].shape
60-
values = np.concatenate([x for x in ragged_array], axis=0)
61-
row_splits = np.cumsum(np.array([len(x) for x in ragged_array], dtype="int64"), dtype="int64")
62-
row_splits = np.pad(row_splits, [1, 0])
139+
if not isinstance(ragged_array, tf.RaggedTensor):
140+
with tf.device(self._device):
141+
ragged_array = tf.ragged.constant(ragged_array, inner_shape=_check_for_inner_shape(ragged_array))
142+
assert ragged_array.ragged_rank == 1, "Only support for ragged_rank=1 at the moment."
143+
values = np.array(ragged_array.values)
144+
row_splits = np.array(ragged_array.row_splits)
145+
shape = np.array([x if x is not None else 0 for x in ragged_array.shape], dtype="uint64")
146+
ragged_rank = np.array(ragged_array.ragged_rank)
147+
rank = np.array(len(shape))
63148
with h5py.File(self.file_path, "w") as file:
64-
file.create_dataset("values", data=values, maxshape=[None] + list(inner_shape)[1:])
149+
file.create_dataset("values", data=values,
150+
maxshape=[x if i > 0 else None for i, x in enumerate(values.shape)])
65151
file.create_dataset("row_splits", data=row_splits, maxshape=(None, ))
66-
file.create_dataset("shape", data=np.array([]))
152+
file.create_dataset("shape", data=shape)
153+
file.create_dataset("rank", data=rank)
154+
file.create_dataset("ragged_rank", data=ragged_rank)
155+
156+
def read(self, return_as_tensor: bool = False):
157+
"""Read the file into memory.
67158
68-
def read(self):
159+
Args:
160+
return_as_tensor: Whether to return tf.RaggedTensor.
161+
162+
Returns:
163+
tf.RaggedTensor: Ragged tensor form file.
164+
"""
69165
with h5py.File(self.file_path, "r") as file:
70-
data = np.split(file["values"][()], file["row_splits"][1:-1])
71-
return data
166+
values = file["values"]
167+
row_splits = file["row_splits"]
168+
if return_as_tensor:
169+
with tf.device(self._device):
170+
out = tf.RaggedTensor.from_row_splits(np.array(values), np.array(row_splits))
171+
else:
172+
out = np.split(values, row_splits[1:-1])
173+
return out
72174

73175
def __getitem__(self, item: int):
74176
with h5py.File(self.file_path, "r") as file:
75177
row_splits = file["row_splits"]
76-
out_data = file["values"][row_splits[item]:row_splits[item+1]]
178+
out_data = np.array(file["values"][row_splits[item]:row_splits[item+1]])
77179
return out_data
78180

79181
def append(self, item):
@@ -106,7 +208,7 @@ def append_multiple(self, items: list):
106208

107209
def __len__(self):
108210
with h5py.File(self.file_path, "r") as file:
109-
num_row_splits = file["row_splits"].shape[0]
211+
num_row_splits = int(file["row_splits"].shape[0])
110212
# length is num_row_splits - 1
111213
return num_row_splits-1
112214

0 commit comments

Comments
 (0)