Skip to content

Commit d8cb8b0

Browse files
committed
added zero sample weight removal
1 parent e7c2442 commit d8cb8b0

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/linearboost/linear_boost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,11 @@ def _check_X_y(self, X, y) -> tuple[np.ndarray, np.ndarray]:
271271
def fit(self, X, y, sample_weight=None) -> Self:
272272
if self.algorithm not in {"SAMME", "SAMME.R"}:
273273
raise ValueError("algorithm must be 'SAMME' or 'SAMME.R'")
274+
if sample_weight is not None:
275+
nonzero_mask = sample_weight != 0
276+
X = X[nonzero_mask]
277+
y = y[nonzero_mask]
278+
sample_weight = sample_weight[nonzero_mask]
274279

275280
X, y = self._check_X_y(X, y)
276281
self.classes_ = np.unique(y)

0 commit comments

Comments
 (0)