diff --git a/CHANGELOG.md b/CHANGELOG.md
index 30b79ae5..dc5d84af 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,13 +6,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
+- `StandardScaler` Transformer in `pymilo_param.py`
+- `PreprocessingTransporter` Transporter
- ndarray shape config in `GeneralDataStructure` Transporter
- `util.py` in chains
- `BinMapperTransporter` Transporter
- `BunchTransporter` Transporter
- `GeneratorTransporter` Transporter
-- `LabelEncoderTransporter` Transporter
-- `OneHotEncoderTransporter` Transporter
- `TreePredictorTransporter` Transporter
- `AdaboostClassifier` model
- `AdaboostRegressor` model
@@ -37,6 +37,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Ensemble chain
- `SECURITY.md`
### Changed
+- `Pipeline` test updated
+- `LabelBinarizer`,`LabelEncoder` and `OneHotEncoder` got embedded in `PreprocessingTransporter`
+- Preprocessing support added to Ensemble chain
+- Preprocessing params initialized in `pymilo_param`
- `util.py` in utils updated
- `test_pymilo.py` updated
- `pymilo_func.py` updated
diff --git a/SUPPORTED_MODELS.md b/SUPPORTED_MODELS.md
index 73730db5..bafc2537 100644
--- a/SUPPORTED_MODELS.md
+++ b/SUPPORTED_MODELS.md
@@ -630,4 +630,9 @@
LabelEncoder |
>=0.8 |
+
+ 4 |
+ StandardScaler |
+ >=0.8 |
+
diff --git a/pymilo/chains/clustering_chain.py b/pymilo/chains/clustering_chain.py
index 958374f4..beb09fc4 100644
--- a/pymilo/chains/clustering_chain.py
+++ b/pymilo/chains/clustering_chain.py
@@ -5,6 +5,7 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.function_transporter import FunctionTransporter
from ..transporters.cfnode_transporter import CFNodeTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..utils.util import get_sklearn_type
@@ -15,6 +16,7 @@
bisecting_kmeans_support = SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED
CLUSTERING_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"FunctionTransporter": FunctionTransporter(),
"CFNodeTransporter": CFNodeTransporter(),
diff --git a/pymilo/chains/decision_tree_chain.py b/pymilo/chains/decision_tree_chain.py
index 886112a8..951f708c 100644
--- a/pymilo/chains/decision_tree_chain.py
+++ b/pymilo/chains/decision_tree_chain.py
@@ -5,6 +5,7 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.tree_transporter import TreeTransporter
from ..transporters.randomstate_transporter import RandomStateTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..utils.util import get_sklearn_type
@@ -16,6 +17,7 @@
DECISION_TREE_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"RandomStateTransporter": RandomStateTransporter(),
"TreeTransporter": TreeTransporter(),
diff --git a/pymilo/chains/ensemble_chain.py b/pymilo/chains/ensemble_chain.py
index 32af5085..5ae373cc 100644
--- a/pymilo/chains/ensemble_chain.py
+++ b/pymilo/chains/ensemble_chain.py
@@ -4,12 +4,11 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.randomstate_transporter import RandomStateTransporter
from ..transporters.lossfunction_transporter import LossFunctionTransporter
-from ..transporters.onehotencoder_transporter import OneHotEncoderTransporter
from ..transporters.bunch_transporter import BunchTransporter
-from ..transporters.labelencoder_transporter import LabelEncoderTransporter
from ..transporters.generator_transporter import GeneratorTransporter
from ..transporters.treepredictor_transporter import TreePredictorTransporter
from ..transporters.binmapper_transporter import BinMapperTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_ENSEMBLE_TABLE
@@ -27,14 +26,13 @@
import copy
ENSEMBLE_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"TreePredictorTransporter": TreePredictorTransporter(),
"BinMapperTransporter": BinMapperTransporter(),
"GeneratorTransporter": GeneratorTransporter(),
"RandomStateTransporter": RandomStateTransporter(),
"LossFunctionTransporter": LossFunctionTransporter(),
- "OneHotEncoderTransporter": OneHotEncoderTransporter(),
- "LabelEncoderTransporter": LabelEncoderTransporter(),
"BunchTransporter": BunchTransporter(),
}
@@ -166,14 +164,18 @@ def serialize_ensemble(ensemble_object):
for key, value in ensemble_object.__dict__.items():
if isinstance(value, list):
has_inner_tuple_with_ml_model = False
+ pt = PreprocessingTransporter()
for idx, item in enumerate(value):
if isinstance(item, tuple):
listed_tuple = list(item)
for inner_idx, inner_item in enumerate(listed_tuple):
- has_inner_model, result = serialize_possible_ml_model(inner_item)
- if has_inner_model:
- has_inner_tuple_with_ml_model = True
- listed_tuple[inner_idx] = result
+ if pt.is_preprocessing_module(inner_item):
+ listed_tuple[inner_idx] = pt.serialize_pre_module(inner_item)
+ else:
+ has_inner_model, result = serialize_possible_ml_model(inner_item)
+ if has_inner_model:
+ has_inner_tuple_with_ml_model = True
+ listed_tuple[inner_idx] = result
value[idx] = listed_tuple
else:
value[idx] = serialize_possible_ml_model(item)[1]
@@ -325,12 +327,16 @@ def deserialize_ensemble(ensemble, is_inner_model=False):
value) and value["pymiloed-data-structure"] == "list of (str, estimator) tuples":
listed_tuples = value["pymiloed-data"]
list_of_tuples = []
+ pt = PreprocessingTransporter()
for listed_tuple in listed_tuples:
- name, serialized_ml_model = listed_tuple
+ name, serialized_model = listed_tuple
+ retrieved_model = pt.deserialize_pre_module(serialized_model) if pt.is_preprocessing_module(
+ serialized_model) else deserialize_possible_ml_model(serialized_model)[1]
list_of_tuples.append(
- (name, deserialize_possible_ml_model(serialized_ml_model)[1])
+ (name, retrieved_model)
)
data[key] = list_of_tuples
+
elif GeneralDataStructureTransporter().is_deserialized_ndarray(value):
has_inner_model, result = deserialize_models_in_ndarray(value)
if has_inner_model:
diff --git a/pymilo/chains/linear_model_chain.py b/pymilo/chains/linear_model_chain.py
index 9efefa31..02522bd1 100644
--- a/pymilo/chains/linear_model_chain.py
+++ b/pymilo/chains/linear_model_chain.py
@@ -5,7 +5,7 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.baseloss_transporter import BaseLossTransporter
from ..transporters.lossfunction_transporter import LossFunctionTransporter
-from ..transporters.labelbinarizer_transporter import LabelBinarizerTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_LINEAR_MODEL_TABLE
from ..utils.util import get_sklearn_type, is_iterable
@@ -16,10 +16,11 @@
LINEAR_MODEL_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"BaseLossTransporter": BaseLossTransporter(),
"LossFunctionTransporter": LossFunctionTransporter(),
- "LabelBinarizerTransporter": LabelBinarizerTransporter()}
+}
def is_linear_model(model):
@@ -101,9 +102,9 @@ def serialize_linear_model(linear_model_object):
for key in linear_model_object.__dict__:
if is_linear_model(linear_model_object.__dict__[key]):
linear_model_object.__dict__[key] = {
- "pymilo-inner-model-data": transport_linear_model(linear_model_object.__dict__[key], Command.SERIALIZE),
+ "pymilo-inner-model-data": transport_linear_model(linear_model_object.__dict__[key], Command.SERIALIZE, True),
"pymilo-inner-model-type": get_sklearn_type(linear_model_object.__dict__[key]),
- "by-pass": True
+ "pymilo-by-pass": True
}
# now serializing non-linear model fields
for transporter in LINEAR_MODEL_CHAIN:
diff --git a/pymilo/chains/naive_bayes_chain.py b/pymilo/chains/naive_bayes_chain.py
index c43c420d..e065eb5c 100644
--- a/pymilo/chains/naive_bayes_chain.py
+++ b/pymilo/chains/naive_bayes_chain.py
@@ -3,6 +3,7 @@
from ..transporters.transporter import Command
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_NAIVE_BAYES_TABLE
from ..exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
@@ -13,6 +14,7 @@
from traceback import format_exc
NAIVE_BAYES_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
}
diff --git a/pymilo/chains/neighbours_chain.py b/pymilo/chains/neighbours_chain.py
index 120c9844..2891678c 100644
--- a/pymilo/chains/neighbours_chain.py
+++ b/pymilo/chains/neighbours_chain.py
@@ -4,6 +4,7 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.neighbors_tree_transporter import NeighborsTreeTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_NEIGHBORS_TABLE
from ..exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
@@ -14,6 +15,7 @@
from traceback import format_exc
NEIGHBORS_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"NeighborsTreeTransporter": NeighborsTreeTransporter(),
}
diff --git a/pymilo/chains/neural_network_chain.py b/pymilo/chains/neural_network_chain.py
index 70430b99..ed73c664 100644
--- a/pymilo/chains/neural_network_chain.py
+++ b/pymilo/chains/neural_network_chain.py
@@ -6,7 +6,7 @@
from ..transporters.randomstate_transporter import RandomStateTransporter
from ..transporters.sgdoptimizer_transporter import SGDOptimizerTransporter
from ..transporters.adamoptimizer_transporter import AdamOptimizerTransporter
-from ..transporters.labelbinarizer_transporter import LabelBinarizerTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_NEURAL_NETWORK_TABLE
@@ -19,11 +19,11 @@
NEURAL_NETWORK_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"RandomStateTransporter": RandomStateTransporter(),
"SGDOptimizer": SGDOptimizerTransporter(),
"AdamOptimizerTransporter": AdamOptimizerTransporter(),
- "LabelBinarizerTransporter": LabelBinarizerTransporter(),
}
diff --git a/pymilo/chains/svm_chain.py b/pymilo/chains/svm_chain.py
index d2c1bea4..f3985c81 100644
--- a/pymilo/chains/svm_chain.py
+++ b/pymilo/chains/svm_chain.py
@@ -4,6 +4,7 @@
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.randomstate_transporter import RandomStateTransporter
+from ..transporters.preprocessing_transporter import PreprocessingTransporter
from ..pymilo_param import SKLEARN_SVM_TABLE
from ..exceptions.serialize_exception import PymiloSerializationException, SerilaizatoinErrorTypes
@@ -14,6 +15,7 @@
from traceback import format_exc
SVM_CHAIN = {
+ "PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"RandomStateTransporter": RandomStateTransporter(),
}
diff --git a/pymilo/pymilo_param.py b/pymilo/pymilo_param.py
index cacf5d4b..93c759c5 100644
--- a/pymilo/pymilo_param.py
+++ b/pymilo/pymilo_param.py
@@ -1,14 +1,6 @@
# -*- coding: utf-8 -*-
"""Parameters and constants."""
-from numpy import uint8
-from numpy import intc
-from numpy import inf
-from numpy import float64
-from numpy import int32
-from numpy import int64
-from numpy import uint64
-from sklearn.preprocessing import LabelBinarizer
-
+import numpy as np
import sklearn.linear_model as linear_model
import sklearn.neural_network as neural_network
import sklearn.tree as tree
@@ -20,7 +12,7 @@
import sklearn.dummy as dummy
import sklearn.ensemble as ensemble
import sklearn.pipeline as pipeline
-
+import sklearn.preprocessing as preprocessing
quantile_regressor_support = False
try:
@@ -205,10 +197,17 @@
"Pipeline": pipeline.Pipeline,
}
+SKLEARN_PREPROCESSING_TABLE = {
+ "StandardScaler": preprocessing.StandardScaler,
+ "OneHotEncoder": preprocessing.OneHotEncoder,
+ "LabelBinarizer": preprocessing.LabelBinarizer,
+ "LabelEncoder": preprocessing.LabelEncoder,
+}
+
KEYS_NEED_PREPROCESSING_BEFORE_DESERIALIZATION = {
- "_label_binarizer": LabelBinarizer, # in Ridge Classifier
- "active_": int32, # in Lasso Lars
- "n_nonzero_coefs_": int64, # in OMP-CV
+ "_label_binarizer": preprocessing.LabelBinarizer, # in Ridge Classifier
+ "active_": np.int32, # in Lasso Lars
+ "n_nonzero_coefs_": np.int64, # in OMP-CV
"scores_": dict, # in Logistic Regression CV,
"_base_loss": {}, # BaseLoss in Logistic Regression,
"loss_function_": {}, # LossFunction in SGD Classifier,
@@ -216,13 +215,14 @@
}
NUMPY_TYPE_DICT = {
- "numpy.intc": intc,
- "numpy.int32": int32,
- "numpy.int64": int64,
- "numpy.float64": float64,
- "numpy.infinity": lambda _: inf,
- "numpy.uint8": uint8,
- "numpy.uint64": uint64,
+ "numpy.intc": np.intc,
+ "numpy.int32": np.int32,
+ "numpy.int64": np.int64,
+ "numpy.float64": np.float64,
+ "numpy.infinity": lambda _: np.inf,
+ "numpy.uint8": np.uint8,
+ "numpy.uint64": np.uint64,
+ "numpy.dtype": np.dtype,
}
EXPORTED_MODELS_PATH = {
diff --git a/pymilo/transporters/general_data_structure_transporter.py b/pymilo/transporters/general_data_structure_transporter.py
index 9ab3e8ee..b39765ad 100644
--- a/pymilo/transporters/general_data_structure_transporter.py
+++ b/pymilo/transporters/general_data_structure_transporter.py
@@ -93,8 +93,16 @@ def serialize(self, data, key, model_type):
:type model_type: str
:return: pymilo serialized output of data[key]
"""
+ if isinstance(data[key], type):
+ raw_type = str(data[key])
+ raw_type = "numpy" + str(raw_type).split("numpy")[-1][:-2]
+ if raw_type in NUMPY_TYPE_DICT.keys():
+ data[key] = {
+ "np-type": "numpy.dtype",
+ "value": raw_type
+ }
# 1. Handling numpy infinity, ransac
- if isinstance(data[key], np.float64):
+ elif isinstance(data[key], np.float64):
if np.inf == data[key]:
data[key] = {
"np-type": "numpy.infinity",
@@ -209,7 +217,7 @@ def get_deserialized_dict(self, content):
return self.deep_deserialize_ndarray(content)
if check_str_in_iterable("np-type", content) and check_str_in_iterable("value", content):
- return NUMPY_TYPE_DICT[content["np-type"]](content["value"])
+ return self.get_deserialized_regular_primary_types(content)
for key in content:
@@ -271,6 +279,8 @@ def get_deserialized_regular_primary_types(self, content):
:return: the associated np.int32|np.int64|np.inf
"""
if "np-type" in content:
+ if content["np-type"] == "numpy.dtype":
+ return NUMPY_TYPE_DICT[content["np-type"]](NUMPY_TYPE_DICT[content['value']])
return NUMPY_TYPE_DICT[content["np-type"]](content['value'])
def is_numpy_primary_type(self, content):
@@ -359,8 +369,7 @@ def deserialize_primitive_type(self, primitive):
if is_primitive(primitive):
return primitive
elif check_str_in_iterable("np-type", primitive):
- return NUMPY_TYPE_DICT[primitive["np-type"]
- ](primitive['value'])
+ return self.get_deserialized_regular_primary_types(primitive)
else:
return primitive
diff --git a/pymilo/transporters/labelbinarizer_transporter.py b/pymilo/transporters/labelbinarizer_transporter.py
deleted file mode 100644
index 456b3517..00000000
--- a/pymilo/transporters/labelbinarizer_transporter.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# -*- coding: utf-8 -*-
-"""PyMilo LabelBinarizer transporter."""
-from ..pymilo_param import KEYS_NEED_PREPROCESSING_BEFORE_DESERIALIZATION
-from sklearn import preprocessing
-import numpy as np
-from .transporter import AbstractTransporter
-from .general_data_structure_transporter import GeneralDataStructureTransporter
-
-
-class LabelBinarizerTransporter(AbstractTransporter):
- """Customized PyMilo Transporter developed to handle LabelBinarizer field(for Ridge Classifier(+[CV]))."""
-
- def serialize(self, data, key, model_type):
- """
- Serialize the LabelBinarizer field(if there is).
-
- serialize the data[key] of the given model which type is model_type.
- basically in order to fully serialize a model, we should traverse over all the keys of its data dictionary and
- pass it through the chain of associated transporters to get fully serialized.
-
- :param data: the internal data dictionary of the given model
- :type data: dict
- :param key: the special key of the data param, which we're going to serialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which data dictionary is given as the data param
- :type model_type: str
- :return: pymilo serialized output of data[key]
- """
- if isinstance(data[key], preprocessing.LabelBinarizer):
- data[key] = self.get_serialized_label_binarizer(data[key])
- return data[key]
-
- def get_serialized_label_binarizer(self, label_binarizer):
- """
- Serialize a LabelBinarizer object.
-
- :param label_binarizer: a label_binarizer object
- :type label_binarizer: sklearn.preprocessing.LabelBinarizer
- :return: pymilo serialized output of label_binarizer object
- """
- data = label_binarizer.__dict__
- for key in data:
- if isinstance(data[key], np.ndarray):
- data[key] = GeneralDataStructureTransporter().deep_serialize_ndarray(data[key])
- return data
-
- def deserialize(self, data, key, model_type):
- """
- Deserialize the LabelBinarizer field(if there is).
-
- deserialize the data[key] of the given model which type is model_type.
- basically in order to fully deserialize a model, we should traverse over all the keys of its serialized data dictionary and
- pass it through the chain of associated transporters to get fully deserialized.
-
- :param data: the internal data dictionary of the associated json file
- of the ML model which is generated previously by pymilo export.
- :type data: dict
- :param key: the special key of the data param, which we're going to deserialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which internal serialized data dictionary is given as the data param
- :type model_type: str
- :return: pymilo deserialized output of data[key]
- """
- content = data[key]
- if key != "_label_binarizer":
- return content
- return self.get_deserialized_label_binarizer(content)
-
- def get_deserialized_label_binarizer(self, content):
- """
- Deserialize the pymilo serialized labelBinarizer field of the associated ML model.
-
- :param content: a label_binarizer object
- :type content: sklearn.preprocessing.LabelBinarizer
- :return: a sklearn.preprocessing.LabelBinarizer instance derived from the
- pymilo deserialized output of the previously pymilo serialized label_binarizer field.
- """
- raw_lb = KEYS_NEED_PREPROCESSING_BEFORE_DESERIALIZATION["_label_binarizer"](
- )
- for item in content:
- setattr(raw_lb, item, content[item])
- return raw_lb
diff --git a/pymilo/transporters/labelencoder_transporter.py b/pymilo/transporters/labelencoder_transporter.py
deleted file mode 100644
index b24da762..00000000
--- a/pymilo/transporters/labelencoder_transporter.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# -*- coding: utf-8 -*-
-"""PyMilo LabelEncoder transporter."""
-from sklearn.preprocessing import LabelEncoder
-from ..utils.util import is_primitive, check_str_in_iterable
-from .transporter import AbstractTransporter
-from .general_data_structure_transporter import GeneralDataStructureTransporter
-
-
-class LabelEncoderTransporter(AbstractTransporter):
- """Customized PyMilo Transporter developed to LabelEncoder objects."""
-
- def serialize(self, data, key, model_type):
- """
- Serialize LabelEncoder object.
-
- serialize the data[key] of the given model which type is model_type.
- basically in order to fully serialize a model, we should traverse over all the keys of its data dictionary and
- pass it through the chain of associated transporters to get fully serialized.
-
- :param data: the internal data dictionary of the given model
- :type data: dict
- :param key: the special key of the data param, which we're going to serialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which data dictionary is given as the data param
- :type model_type: str
- :return: pymilo serialized output of data[key]
- """
- if isinstance(data[key], LabelEncoder):
- label_encoder = data[key]
- data[key] = {
- "pymilo-bypass": True, "pymilo-labelencoder": {
- "classes_": GeneralDataStructureTransporter().deep_serialize_ndarray(
- label_encoder.__dict__["classes_"])}}
- return data[key]
-
- def deserialize(self, data, key, model_type):
- """
- Deserialize previously pymilo serialized LabelEncoder object.
-
- deserialize the data[key] of the given model which type is model_type.
- basically in order to fully deserialize a model, we should traverse over all the keys of its serialized data dictionary and
- pass it through the chain of associated transporters to get fully deserialized.
-
- :param data: the internal data dictionary of the associated json file of the ML model which is generated previously by
- pymilo export.
- :type data: dict
- :param key: the special key of the data param, which we're going to deserialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which internal serialized data dictionary is given as the data param
- :type model_type: str
- :return: pymilo deserialized output of data[key]
- """
- content = data[key]
- if is_primitive(content) or content is None:
- return content
-
- if check_str_in_iterable("pymilo-labelencoder", content):
- serialized_le = content["pymilo-labelencoder"]
- label_encoder = LabelEncoder()
- setattr(
- label_encoder,
- "classes_",
- GeneralDataStructureTransporter().deep_deserialize_ndarray(
- serialized_le["classes_"]))
- return label_encoder
-
- return content
diff --git a/pymilo/transporters/onehotencoder_transporter.py b/pymilo/transporters/onehotencoder_transporter.py
deleted file mode 100644
index 0e29d5fd..00000000
--- a/pymilo/transporters/onehotencoder_transporter.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# -*- coding: utf-8 -*-
-"""PyMilo OneHotEncoder transporter."""
-from sklearn.preprocessing import OneHotEncoder
-from ..utils.util import is_primitive, check_str_in_iterable, has_named_parameter
-from .transporter import AbstractTransporter
-
-
-class OneHotEncoderTransporter(AbstractTransporter):
- """Customized PyMilo Transporter developed to handle OneHotEncoder objects."""
-
- def serialize(self, data, key, model_type):
- """
- Serialize OneHotEncoder object.
-
- serialize the data[key] of the given model which type is model_type.
- basically in order to fully serialize a model, we should traverse over all the keys of its data dictionary and
- pass it through the chain of associated transporters to get fully serialized.
-
- :param data: the internal data dictionary of the given model
- :type data: dict
- :param key: the special key of the data param, which we're going to serialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which data dictionary is given as the data param
- :type model_type: str
- :return: pymilo serialized output of data[key]
- """
- if isinstance(data[key], OneHotEncoder):
- data[key] = {
- "pymilo-onehotencoder": {
- "sparse_output": data["sparse_output"]
- }
- }
- return data[key]
-
- def deserialize(self, data, key, model_type):
- """
- Deserialize previously pymilo serialized OneHotEncoder object.
-
- deserialize the data[key] of the given model which type is model_type.
- basically in order to fully deserialize a model, we should traverse over all the keys of its serialized data dictionary and
- pass it through the chain of associated transporters to get fully deserialized.
-
- :param data: the internal data dictionary of the associated json file of the ML model which is generated previously by
- pymilo export.
- :type data: dict
- :param key: the special key of the data param, which we're going to deserialize its value(data[key])
- :type key: object
- :param model_type: the model type of the ML model, which internal serialized data dictionary is given as the data param
- :type model_type: str
- :return: pymilo deserialized output of data[key]
- """
- content = data[key]
- if is_primitive(content) or content is None:
- return content
-
- if check_str_in_iterable("pymilo-onehotencoder", content):
- if has_named_parameter(OneHotEncoder, "sparse_output"):
- return OneHotEncoder(sparse_output=content["pymilo-onehotencoder"]["sparse_output"])
- elif has_named_parameter(OneHotEncoder, "sparse"):
- return OneHotEncoder(sparse=content["pymilo-onehotencoder"]["sparse_output"])
- return content
diff --git a/pymilo/transporters/preprocessing_transporter.py b/pymilo/transporters/preprocessing_transporter.py
new file mode 100644
index 00000000..1557fd8a
--- /dev/null
+++ b/pymilo/transporters/preprocessing_transporter.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+"""PyMilo Preprocessing transporter."""
+from ..pymilo_param import SKLEARN_PREPROCESSING_TABLE
+from ..utils.util import check_str_in_iterable, get_sklearn_type
+from .transporter import AbstractTransporter, Command
+from .general_data_structure_transporter import GeneralDataStructureTransporter
+
+
+class PreprocessingTransporter(AbstractTransporter):
+ """Preprocessing object dedicated Transporter."""
+
+ def serialize(self, data, key, model_type):
+ """
+ Serialize Preprocessing object.
+
+ serialize the data[key] of the given model which type is model_type.
+ basically in order to fully serialize a model, we should traverse over all the keys of its data dictionary and
+ pass it through the chain of associated transporters to get fully serialized.
+
+ :param data: the internal data dictionary of the given model
+ :type data: dict
+ :param key: the special key of the data param, which we're going to serialize its value(data[key])
+ :type key: object
+ :param model_type: the model type of the ML model, which data dictionary is given as the data param
+ :type model_type: str
+ :return: pymilo serialized output of data[key]
+ """
+ if self.is_preprocessing_module(data[key]):
+ return self.serialize_pre_module(data[key])
+ return data[key]
+
+
+ def deserialize(self, data, key, model_type):
+ """
+ Deserialize previously pymilo serialized preprocessing object.
+
+ deserialize the data[key] of the given model which type is model_type.
+ basically in order to fully deserialize a model, we should traverse over all the keys of its serialized data dictionary and
+ pass it through the chain of associated transporters to get fully deserialized.
+
+ :param data: the internal data dictionary of the associated json file of the ML model which is generated previously by
+ pymilo export.
+ :type data: dict
+ :param key: the special key of the data param, which we're going to deserialize its value(data[key])
+ :type key: object
+ :param model_type: the model type of the ML model, which internal serialized data dictionary is given as the data param
+ :type model_type: str
+ :return: pymilo deserialized output of data[key]
+ """
+ content = data[key]
+ if self.is_preprocessing_module(content):
+ return self.deserialize_pre_module(content)
+ return content
+
+
+ def is_preprocessing_module(self, pre_module):
+ """
+ Check whether the given module is a sklearn Preprocessing module or not.
+
+ :param pre_module: given object
+ :type pre_module: any
+ :return: bool
+ """
+ if isinstance(pre_module, dict):
+ return check_str_in_iterable(
+ "pymilo-preprocessing-type",
+ pre_module) and pre_module["pymilo-preprocessing-type"] in SKLEARN_PREPROCESSING_TABLE.keys()
+ return get_sklearn_type(pre_module) in SKLEARN_PREPROCESSING_TABLE.keys()
+
+
+ def serialize_pre_module(self, pre_module):
+ """
+ Serialize Preprocessing object.
+
+ :param pre_module: given sklearn preprocessing module
+ :type pre_module: sklearn.preprocessing
+ :return: pymilo serialized pre_module
+ """
+ gdst = GeneralDataStructureTransporter()
+ gdst.transport(pre_module, Command.SERIALIZE, False)
+ return {
+ "pymilo-bypass": True,
+ "pymilo-preprocessing-type": get_sklearn_type(pre_module),
+ "pymilo-preprocessing-data": pre_module.__dict__
+ }
+
+
+ def deserialize_pre_module(self, serialized_pre_module):
+ """
+ Deserialize Preprocessing object.
+
+ :param serialized_pre_module: serializezd preprocessing module(by pymilo)
+ :type serialized_pre_module: dict
+ :return: retrieved associated sklearn.preprocessing module
+ """
+ data = serialized_pre_module["pymilo-preprocessing-data"]
+ associated_type = SKLEARN_PREPROCESSING_TABLE[serialized_pre_module["pymilo-preprocessing-type"]]
+ retrieved_pre_module = associated_type()
+ gdst = GeneralDataStructureTransporter()
+ for key in data:
+ setattr(retrieved_pre_module, key, gdst.deserialize(data, key, ""))
+ return retrieved_pre_module
diff --git a/tests/test_ensembles/pipeline.py b/tests/test_ensembles/pipeline.py
index de2e7961..39c2f765 100644
--- a/tests/test_ensembles/pipeline.py
+++ b/tests/test_ensembles/pipeline.py
@@ -1,4 +1,5 @@
from sklearn.svm import SVC
+from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from pymilo.utils.test_pymilo import pymilo_classification_test
from pymilo.utils.data_exporter import prepare_simple_classification_datasets
@@ -8,6 +9,6 @@
def pipeline():
x_train, y_train, x_test, y_test = prepare_simple_classification_datasets()
pipeline = Pipeline([
- #('scaler', StandardScaler()),
+ ('scaler', StandardScaler()),
('svc', SVC())]).fit(x_train, y_train)
pymilo_classification_test(pipeline, MODEL_NAME, (x_test, y_test))