Skip to content

Commit 44e5954

Browse files
committed
generalize OneHotEncoder Transporting
1 parent ea25aa4 commit 44e5954

File tree

1 file changed

+1
-21
lines changed

1 file changed

+1
-21
lines changed

pymilo/transporters/preprocessing_transporter.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,11 @@ def serialize_pre_module(self, pre_module):
7676
:type pre_module: sklearn.preprocessing
7777
:return: pymilo serialized pre_module
7878
"""
79-
_type = get_sklearn_type(pre_module)
80-
associated_class = SKLEARN_PREPROCESSING_TABLE[_type]
81-
if _type == "OneHotEncoder":
82-
return {
83-
"pymilo-bypass": True,
84-
"pymilo-preprocessing-type": _type,
85-
"pymilo-preprocessing-data": {
86-
"sparse_output": pre_module.sparse_output if has_named_parameter(associated_class, "sparse_output") else pre_module.sparse
87-
}
88-
}
89-
9079
gdst = GeneralDataStructureTransporter()
9180
gdst.transport(pre_module, Command.SERIALIZE, False)
9281
return {
9382
"pymilo-bypass": True,
94-
"pymilo-preprocessing-type": _type,
83+
"pymilo-preprocessing-type": get_sklearn_type(pre_module),
9584
"pymilo-preprocessing-data": pre_module.__dict__
9685
}
9786

@@ -105,16 +94,7 @@ def deserialize_pre_module(self, serialized_pre_module):
10594
:return: retrieved associated sklearn.preprocessing module
10695
"""
10796
data = serialized_pre_module["pymilo-preprocessing-data"]
108-
_type = serialized_pre_module["pymilo-preprocessing-type"]
10997
associated_type = SKLEARN_PREPROCESSING_TABLE[serialized_pre_module["pymilo-preprocessing-type"]]
110-
111-
if _type == "OneHotEncoder":
112-
if has_named_parameter(associated_type, "sparse_output"):
113-
return associated_type(
114-
sparse_output=serialized_pre_module["pymilo-preprocessing-data"]["sparse_output"])
115-
elif has_named_parameter(associated_type, "sparse"):
116-
return associated_type(sparse=serialized_pre_module["pymilo-preprocessing-data"]["sparse_output"])
117-
11898
retrieved_pre_module = associated_type()
11999
gdst = GeneralDataStructureTransporter()
120100
for key in data:

0 commit comments

Comments
 (0)