Skip to content

Commit aa3b4bb

Browse files
authored
fix: support sklearn BaseEstimators (#407)
The transformers and threshold model inherit BaseEstimator class. Fixes: > mlflow saving of models --------- Signed-off-by: Kushal Batra <i.kushalbatra@gmail.com>
1 parent eb9ceb1 commit aa3b4bb

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

numalogic/registry/mlflow_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from mlflow.exceptions import RestException
2222
from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST
2323
from mlflow.tracking import MlflowClient
24-
from sklearn.base import BaseEstimator
2524
from torch import nn
2625

26+
from numalogic.base import BaseThresholdModel, BaseTransformer
2727
from numalogic.registry import ArtifactManager, ArtifactData
2828
from numalogic.registry.artifact import ArtifactCache
2929
from numalogic.tools.exceptions import ModelVersionError
@@ -104,7 +104,7 @@ def __init__(
104104
def handler_from_obj(artifact: artifact_t):
105105
if isinstance(artifact, nn.Module):
106106
return mlflow.pytorch
107-
if isinstance(artifact, BaseEstimator):
107+
if isinstance(artifact, (BaseThresholdModel, BaseTransformer)):
108108
return mlflow.sklearn
109109
return mlflow.pyfunc
110110

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "numalogic"
3-
version = "0.12.4"
3+
version = "0.13.0"
44
description = "Collection of operational Machine Learning models and tools."
55
authors = ["Numalogic Developers"]
66
packages = [{ include = "numalogic" }]

tests/registry/_mlflow_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from mlflow.entities.model_registry import ModelVersion
88
from mlflow.models.model import ModelInfo
99
from mlflow.store.entities import PagedList
10-
from sklearn.ensemble import RandomForestRegressor
1110
from sklearn.preprocessing import StandardScaler
1211
from torch import tensor
1312

13+
from numalogic.models.threshold import StdDevThreshold
14+
1415

1516
def create_model():
1617
x = torch.linspace(-math.pi, math.pi, 2000)
@@ -33,8 +34,7 @@ def create_model():
3334

3435

3536
def model_sklearn():
36-
params = {"n_estimators": 5, "random_state": 42}
37-
return RandomForestRegressor(**params)
37+
return StdDevThreshold()
3838

3939

4040
def mock_log_state_dict(*_, **__):

0 commit comments

Comments
 (0)