Skip to content

Commit a1799a1

Browse files
committed
add numpy type itself transporting (for OneHotEncoder)
1 parent 44e5954 commit a1799a1

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

pymilo/transporters/general_data_structure_transporter.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,16 @@ def serialize(self, data, key, model_type):
9393
:type model_type: str
9494
:return: pymilo serialized output of data[key]
9595
"""
96+
if isinstance(data[key], type):
97+
raw_type = str(data[key])
98+
raw_type = "numpy" + str(raw_type).split("numpy")[-1][:-2]
99+
if raw_type in NUMPY_TYPE_DICT.keys():
100+
data[key] = {
101+
"np-type": "numpy.dtype",
102+
"value": raw_type
103+
}
96104
# 1. Handling numpy infinity, ransac
97-
if isinstance(data[key], np.float64):
105+
elif isinstance(data[key], np.float64):
98106
if np.inf == data[key]:
99107
data[key] = {
100108
"np-type": "numpy.infinity",
@@ -209,7 +217,7 @@ def get_deserialized_dict(self, content):
209217
return self.deep_deserialize_ndarray(content)
210218

211219
if check_str_in_iterable("np-type", content) and check_str_in_iterable("value", content):
212-
return NUMPY_TYPE_DICT[content["np-type"]](content["value"])
220+
return self.get_deserialized_regular_primary_types(content)
213221

214222
for key in content:
215223

@@ -271,6 +279,8 @@ def get_deserialized_regular_primary_types(self, content):
271279
:return: the associated np.int32|np.int64|np.inf
272280
"""
273281
if "np-type" in content:
282+
if content["np-type"] == "numpy.dtype":
283+
return NUMPY_TYPE_DICT[content["np-type"]](NUMPY_TYPE_DICT[content['value']])
274284
return NUMPY_TYPE_DICT[content["np-type"]](content['value'])
275285

276286
def is_numpy_primary_type(self, content):
@@ -359,8 +369,7 @@ def deserialize_primitive_type(self, primitive):
359369
if is_primitive(primitive):
360370
return primitive
361371
elif check_str_in_iterable("np-type", primitive):
362-
return NUMPY_TYPE_DICT[primitive["np-type"]
363-
](primitive['value'])
372+
return self.get_deserialized_regular_primary_types(primitive)
364373
else:
365374
return primitive
366375

0 commit comments

Comments
 (0)