Skip to content

Fixes sklearn pickling issues #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nextmv-gurobipy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
]
dependencies = [
"gurobipy>=12.0.1",
"nextmv>=0.25.0"
"nextmv>=0.28.1"
]
description = "An SDK for integrating Gurobi with the Nextmv platform"
dynamic = [
Expand Down
2 changes: 1 addition & 1 deletion nextmv-scikit-learn/nextmv_sklearn/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "v0.3.0"
__version__ = "v0.3.1.dev1"
4 changes: 4 additions & 0 deletions nextmv-scikit-learn/nextmv_sklearn/linear_model/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@
nextmv.Option(
name="fit_intercept",
option_type=bool,
default=True,
description="Whether to calculate the intercept for this model.",
),
nextmv.Option(
name="copy_X",
option_type=bool,
default=True,
description="If True, X will be copied; else, it may be overwritten.",
),
nextmv.Option(
name="n_jobs",
option_type=int,
default=1,
description="The number of jobs to use for the computation.",
),
nextmv.Option(
name="positive",
option_type=bool,
default=False,
description="When set to True, forces the coefficients to be positive.",
),
]
Expand Down
2 changes: 1 addition & 1 deletion nextmv-scikit-learn/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
]
dependencies = [
"scikit-learn>=1.6.1",
"nextmv>=0.25.0"
"nextmv>=0.28.1"
]
description = "An SDK for integrating scikit-learn with the Nextmv platform"
dynamic = [
Expand Down
60 changes: 60 additions & 0 deletions nextmv-scikit-learn/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import unittest

from nextmv_sklearn.linear_model import LinearRegression, LinearRegressionOptions

import nextmv


class MLRegressorModel(nextmv.Model):
def solve(self, input: nextmv.Input) -> nextmv.Output:
if input.options.mode == "linear":
model = LinearRegression(input.options)
_ = model
return nextmv.Output(solution={}, options=input.options)
else:
raise ValueError(f"Unsupported mode: {input.options.mode}")


class TestPickle(unittest.TestCase):
def tearDown(self):
"""Removes the mlflow elements created during the test."""
model_configuration = nextmv.ModelConfiguration(
name="reg",
)
nextmv.model._cleanup_python_model(model_dir="export", model_configuration=model_configuration)

def test_options(self):
model = MLRegressorModel()
# Define options (custom and sklearn).
sklearn_opts = LinearRegressionOptions().to_nextmv()
custom_options = nextmv.Options(
nextmv.Option(
name="mode",
option_type=str,
default="linear",
description="ML mode (linear or xgboost).",
required=False,
)
)
options = custom_options.merge(sklearn_opts)
# Create a model configuration so that we can pickle the model.
model_configuration = nextmv.ModelConfiguration(
name="reg",
requirements=[
"nextmv",
"nextmv-scikit-learn",
],
options=options,
)

# Run the model with some input data.
input = nextmv.Input(data={}, options=options)
output = model.solve(input)
self.assertIsInstance(output, nextmv.Output)

# Save (pickle) the model to a directory.
os.makedirs("export", exist_ok=True)
model.save("export", model_configuration)
# Assert that the "export" directory is not empty
self.assertTrue(len(os.listdir("export")) > 0)
Loading