Skip to content

Commit 1b6192c

Browse files
[MNT] Update scikit-learn requirement from <1.7.0,>=1.0.0 to >=1.0.0,<1.8.0 in the python-packages group (#2907)
* [MNT] Update scikit-learn requirement in the python-packages group Updates the requirements on [scikit-learn](https://github.com/scikit-learn/scikit-learn) to permit the latest version. Updates `scikit-learn` to 1.7.0 - [Release notes](https://github.com/scikit-learn/scikit-learn/releases) - [Commits](scikit-learn/scikit-learn@1.0...1.7.0) --- updated-dependencies: - dependency-name: scikit-learn dependency-version: 1.7.0 dependency-type: direct:production dependency-group: python-packages ... Signed-off-by: dependabot[bot] <support@github.com> * new validation * Update test_rockad.py * Update _rockad.py --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: MatthewMiddlehurst <pfm15hbu@gmail.com>
1 parent 184994b commit 1b6192c

File tree

6 files changed

+14
-13
lines changed

6 files changed

+14
-13
lines changed

aeon/anomaly_detection/series/distance_based/_rockad.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def _inner_fit(self, X: np.ndarray) -> None:
176176

177177
if self.power_transform:
178178
self.power_transformer_ = PowerTransformer()
179+
# todo check if this is still an issue with scikit-learn >= 1.7.0
180+
# when lower bound is raised
179181
try:
180182
Xtp = self.power_transformer_.fit_transform(Xt)
181183

aeon/anomaly_detection/series/distance_based/tests/test_rockad.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,6 @@ def test_rockad_incorrect_input():
7373
):
7474
ad = ROCKAD(stride=1, window_size=100)
7575
ad.fit(train_series)
76-
with pytest.warns(
77-
UserWarning, match=r"Power Transform failed and thus has been disabled."
78-
):
79-
ad = ROCKAD(stride=1, window_size=5)
80-
ad.fit(train_series)
8176
with pytest.raises(
8277
ValueError, match=r"window shape cannot be larger than input array shape"
8378
):

aeon/classification/sklearn/_continuous_interval_tree.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.exceptions import NotFittedError
2222
from sklearn.utils import check_random_state
2323
from sklearn.utils.multiclass import check_classification_targets
24+
from sklearn.utils.validation import validate_data
2425

2526

2627
class _TreeNode:
@@ -374,7 +375,8 @@ def fit(self, X, y):
374375
"""
375376
# data processing
376377
X = self._check_X(X)
377-
X, y = self._validate_data(
378+
X, y = validate_data(
379+
self,
378380
X=X,
379381
y=y,
380382
ensure_min_samples=2,
@@ -464,8 +466,8 @@ def predict_proba(self, X):
464466

465467
# data processing
466468
X = self._check_X(X)
467-
X = self._validate_data(
468-
X=X, reset=False, force_all_finite="allow-nan", accept_sparse=False
469+
X = validate_data(
470+
self, X=X, reset=False, force_all_finite="allow-nan", accept_sparse=False
469471
)
470472

471473
dists = np.zeros((X.shape[0], self.n_classes_))

aeon/classification/sklearn/_rotation_forest_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.tree import DecisionTreeClassifier
2020
from sklearn.utils import check_random_state
2121
from sklearn.utils.multiclass import check_classification_targets
22+
from sklearn.utils.validation import validate_data
2223

2324
from aeon.base._base import _clone_estimator
2425
from aeon.utils.validation import check_n_jobs
@@ -192,7 +193,7 @@ def predict_proba(self, X) -> np.ndarray:
192193

193194
# data processing
194195
X = self._check_X(X)
195-
X = self._validate_data(X=X, reset=False, accept_sparse=False)
196+
X = validate_data(self, X=X, reset=False, accept_sparse=False)
196197

197198
# replace missing values with 0 and remove useless attributes
198199
X = X[:, self._useful_atts]
@@ -299,7 +300,7 @@ def fit_predict_proba(self, X, y) -> np.ndarray:
299300
def _fit_rotf(self, X, y, save_transformed_data: bool = False):
300301
# data processing
301302
X = self._check_X(X)
302-
X, y = self._validate_data(X=X, y=y, ensure_min_samples=2, accept_sparse=False)
303+
X, y = validate_data(self, X=X, y=y, ensure_min_samples=2, accept_sparse=False)
303304
check_classification_targets(y)
304305

305306
self._n_jobs = check_n_jobs(self.n_jobs)

aeon/regression/sklearn/_rotation_forest_regressor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.exceptions import NotFittedError
2020
from sklearn.tree import DecisionTreeRegressor
2121
from sklearn.utils import check_random_state
22+
from sklearn.utils.validation import validate_data
2223

2324
from aeon.base._base import _clone_estimator
2425
from aeon.utils.validation import check_n_jobs
@@ -168,7 +169,7 @@ def predict(self, X) -> np.ndarray:
168169

169170
# data processing
170171
X = self._check_X(X)
171-
X = self._validate_data(X=X, reset=False, accept_sparse=False)
172+
X = validate_data(self, X=X, reset=False, accept_sparse=False)
172173

173174
# replace missing values with 0 and remove useless attributes
174175
X = X[:, self._useful_atts]
@@ -222,7 +223,7 @@ def fit_predict(self, X, y) -> np.ndarray:
222223
def _fit_rotf(self, X, y, save_transformed_data: bool = False):
223224
# data processing
224225
X = self._check_X(X)
225-
X, y = self._validate_data(X=X, y=y, ensure_min_samples=2, accept_sparse=False)
226+
X, y = validate_data(self, X=X, y=y, ensure_min_samples=2, accept_sparse=False)
226227

227228
self._label_average = np.mean(y)
228229

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ dependencies = [
5050
"numpy>=1.21.0,<2.3.0",
5151
"packaging>=20.0",
5252
"pandas>=2.0.0,<2.4.0",
53-
"scikit-learn>=1.0.0,<1.7.0",
53+
"scikit-learn>=1.0.0,<1.8.0",
5454
"scipy>=1.9.0,<1.16.0",
5555
"typing-extensions>=4.6.0",
5656
]

0 commit comments

Comments
 (0)