Skip to content

Commit b387f15

Browse files
authored
Merge pull request #8 from LinearBoost/v.0.1.2
V.0.1.2
2 parents c838802 + d3af90d commit b387f15

File tree

7 files changed

+545
-408
lines changed

7 files changed

+545
-408
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,8 @@ dmypy.json
130130

131131
#PyCharm
132132
.idea/
133+
catboost_info/catboost_training.json
134+
catboost_info/learn/events.out.tfevents
135+
catboost_info/learn_error.tsv
136+
catboost_info/time_left.tsv
137+
*.ipynb

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# LinearBoost Classifier
22

3-
![Lastest Release](https://img.shields.io/badge/release-v0.1.1-green)
3+
![Lastest Release](https://img.shields.io/badge/release-v0.1.2-green)
44
[![PyPI Version](https://img.shields.io/pypi/v/linearboost)](https://pypi.org/project/linearboost/)
55
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue)
66

@@ -19,10 +19,11 @@ Key Features:
1919
- Exceptional Speed: Blazing fast training and inference times
2020
- Resource Efficient: Low memory usage, ideal for large datasets
2121

22-
## 🚀 New Major Release (v0.1.1)
23-
Version 0.1.1 of **LinearBoost Classifier** is released, with a pull request from [@msamsami](https://github.com/msamsami). Here are the changes:
22+
## 🚀 New Major Release (v0.1.2)
23+
Version 0.1.2 of **LinearBoost Classifier** is released. Here are the changes:
2424

2525
- The codebase is refactored into a new structure.
26+
- SAMME.R algorithm is returned to the classifier.
2627
- Both SEFR and LinearBoostClassifier classes are refactored to fully adhere to Scikit-learn's conventions and API. Now, they are standard Scikit-learn estimators that can be used in Scikit-learn pipelines, grid search, etc.
2728
- Added unit tests (using pytest) to ensure the estimators adhere to Scikit-learn conventions.
2829
- Added fit_intercept parameter to SEFR similar to other linear estimators in Scikit-learn (e.g., LogisticRegression, LinearRegression, etc.).
@@ -35,16 +36,6 @@ Version 0.1.1 of **LinearBoost Classifier** is released, with a pull request fro
3536
- Improved Scikit-learn compatibility.
3637

3738

38-
## 🚀 New Release (v0.0.5)
39-
Version 0.0.5 of the **LinearBoost Classifier** is released! This new version introduces several exciting features and improvements:
40-
41-
- 🛠️ Support of custom loss function
42-
- ✅ Enhanced handling of class weights
43-
- 🎨 Customized handling of the data scalers
44-
- ⚡ Optimized boosting
45-
- 🕒 Improved runtime and scalability
46-
47-
4839
Get Started and Documentation
4940
-----------------------------
5041

@@ -228,3 +219,12 @@ License
228219
-------
229220

230221
This project is licensed under the terms of the MIT license. See [LICENSE](https://github.com/LinearBoost/linearboost-classifier/blob/main/LICENSE) for additional details.
222+
223+
## Acknowledgments
224+
225+
Some portions of this code are adapted from the scikit-learn project
226+
(https://scikit-learn.org), which is licensed under the BSD 3-Clause License.
227+
See the `licenses/` folder for details. The modifications and additions made to the original code are licensed under the MIT License © 2025 Hamidreza Keshavarz, Reza Rawassizadeh.
228+
Special Thanks to:
229+
- **Mehdi Samsami** – for software engineering, refactoring, and ensuring compatibility.
230+
The original code from scikit-learn is available at [scikit-learn GitHub repository](https://github.com/scikit-learn/scikit-learn)

src/linearboost/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2"
22

33
from .linear_boost import LinearBoostClassifier
44
from .sefr import SEFR
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2007-2024 The scikit-learn developers.
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
* Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
* Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
* Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

src/linearboost/linear_boost.py

Lines changed: 153 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# This file is part of the LinearBoost project.
2+
#
3+
# Portions of this file are derived from scikit-learn
4+
# Copyright (c) 2007–2024, scikit-learn developers (version 1.5)
5+
# Licensed under the BSD 3-Clause License
6+
# See https://github.com/scikit-learn/scikit-learn/blob/main/COPYING for details.
7+
#
8+
# Additional code and modifications:
9+
# - Hamidreza Keshavarz (hamid9@outlook.com) — machine learning logic, design, and new algorithms
10+
# - Mehdi Samsami (mehdisamsami@live.com) — software refactoring, compatibility with scikit-learn framework, and packaging
11+
#
12+
# The combined work is licensed under the MIT License.
13+
114
from __future__ import annotations
215

316
import sys
@@ -23,7 +36,7 @@
2336
StandardScaler,
2437
)
2538
from sklearn.utils import compute_sample_weight
26-
from sklearn.utils._param_validation import Hidden, Interval, StrOptions
39+
from sklearn.utils._param_validation import Interval, StrOptions
2740
from sklearn.utils.multiclass import check_classification_targets, type_of_target
2841
from sklearn.utils.validation import check_is_fitted
2942

@@ -73,18 +86,10 @@ class LinearBoostClassifier(AdaBoostClassifier):
7386
algorithm : {'SAMME', 'SAMME.R'}, default='SAMME'
7487
If 'SAMME' then use the SAMME discrete boosting algorithm.
7588
If 'SAMME.R' then use the SAMME.R real boosting algorithm
76-
(only available in scikit-learn < 1.6).
89+
(implemented from scikit-learn = 1.5).
7790
The SAMME.R algorithm typically converges faster than SAMME,
7891
achieving a lower test error with fewer boosting iterations.
7992
80-
.. deprecated:: scikit-learn 1.4
81-
`"SAMME.R"` is deprecated and will be removed in scikit-learn 1.6.
82-
'"SAMME"' will become the default.
83-
84-
.. deprecated:: scikit-learn 1.6
85-
`algorithm` is deprecated and will be removed in scikit-learn 1.8.
86-
This estimator only implements the 'SAMME' algorithm in scikit-learn >= 1.6.
87-
8893
scaler : str, default='minmax'
8994
Specifies the scaler to apply to the data. Options include:
9095
@@ -188,9 +193,7 @@ class LinearBoostClassifier(AdaBoostClassifier):
188193
_parameter_constraints: dict = {
189194
"n_estimators": [Interval(Integral, 1, None, closed="left")],
190195
"learning_rate": [Interval(Real, 0, None, closed="neither")],
191-
"algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))]
192-
if SKLEARN_V1_6_OR_LATER
193-
else [StrOptions({"SAMME", "SAMME.R"})],
196+
"algorithm": [StrOptions({"SAMME", "SAMME.R"})],
194197
"scaler": [StrOptions({s for s in _scalers})],
195198
"class_weight": [
196199
StrOptions({"balanced_subsample", "balanced"}),
@@ -206,18 +209,15 @@ def __init__(
206209
n_estimators=200,
207210
*,
208211
learning_rate=1.0,
209-
algorithm="SAMME",
212+
algorithm="SAMME.R",
210213
scaler="minmax",
211214
class_weight=None,
212215
loss_function=None,
213216
):
214217
super().__init__(
215-
estimator=SEFR(),
216-
n_estimators=n_estimators,
217-
learning_rate=learning_rate,
218-
algorithm=algorithm,
218+
estimator=SEFR(), n_estimators=n_estimators, learning_rate=learning_rate
219219
)
220-
220+
self.algorithm = algorithm
221221
self.scaler = scaler
222222
self.class_weight = class_weight
223223
self.loss_function = loss_function
@@ -241,7 +241,11 @@ def _more_tags(self) -> dict[str, bool]:
241241
"check_sample_weight_equivalence_on_dense_data": (
242242
"In LinearBoostClassifier, setting a sample's weight to 0 can produce a different "
243243
"result than omitting the sample. Such samples intentionally still affect the data scaling process."
244-
)
244+
),
245+
"check_sample_weights_invariance": (
246+
"In LinearBoostClassifier, a zero sample_weight is not equivalent to removing the sample, "
247+
"as samples with zero weight intentionally still affect the data scaling process."
248+
),
245249
},
246250
}
247251

@@ -269,9 +273,8 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
269273
return X, y
270274

271275
def fit(self, X, y, sample_weight=None) -> Self:
272-
X, y = self._check_X_y(X, y)
273-
self.classes_ = np.unique(y)
274-
self.n_classes_ = self.classes_.shape[0]
276+
if self.algorithm not in {"SAMME", "SAMME.R"}:
277+
raise ValueError("algorithm must be 'SAMME' or 'SAMME.R'")
275278

276279
if self.scaler not in _scalers:
277280
raise ValueError('Invalid scaler provided; got "%s".' % self.scaler)
@@ -283,6 +286,25 @@ def fit(self, X, y, sample_weight=None) -> Self:
283286
clone(_scalers[self.scaler]), clone(_scalers["minmax"])
284287
)
285288
X_transformed = self.scaler_.fit_transform(X)
289+
y = np.asarray(y)
290+
291+
if sample_weight is not None:
292+
sample_weight = np.asarray(sample_weight)
293+
if sample_weight.shape[0] != X_transformed.shape[0]:
294+
raise ValueError(
295+
f"sample_weight.shape == {sample_weight.shape} is incompatible with X.shape == {X_transformed.shape}"
296+
)
297+
nonzero_mask = (
298+
sample_weight.sum(axis=1) != 0
299+
if sample_weight.ndim > 1
300+
else sample_weight != 0
301+
)
302+
X_transformed = X_transformed[nonzero_mask]
303+
y = y[nonzero_mask]
304+
sample_weight = sample_weight[nonzero_mask]
305+
X_transformed, y = self._check_X_y(X_transformed, y)
306+
self.classes_ = np.unique(y)
307+
self.n_classes_ = self.classes_.shape[0]
286308

287309
if self.class_weight is not None:
288310
valid_presets = ("balanced", "balanced_subsample")
@@ -307,50 +329,131 @@ def fit(self, X, y, sample_weight=None) -> Self:
307329
warnings.filterwarnings(
308330
"ignore",
309331
category=FutureWarning,
310-
message=".*parameter 'algorithm' is deprecated.*",
332+
message=".*parameter 'algorithm' may change in the future.*",
311333
)
312334
return super().fit(X_transformed, y, sample_weight)
313335

336+
def _samme_proba(self, estimator, n_classes, X):
337+
"""Calculate algorithm 4, step 2, equation c) of Zhu et al [1].
338+
339+
References
340+
----------
341+
.. [1] J. Zhu, H. Zou, S. Rosset, T. Hastie, "Multi-class AdaBoost", 2009.
342+
343+
"""
344+
proba = estimator.predict_proba(X)
345+
346+
# Displace zero probabilities so the log is defined.
347+
# Also fix negative elements which may occur with
348+
# negative sample weights.
349+
np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba)
350+
log_proba = np.log(proba)
351+
352+
return (n_classes - 1) * (
353+
log_proba - (1.0 / n_classes) * log_proba.sum(axis=1)[:, np.newaxis]
354+
)
355+
314356
def _boost(self, iboost, X, y, sample_weight, random_state):
315357
estimator = self._make_estimator(random_state=random_state)
316358
estimator.fit(X, y, sample_weight=sample_weight)
317359

318-
y_pred = estimator.predict(X)
319-
missclassified = y_pred != y
360+
if self.algorithm == "SAMME.R":
361+
y_pred = estimator.predict(X)
320362

321-
if self.loss_function:
322-
estimator_error = self.loss_function(y, y_pred, sample_weight)
323-
else:
363+
incorrect = y_pred != y
324364
estimator_error = np.mean(
325-
np.average(missclassified, weights=sample_weight, axis=0)
365+
np.average(incorrect, weights=sample_weight, axis=0)
326366
)
327367

328-
if estimator_error <= 0:
329-
return sample_weight, 1.0, 0.0
368+
if estimator_error <= 0:
369+
return sample_weight, 1.0, 0.0
370+
elif estimator_error >= 0.5:
371+
if len(self.estimators_) > 1:
372+
self.estimators_.pop(-1)
373+
return None, None, None
330374

331-
if estimator_error >= 0.5:
332-
self.estimators_.pop(-1)
333-
if len(self.estimators_) == 0:
334-
raise ValueError(
335-
"BaseClassifier in AdaBoostClassifier ensemble is worse than random, ensemble can not be fit."
375+
# Compute SEFR-specific weight update
376+
estimator_weight = self.learning_rate * np.log(
377+
(1 - estimator_error) / estimator_error
378+
)
379+
380+
if iboost < self.n_estimators - 1:
381+
sample_weight = np.exp(
382+
np.log(sample_weight)
383+
+ estimator_weight * incorrect * (sample_weight > 0)
336384
)
337-
return None, None, None
338385

339-
estimator_weight = (
340-
self.learning_rate
341-
* 0.5
342-
* np.log((1.0 - estimator_error) / max(estimator_error, 1e-10))
343-
)
386+
return sample_weight, estimator_weight, estimator_error
387+
388+
else: # standard SAMME
389+
y_pred = estimator.predict(X)
390+
incorrect = y_pred != y
391+
estimator_error = np.mean(np.average(incorrect, weights=sample_weight))
392+
393+
if estimator_error <= 0:
394+
return sample_weight, 1.0, 0.0
395+
if estimator_error >= 0.5:
396+
self.estimators_.pop(-1)
397+
if len(self.estimators_) == 0:
398+
raise ValueError(
399+
"BaseClassifier in AdaBoostClassifier ensemble is worse than random, ensemble cannot be fit."
400+
)
401+
return None, None, None
402+
403+
estimator_weight = self.learning_rate * np.log(
404+
(1.0 - estimator_error) / max(estimator_error, 1e-10)
405+
)
344406

345-
sample_weight *= np.exp(
346-
estimator_weight
347-
* missclassified
348-
* ((sample_weight > 0) | (estimator_weight < 0))
349-
)
407+
sample_weight *= np.exp(estimator_weight * incorrect)
408+
409+
# Normalize sample weights
410+
sample_weight /= np.sum(sample_weight)
350411

351-
return sample_weight, estimator_weight, estimator_error
412+
return sample_weight, estimator_weight, estimator_error
352413

353414
def decision_function(self, X):
354415
check_is_fitted(self)
355416
X_transformed = self.scaler_.transform(X)
356-
return super().decision_function(X_transformed)
417+
418+
if self.algorithm == "SAMME.R":
419+
# Proper SAMME.R decision function
420+
classes = self.classes_
421+
n_classes = len(classes)
422+
423+
pred = sum(
424+
self._samme_proba(estimator, n_classes, X_transformed)
425+
for estimator in self.estimators_
426+
)
427+
pred /= self.estimator_weights_.sum()
428+
if n_classes == 2:
429+
pred[:, 0] *= -1
430+
return pred.sum(axis=1)
431+
return pred
432+
433+
else:
434+
# Standard SAMME algorithm from AdaBoostClassifier (discrete)
435+
return super().decision_function(X_transformed)
436+
437+
def predict(self, X):
438+
"""Predict classes for X.
439+
440+
The predicted class of an input sample is computed as the weighted mean
441+
prediction of the classifiers in the ensemble.
442+
443+
Parameters
444+
----------
445+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
446+
The training input samples. Sparse matrix can be CSC, CSR, COO,
447+
DOK, or LIL. COO, DOK, and LIL are converted to CSR.
448+
449+
Returns
450+
-------
451+
y : ndarray of shape (n_samples,)
452+
The predicted classes.
453+
"""
454+
pred = self.decision_function(X)
455+
456+
if self.n_classes_ == 2:
457+
return self.classes_.take(pred > 0, axis=0)
458+
459+
return self.classes_.take(np.argmax(pred, axis=1), axis=0)

tests/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def check_estimator(estimator, *args, **kwargs):
1616

1717

1818
def get_expected_failed_tests(estimator) -> dict[str, str]:
19-
return estimator._more_tags()["_xfail_checks"]
19+
return estimator._more_tags().get("_xfail_checks", {})

0 commit comments

Comments
 (0)