Skip to content

Commit 16c68d9

Browse files
authored
Merge pull request #1 from ChaorongC/typing
Typing for MESA
2 parents a547ae7 + 63788ab commit 16c68d9

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

mesa/MESA.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from sklearn.preprocessing import Normalizer, StandardScaler
2424
from collections.abc import Sequence
2525
from sklearn.linear_model import LogisticRegression
26+
from typing import Union, Optional, List, Tuple, Any
2627

27-
def disp_mesa(txt):
28+
def disp_mesa(txt: str) -> None:
2829
"""
2930
Display a timestamped message to stderr for MESA logging.
3031
@@ -36,7 +37,7 @@ def disp_mesa(txt):
3637
print("@%s \t%s" % (time.asctime(), txt), file=sys.stderr)
3738

3839

39-
def wilcoxon(X, y):
40+
def wilcoxon(X: np.ndarray, y: np.ndarray) -> np.ndarray:
4041
"""
4142
Score function for feature selection using Wilcoxon rank-sum test.
4243
@@ -98,7 +99,7 @@ def __init__(self, n=10, **kwargs):
9899
super().__init__(**kwargs)
99100
self.n = n
100101

101-
def fit(self, X, y):
102+
def fit(self, X: np.ndarray, y: np.ndarray) -> "BorutaSelector":
102103
"""
103104
Fit the Boruta feature selection algorithm and select top n features.
104105
@@ -118,7 +119,7 @@ def fit(self, X, y):
118119
self.indices = np.argsort(self.ranking_)[: self.n]
119120
return self
120121

121-
def transform(self, X):
122+
def transform(self, X: Union[np.ndarray, pd.DataFrame]) -> Union[np.ndarray, pd.DataFrame]:
122123
"""
123124
Transform data to contain only the selected top n features.
124125
@@ -146,7 +147,7 @@ def transform(self, X):
146147
except:
147148
return X[:, self.indices]
148149

149-
def get_support(self):
150+
def get_support(self) -> np.ndarray:
150151
"""
151152
Get indices of the selected features.
152153
@@ -196,7 +197,7 @@ def __init__(self, ratio=0.9, imputer=SimpleImputer(strategy="mean")):
196197
self.ratio = ratio
197198
self.imputer = imputer
198199

199-
def fit(self, X, y=None):
200+
def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> "missing_value_processing":
200201
"""
201202
Fit the missing value processor by identifying valid features and fitting imputer.
202203
@@ -228,7 +229,7 @@ def fit(self, X, y=None):
228229
else:
229230
raise ValueError("The ratio of valid values should be greater than 0.")
230231

231-
def transform(self, X):
232+
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
232233
"""
233234
Transform data by removing high-missing features and imputing remaining values.
234235
@@ -256,7 +257,7 @@ def transform(self, X):
256257
else:
257258
raise ValueError("The ratio of valid values should be greater than 0.")
258259

259-
def get_support(self):
260+
def get_support(self) -> np.ndarray:
260261
"""
261262
Get indices of features that passed the missing value filter.
262263
@@ -362,7 +363,7 @@ def __init__(
362363
for key, value in kwargs.items():
363364
setattr(self, key, value)
364365

365-
def fit(self, X, y):
366+
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> "MESA_modality":
366367
"""
367368
Fit the complete preprocessing pipeline and classifier.
368369
@@ -396,7 +397,7 @@ def fit(self, X, y):
396397
self.classifier = self.classifier.fit(self.pipeline.transform(X), y)
397398
return self
398399

399-
def transform(self, X):
400+
def transform(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
400401
"""
401402
Apply the preprocessing pipeline to data.
402403
@@ -412,7 +413,7 @@ def transform(self, X):
412413
"""
413414
return self.pipeline.transform(X)
414415

415-
def predict(self, X):
416+
def predict(self, X: np.ndarray) -> np.ndarray:
416417
"""
417418
Predict class labels for preprocessed data.
418419
@@ -428,7 +429,7 @@ def predict(self, X):
428429
"""
429430
return self.classifier.predict(X)
430431

431-
def predict_proba(self, X):
432+
def predict_proba(self, X: np.ndarray) -> np.ndarray:
432433
"""
433434
Predict class probabilities for preprocessed data.
434435
@@ -444,7 +445,7 @@ def predict_proba(self, X):
444445
"""
445446
return self.classifier.predict_proba(X)
446447

447-
def transform_predict(self, X):
448+
def transform_predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
448449
"""
449450
Apply preprocessing pipeline and predict class labels.
450451
@@ -460,7 +461,7 @@ def transform_predict(self, X):
460461
"""
461462
return self.classifier.predict(self.pipeline.transform(X))
462463

463-
def transform_predict_proba(self, X):
464+
def transform_predict_proba(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
464465
"""
465466
Apply preprocessing pipeline and predict class probabilities.
466467
@@ -476,7 +477,7 @@ def transform_predict_proba(self, X):
476477
"""
477478
return self.classifier.predict_proba(self.pipeline.transform(X))
478479

479-
def get_support(self, step=None):
480+
def get_support(self, step: Optional[int] = None) -> np.ndarray:
480481
"""
481482
Get indices of features selected by pipeline components.
482483
@@ -498,7 +499,7 @@ def get_support(self, step=None):
498499
else:
499500
return self.pipeline[step].get_support(indices=True)
500501

501-
def get_params(self, deep=True):
502+
def get_params(self, deep: bool = True) -> dict:
502503
"""
503504
Get parameters of the MESA_modality instance.
504505
@@ -590,7 +591,7 @@ def __init__(
590591
for key, value in kwargs.items():
591592
setattr(self, key, value)
592593

593-
def _base_fit(self, X, y, base_estimator):
594+
def _base_fit(self, X: np.ndarray, y: Union[pd.Series, np.ndarray], base_estimator: Any) -> np.ndarray:
594595
"""
595596
Generate meta-features using cross-validation for a single modality.
596597
@@ -624,7 +625,7 @@ def _internal_cv(train_index, test_index):
624625
)
625626
return base_probability
626627

627-
def fit(self, X_list, y):
628+
def fit(self, X_list: List[Union[pd.DataFrame, np.ndarray]], y: Union[pd.Series, np.ndarray]) -> "MESA":
628629
"""
629630
Fit all modality estimators and the meta-estimator.
630631
@@ -658,7 +659,7 @@ def fit(self, X_list, y):
658659
self.meta_estimator.fit(base_probability, y_stacking)
659660
return self
660661

661-
def predict(self, X_list_test):
662+
def predict(self, X_list_test: List[Union[pd.DataFrame, np.ndarray]]) -> np.ndarray:
662663
"""
663664
Predict class labels using the fitted ensemble.
664665
@@ -677,7 +678,7 @@ def predict(self, X_list_test):
677678
)
678679
return self.meta_estimator.predict(base_probability_test)
679680

680-
def predict_proba(self, X_list_test):
681+
def predict_proba(self, X_list_test: List[Union[pd.DataFrame, np.ndarray]]) -> np.ndarray:
681682
"""
682683
Predict class probabilities using the fitted ensemble.
683684
@@ -696,7 +697,7 @@ def predict_proba(self, X_list_test):
696697
)
697698
return self.meta_estimator.predict_proba(base_probability_test)
698699

699-
def get_support(self, step=None):
700+
def get_support(self, step: Optional[int] = None) -> List[np.ndarray]:
700701
"""
701702
Get feature support information from all modalities.
702703
@@ -778,14 +779,17 @@ def __init__(
778779

779780
def _cv_iter(
780781
self,
781-
X,
782-
y,
783-
train_index,
784-
test_index,
785-
proba=True,
786-
return_feature_in=False,
787-
mesa=False
788-
):
782+
X: Union[pd.DataFrame, List[pd.DataFrame]],
783+
y: Union[pd.Series, np.ndarray],
784+
train_index: np.ndarray,
785+
test_index: np.ndarray,
786+
proba: bool = True,
787+
return_feature_in: bool = False,
788+
mesa: bool = False
789+
) -> Union[
790+
Tuple[np.ndarray, np.ndarray],
791+
Tuple[np.ndarray, np.ndarray, Any]
792+
]:
789793
"""
790794
Perform a single iteration of cross-validation.
791795
@@ -837,7 +841,7 @@ def _cv_iter(
837841
else:
838842
return y_pred, y_test
839843

840-
def fit(self, X, y):
844+
def fit(self, X: Union[pd.DataFrame, List[pd.DataFrame]], y: Union[pd.Series, np.ndarray]) -> "MESA_CV":
841845
"""
842846
Perform cross-validation on the provided data.
843847
@@ -894,7 +898,7 @@ def fit(self, X, y):
894898
)
895899
return self
896900

897-
def get_performance(self):
901+
def get_performance(self) -> float:
898902
"""
899903
Calculate the mean ROC AUC score across all cross-validation folds.
900904

0 commit comments

Comments
 (0)