diff --git a/src/linearboost/linear_boost.py b/src/linearboost/linear_boost.py index 64073ec..0967008 100644 --- a/src/linearboost/linear_boost.py +++ b/src/linearboost/linear_boost.py @@ -15,6 +15,7 @@ import sys import warnings +from abc import abstractmethod from numbers import Integral, Real if sys.version_info >= (3, 11): @@ -40,7 +41,7 @@ from sklearn.utils.multiclass import check_classification_targets, type_of_target from sklearn.utils.validation import check_is_fitted -from ._utils import SKLEARN_V1_6_OR_LATER, check_X_y +from ._utils import SKLEARN_V1_6_OR_LATER, check_X_y, validate_data from .sefr import SEFR __all__ = ["LinearBoostClassifier"] @@ -63,7 +64,201 @@ } -class LinearBoostClassifier(AdaBoostClassifier): +class _DenseAdaBoostClassifier(AdaBoostClassifier): + if SKLEARN_V1_6_OR_LATER: + + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags.input_tags.sparse = False + return tags + + def _check_X(self, X): + # Only called to validate X in non-fit methods, therefore reset=False + return validate_data( + self, + X, + accept_sparse=False, + ensure_2d=True, + allow_nd=True, + dtype=None, + reset=False, + ) + + @abstractmethod + def _boost(self, iboost, X, y, sample_weight, random_state): + """Implement a single boost. + + Warning: This method needs to be overridden by subclasses. + + Parameters + ---------- + iboost : int + The index of the current boost iteration. + + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + y : array-like of shape (n_samples,) + The target values (class labels). + + sample_weight : array-like of shape (n_samples,) + The current sample weights. + + random_state : RandomState + The current random number generator + + Returns + ------- + sample_weight : array-like of shape (n_samples,) or None + The reweighted sample weights. + If None then boosting has terminated early. + + estimator_weight : float + The weight for the current boost. + If None then boosting has terminated early. + + error : float + The classification error for the current boost. + If None then boosting has terminated early. + """ + pass + + def staged_score(self, X, y, sample_weight=None): + """Return staged scores for X, y. + + This generator method yields the ensemble score after each iteration of + boosting and therefore allows monitoring, such as to determine the + score on a test set after each boost. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + y : array-like of shape (n_samples,) + Labels for X. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Yields + ------ + z : float + """ + return super().staged_score(X, y, sample_weight) + + def staged_predict(self, X): + """Return staged predictions for X. + + The predicted class of an input sample is computed as the weighted mean + prediction of the classifiers in the ensemble. + + This generator method yields the ensemble prediction after each + iteration of boosting and therefore allows monitoring, such as to + determine the prediction on a test set after each boost. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + + Yields + ------ + y : generator of ndarray of shape (n_samples,) + The predicted classes. + """ + return super().staged_predict(X) + + def staged_decision_function(self, X): + """Compute decision function of ``X`` for each boosting iteration. + + This method allows monitoring (i.e. determine error on testing set) + after each boosting iteration. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + Yields + ------ + score : generator of ndarray of shape (n_samples, k) + The decision function of the input samples. The order of + outputs is the same of that of the :term:`classes_` attribute. + Binary classification is a special cases with ``k == 1``, + otherwise ``k==n_classes``. For binary classification, + values closer to -1 or 1 mean more like the first or second + class in ``classes_``, respectively. + """ + return super().staged_decision_function(X) + + def predict_proba(self, X): + """Predict class probabilities for X. + + The predicted class probabilities of an input sample is computed as + the weighted mean predicted class probabilities of the classifiers + in the ensemble. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class probabilities of the input samples. The order of + outputs is the same of that of the :term:`classes_` attribute. + """ + return super().predict_proba(X) + + def staged_predict_proba(self, X): + """Predict class probabilities for X. + + The predicted class probabilities of an input sample is computed as + the weighted mean predicted class probabilities of the classifiers + in the ensemble. + + This generator method yields the ensemble predicted class probabilities + after each iteration of boosting and therefore allows monitoring, such + as to determine the predicted class probabilities on a test set after + each boost. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + Yields + ------ + p : generator of ndarray of shape (n_samples,) + The class probabilities of the input samples. The order of + outputs is the same of that of the :term:`classes_` attribute. + """ + return super().staged_predict_proba(X) + + def predict_log_proba(self, X): + """Predict class log-probabilities for X. + + The predicted class log-probabilities of an input sample is computed as + the weighted mean predicted class log-probabilities of the classifiers + in the ensemble. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class probabilities of the input samples. The order of + outputs is the same of that of the :term:`classes_` attribute. + """ + return super().predict_log_proba(X) + + +class LinearBoostClassifier(_DenseAdaBoostClassifier): """A LinearBoost classifier. A LinearBoost classifier is a meta-estimator based on AdaBoost and SEFR. @@ -221,7 +416,6 @@ def __init__( def __sklearn_tags__(self): tags = super().__sklearn_tags__() - tags.input_tags.sparse = False tags.target_tags.required = True tags.classifier_tags.multi_class = False tags.classifier_tags.poor_score = True @@ -268,6 +462,25 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]: return X, y def fit(self, X, y, sample_weight=None) -> Self: + """Build a LinearBoost classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + y : array-like of shape (n_samples,) + The target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, the sample weights are initialized to + 1 / n_samples. + + Returns + ------- + self : object + Fitted estimator. + """ if self.algorithm not in {"SAMME", "SAMME.R"}: raise ValueError("algorithm must be 'SAMME' or 'SAMME.R'") @@ -322,7 +535,8 @@ def fit(self, X, y, sample_weight=None) -> Self: ) return super().fit(X_transformed, y, sample_weight) - def _samme_proba(self, estimator, n_classes, X): + @staticmethod + def _samme_proba(estimator, n_classes, X): """Calculate algorithm 4, step 2, equation c) of Zhu et al [1]. References @@ -401,6 +615,23 @@ def _boost(self, iboost, X, y, sample_weight, random_state): return sample_weight, estimator_weight, estimator_error def decision_function(self, X): + """Compute the decision function of ``X``. + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + The training input samples. + + Returns + ------- + score : ndarray of shape of (n_samples, k) + The decision function of the input samples. The order of + outputs is the same as that of the :term:`classes_` attribute. + Binary classification is a special cases with ``k == 1``, + otherwise ``k==n_classes``. For binary classification, + values closer to -1 or 1 mean more like the first or second + class in ``classes_``, respectively. + """ check_is_fitted(self) X_transformed = self.scaler_.transform(X) @@ -431,9 +662,8 @@ def predict(self, X): Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Sparse matrix can be CSC, CSR, COO, - DOK, or LIL. COO, DOK, and LIL are converted to CSR. + X : {array-like} of shape (n_samples, n_features) + The training input samples. Returns ------- diff --git a/tests/test_linearboost.py b/tests/test_linearboost.py index f01a23f..a7ccf19 100644 --- a/tests/test_linearboost.py +++ b/tests/test_linearboost.py @@ -393,7 +393,9 @@ def test_invalid_algorithm_error(): y = np.array([0, 1]) clf = LinearBoostClassifier(algorithm="INVALID") - with pytest.raises(ValueError, match="algorithm must be 'SAMME' or 'SAMME.R'"): + msg1 = "algorithm must be 'SAMME' or 'SAMME.R'" + msg2 = r"The 'algorithm' parameter of LinearBoostClassifier must be a str among \{('SAMME', 'SAMME\.R'|'SAMME\.R', 'SAMME')\}" + with pytest.raises(ValueError, match=rf"({msg1}|{msg2})"): clf.fit(X, y) @@ -403,7 +405,9 @@ def test_invalid_scaler_error(): y = np.array([0, 1]) clf = LinearBoostClassifier(scaler="invalid_scaler") - with pytest.raises(ValueError, match="Invalid scaler provided"): + msg1 = "Invalid scaler provided" + msg2 = r"The 'scaler' parameter of LinearBoostClassifier must be a str among .*\. Got 'invalid_scaler' instead\." + with pytest.raises(ValueError, match=rf"({msg1}|{msg2})"): clf.fit(X, y) @@ -413,7 +417,9 @@ def test_invalid_class_weight_error(): y = np.array([0, 1]) clf = LinearBoostClassifier(class_weight="invalid_weight") - with pytest.raises(ValueError, match='Valid preset for class_weight is "balanced"'): + msg1 = 'Valid preset for class_weight is "balanced"' + msg2 = r"The 'class_weight' parameter of LinearBoostClassifier must be a str among \{'balanced'\}, an instance of 'dict', an instance of 'list' or None" + with pytest.raises(ValueError, match=rf"({msg1}|{msg2})"): clf.fit(X, y) @@ -691,28 +697,6 @@ def test_breast_cancer_dataset(): assert score > 0.5 # Should be better than random guessing -def test_memory_efficiency(): - """Test that LinearBoostClassifier doesn't consume excessive memory.""" - X, y = make_classification( - n_samples=200, - n_features=10, - n_redundant=0, - random_state=42, - n_clusters_per_class=1, - ) - - clf = LinearBoostClassifier(n_estimators=10) - clf.fit(X, y) - - # Check that the model doesn't store the training data - assert not hasattr(clf, "X_") - assert not hasattr(clf, "y_") - - # Check that estimators are SEFR instances (lightweight) - for estimator in clf.estimators_: - assert estimator.__class__.__name__ == "SEFR" - - def test_different_class_labels(): """Test with different types of class labels.""" X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) diff --git a/tests/test_sefr.py b/tests/test_sefr.py index ff3eaa8..8bcac13 100644 --- a/tests/test_sefr.py +++ b/tests/test_sefr.py @@ -457,28 +457,6 @@ def test_breast_cancer_dataset(): assert score > 0.5 # Should be better than random guessing -def test_memory_efficiency(): - """Test that SEFR doesn't consume excessive memory.""" - # This is a basic test - in practice you might want more sophisticated memory profiling - X, y = make_classification( - n_samples=1000, - n_features=20, - n_redundant=0, - random_state=42, - n_clusters_per_class=1, - ) - - sefr = SEFR() - sefr.fit(X, y) - - # Check that the model doesn't store the training data - assert not hasattr(sefr, "X_") - assert not hasattr(sefr, "y_") - - # Check that coefficients are reasonably sized - assert sefr.coef_.nbytes < 1000 # Should be small for 20 features - - def test_different_class_labels(): """Test with different types of class labels.""" X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])