Skip to content

Commit 46721d9

Browse files
committed
[Util] a few util fxn fp16 fixes that escaped the fp16 PR
1 parent 235ad4f commit 46721d9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/qonnx/util/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def random_string(stringLength=6):
130130
def interleave_matrix_outer_dim_from_partitions(matrix, n_partitions):
131131
"""Interleave the outermost dimension of a matrix from given
132132
partitions (n_partitions)."""
133-
if type(matrix) != np.ndarray or matrix.dtype != np.float32:
133+
if type(matrix) != np.ndarray or matrix.dtype not in [np.float32, np.float16]:
134134
# try to convert to a float numpy array (container dtype is float)
135135
matrix = np.asarray(matrix, dtype=np.float32)
136136
shp = matrix.shape
@@ -179,7 +179,7 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False):
179179
will be inserted after the existing values; otherwise it will be split
180180
evenly between before and after the existing values, with one extra value
181181
inserted after if the padding amount is not divisible by two."""
182-
if type(ndarray) != np.ndarray or ndarray.dtype != np.float32:
182+
if type(ndarray) != np.ndarray or ndarray.dtype not in [np.float32, np.float16]:
183183
# try to convert to a float numpy array (container dtype is float)
184184
ndarray = np.asarray(ndarray, dtype=np.float32)
185185
assert ndarray.ndim == len(

0 commit comments

Comments
 (0)