Skip to content

Commit e99164e

Browse files
authored
Fix missing and fragile scikit-learn imports in Keras sklearn wrappers (#21387)
* Fix sklearn imports * remove sklearn import * fix formatting and imports * Remove internal API import * Resolve review comments
1 parent f85e044 commit e99164e

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

keras/src/wrappers/fixes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def _raise_or_return(target_type):
3434
else:
3535
return target_type
3636

37-
target_type = sklearn.utils.multiclass.type_of_target(
38-
y, input_name=input_name
39-
)
37+
from sklearn.utils.multiclass import type_of_target as sk_type_of_target
38+
39+
target_type = sk_type_of_target(y, input_name=input_name)
4040
return _raise_or_return(target_type)
4141

4242

keras/src/wrappers/sklearn_wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def fit(self, X, y, **kwargs):
172172

173173
def predict(self, X):
174174
"""Predict using the model."""
175-
sklearn.base.check_is_fitted(self)
175+
from sklearn.utils.validation import check_is_fitted
176+
177+
check_is_fitted(self)
176178
X = _validate_data(self, X, reset=False)
177179
raw_output = self.model_.predict(X)
178180
return self._reverse_process_target(raw_output)
@@ -474,7 +476,9 @@ def transform(self, X):
474476
X_transformed: array-like, shape=(n_samples, n_features)
475477
The transformed data.
476478
"""
477-
sklearn.base.check_is_fitted(self)
479+
from sklearn.utils.validation import check_is_fitted
480+
481+
check_is_fitted(self)
478482
X = _validate_data(self, X, reset=False)
479483
return self.model_.predict(X)
480484

keras/src/wrappers/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
try:
24
import sklearn
35
from sklearn.base import BaseEstimator
@@ -25,8 +27,8 @@ def _check_model(model):
2527
# compile model if user gave us an un-compiled model
2628
if not model.compiled or not model.loss or not model.optimizer:
2729
raise RuntimeError(
28-
"Given model needs to be compiled, and have a loss and an "
29-
"optimizer."
30+
"Given model needs to be compiled, and have a loss "
31+
"and an optimizer."
3032
)
3133

3234

@@ -80,8 +82,9 @@ def inverse_transform(self, y):
8082
is passed, it will be squeezed back to 1D. Otherwise, it
8183
will eb left untouched.
8284
"""
83-
sklearn.base.check_is_fitted(self)
84-
xp, _ = sklearn.utils._array_api.get_namespace(y)
85+
from sklearn.utils.validation import check_is_fitted
86+
87+
check_is_fitted(self)
8588
if self.ndim_ == 1 and y.ndim == 2:
86-
return xp.squeeze(y, axis=1)
89+
return np.squeeze(y, axis=1)
8790
return y

0 commit comments

Comments
 (0)