本项目复现了论文Customer churn prediction using a novel meta-classifier:an investigation on transaction, Telecommunication and customer churn datasets。该项目复现了相关论文的研究成果,通过集成多个基分类器并使用Oracle元分类器来提高客户流失预测的准确性。
- 多种机器学习算法: 支持决策树、随机森林、XGBoost、AdaBoost、ExtraTrees等算法
- 两种超参数优化策略:
- SKB + GridSearchCV (SelectKBest + 网格搜索)
- PFS + BayesSearchCV (Permutation Feature Importance + 贝叶斯搜索)
- 三种模型类型:
- 单一模型 (Single Models)
- 堆叠集成模型 (Stacking Ensemble)
- Oracle元分类器 (Oracle Meta-Classifier)
- 完整的数据预处理流程: 包括缺失值处理、异常值检测、特征编码、类别不平衡处理等
- 全面的评估指标: 包含准确率、精确率、召回率、F1分数、ROC-AUC、Cohen's Kappa、MCC等
项目支持三个数据集:
- Transaction Dataset: 银行客户交易数据
- Telecommunication Dataset: 电信客户数据
- Customer Churn Dataset: 客户流失数据
Customer churn prediction using a novel meta-classifier/
├── data/ # 数据集文件夹
│ ├── customer_churn.xlsx
│ ├── telecommunication.csv
│ └── transaction.csv
├── src/ # 源代码文件夹
│ ├── main.py # 主程序入口
│ ├── train.py # 模型训练模块
│ ├── data_preprocess.py # 数据预处理模块
│ ├── hyperparameter_optimization.py # 超参数优化模块
│ ├── utils.py # 工具函数
│ └── test.py # 测试模块
├── results/ # 结果文件夹
│ ├── hyperparameter/ # 超参数优化结果
│ ├── metrics/ # 评估指标结果
│ └── models/ # 训练好的模型
├── README.md # 项目说明文档
├── requirements.txt # 依赖包列表
├── .gitignore # Git忽略文件
└── LICENSE # 开源许可证
pip install -r requirements.txt首先需要对各个模型进行超参数优化:
from src.hyperparameter_optimization import HyperparameterOptimization
# 对transaction数据集进行SKB+GridSearch优化
hpo = HyperparameterOptimization('transaction', 'label_one_hot', 'skb_grid')
hpo.save_hyperparameter_params()
# 对transaction数据集进行PFS+BayesSearch优化
hpo = HyperparameterOptimization('transaction', 'label_one_hot', 'pfs_bayes')
hpo.save_hyperparameter_params()训练不同类型的模型:
from src.train import Train
# 训练单一模型
trainer = Train('transaction', 'label_one_hot', 'single', 'skb_grid')
trainer.train()
# 训练堆叠集成模型
trainer = Train('transaction', 'label_one_hot', 'stack_ensemble', 'skb_grid')
trainer.train()
# 训练Oracle元分类器
trainer = Train('transaction', 'label_one_hot', 'oracle_ensemble', 'skb_grid')
trainer.train()运行主程序进行完整的实验:
from src.main import dataset_main
# 运行customer_churn数据集的完整实验
results = dataset_main('customer_churn', iter_nums=10)本项目实现的核心算法是基于Oracle的元分类器,该算法:
- 训练多个基分类器(决策树、随机森林、XGBoost等)
- 使用Oracle分类器作为理想情况下的元分类器
- Oracle分类器理论上能够为每个测试样本选择最优的基分类器,基于下述公式(公式的含义详见论文)进行选择:
- 这种方法在理论上能够达到所有集成分类器性能的上限
动态选择集成是一种集成策略,通过为每个待分类样本动态选择最优的单个分类器或最优的几个分类器组合来进行预测,能够根据样本特征和分类器局部性能实现个性化选择。本项目使用的动态集成策略为:
- 对于每个测试样本j都在验证集上找到K个与其最相似的验证集样本,本项目使用的是KDTree(还可以使用KNN或聚类算法)
- 评估分类器池中的各个分类器在这K个样本上的得分,得分计算公式(公式的含义详见论文)如下:
- 选择得分最高的分类器预测测试样本j的标签
- SKB (SelectKBest): 使用F统计量进行特征选择
- PFS (Permutation Feature Importance): 使用排列重要性进行特征选择
- GridSearchCV: 网格搜索交叉验证
- BayesSearchCV: 贝叶斯搜索交叉验证
项目使用多种评估指标来全面评估模型性能:
- 基础指标: 准确率、精确率、召回率、F1分数
- ROC指标: ROC-AUC
- 一致性指标: Cohen's Kappa、MCC
- 混淆矩阵指标: TPR、TNR、PPV、NPV、FPR、FNR、FDR、FOR
| 模型 | 超参数优化 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC | Cohen's Kappa | TPR | TNR | PPV | NPV | FPR | FNR | FDR | FOR | MCC |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| AdaBoostClassifier | PFS+Bayes | 0.5001 | 0.4948 | 0.4783 | 0.4864 | 0.5032 | -0.0003 | 0.4783 | 0.5214 | 0.4948 | 0.5049 | 0.4786 | 0.5217 | 0.5052 | 0.4951 | -0.0003 |
| AdaBoostClassifier | SKB+Grid | 0.5033 | 0.4943 | 0.1587 | 0.2402 | 0.5043 | -0.0005 | 0.1587 | 0.8409 | 0.4943 | 0.5050 | 0.1591 | 0.8413 | 0.5057 | 0.4950 | -0.0005 |
| DecisionTreeClassifier | PFS+Bayes | 0.5007 | 0.4951 | 0.4397 | 0.4657 | 0.4993 | 0.0002 | 0.4397 | 0.5605 | 0.4951 | 0.5052 | 0.4395 | 0.5603 | 0.5049 | 0.4948 | 0.0002 |
| DecisionTreeClassifier | SKB+Grid | 0.5011 | 0.4952 | 0.3940 | 0.4235 | 0.4998 | 0.0000 | 0.3940 | 0.6060 | 0.4952 | 0.5049 | 0.3940 | 0.6060 | 0.5048 | 0.4951 | 0.0000 |
| ExtraTreesClassifier | PFS+Bayes | 0.5041 | 0.4978 | 0.2160 | 0.3011 | 0.5037 | 0.0025 | 0.2160 | 0.7864 | 0.4978 | 0.5059 | 0.2136 | 0.7840 | 0.5022 | 0.4941 | 0.0025 |
| ExtraTreesClassifier | SKB+Grid | 0.5041 | 0.4975 | 0.2057 | 0.2908 | 0.5047 | 0.0021 | 0.2057 | 0.7965 | 0.4975 | 0.5057 | 0.2035 | 0.7943 | 0.5025 | 0.4943 | 0.0021 |
| RandomForestClassifier | PFS+Bayes | 0.5001 | 0.4935 | 0.3860 | 0.4331 | 0.4990 | -0.0021 | 0.3860 | 0.6119 | 0.4935 | 0.5042 | 0.3881 | 0.6140 | 0.5065 | 0.4958 | -0.0021 |
| RandomForestClassifier | SKB+Grid | 0.5011 | 0.4956 | 0.4538 | 0.4738 | 0.4995 | 0.0012 | 0.4538 | 0.5474 | 0.4956 | 0.5056 | 0.4526 | 0.5462 | 0.5044 | 0.4944 | 0.0012 |
| XGBClassifier | PFS+Bayes | 0.5046 | 0.4995 | 0.4933 | 0.4964 | 0.5018 | 0.0089 | 0.4933 | 0.5156 | 0.4995 | 0.5094 | 0.4844 | 0.5067 | 0.5005 | 0.4906 | 0.0089 |
| XGBClassifier | SKB+Grid | 0.5014 | 0.4963 | 0.4899 | 0.4931 | 0.4982 | 0.0026 | 0.4899 | 0.5127 | 0.4963 | 0.5064 | 0.4873 | 0.5101 | 0.5037 | 0.4936 | 0.0026 |
| Dynamic Selection Ensemble | PFS+Bayes | 0.5006 | 0.4950 | 0.4452 | 0.4687 | 0.5003 | 0.0001 | 0.4452 | 0.5549 | 0.4950 | 0.5051 | 0.4451 | 0.5548 | 0.5050 | 0.4949 | 0.0001 |
| Dynamic Selection Ensemble | SKB+Grid | 0.5014 | 0.4954 | 0.3887 | 0.4355 | 0.4976 | 0.0006 | 0.3887 | 0.6119 | 0.4954 | 0.5053 | 0.3881 | 0.6113 | 0.5046 | 0.4947 | 0.0006 |
| Stack Ensemble | PFS+Bayes | 0.4987 | 0.3963 | 0.6050 | 0.4587 | 0.4994 | -0.0004 | 0.6050 | 0.3946 | 0.3963 | 0.4532 | 0.6054 | 0.3950 | 0.6037 | 0.5468 | -0.0004 |
| Stack Ensemble | SKB+Grid | 0.4982 | 0.4039 | 0.6456 | 0.4487 | 0.4997 | -0.0006 | 0.6456 | 0.3538 | 0.4039 | 0.3736 | 0.6462 | 0.3544 | 0.5961 | 0.6264 | -0.0006 |
| Oracle Ensemble | PFS+Bayes | 0.9064 | 0.9439 | 0.8620 | 0.9011 | 0.9830 | 0.8126 | 0.8620 | 0.9498 | 0.9439 | 0.8754 | 0.0502 | 0.1380 | 0.0561 | 0.1246 | 0.8126 |
| Oracle Ensemble | SKB+Grid | 0.8963 | 0.9847 | 0.8028 | 0.8836 | 0.9908 | 0.7922 | 0.8028 | 0.9879 | 0.9847 | 0.8383 | 0.0121 | 0.1972 | 0.0153 | 0.1617 | 0.7922 |
| 模型 | 超参数优化 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC | Cohen's Kappa | TPR | TNR | PPV | NPV | FPR | FNR | FDR | FOR | MCC |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| AdaBoostClassifier | PFS+Bayes | 0.7935 | 0.6443 | 0.4701 | 0.5435 | 0.8450 | 0.4143 | 0.4701 | 0.9081 | 0.6443 | 0.8287 | 0.0919 | 0.5299 | 0.3557 | 0.1713 | 0.4143 |
| AdaBoostClassifier | SKB+Grid | 0.7879 | 0.6182 | 0.4936 | 0.5489 | 0.8402 | 0.4125 | 0.4936 | 0.8920 | 0.6182 | 0.8326 | 0.1080 | 0.5064 | 0.3818 | 0.1674 | 0.4125 |
| DecisionTreeClassifier | PFS+Bayes | 0.7809 | 0.6080 | 0.4701 | 0.5233 | 0.8193 | 0.3860 | 0.4701 | 0.8909 | 0.6080 | 0.8274 | 0.1091 | 0.5299 | 0.3920 | 0.1726 | 0.3860 |
| DecisionTreeClassifier | SKB+Grid | 0.7741 | 0.5920 | 0.4494 | 0.5081 | 0.8136 | 0.3658 | 0.4494 | 0.8891 | 0.5920 | 0.8205 | 0.1109 | 0.5506 | 0.4080 | 0.1795 | 0.3658 |
| ExtraTreesClassifier | PFS+Bayes | 0.7949 | 0.6447 | 0.4804 | 0.5505 | 0.8307 | 0.4213 | 0.4804 | 0.9062 | 0.6447 | 0.8312 | 0.0938 | 0.5196 | 0.3553 | 0.1688 | 0.4213 |
| ExtraTreesClassifier | SKB+Grid | 0.7889 | 0.6288 | 0.4704 | 0.5382 | 0.8338 | 0.4050 | 0.4704 | 0.9017 | 0.6288 | 0.8278 | 0.0983 | 0.5296 | 0.3712 | 0.1722 | 0.4050 |
| RandomForestClassifier | PFS+Bayes | 0.7952 | 0.6349 | 0.5107 | 0.5660 | 0.8364 | 0.4341 | 0.5107 | 0.8960 | 0.6349 | 0.8380 | 0.1040 | 0.4893 | 0.3651 | 0.1620 | 0.4341 |
| RandomForestClassifier | SKB+Grid | 0.7854 | 0.6270 | 0.4428 | 0.5189 | 0.8385 | 0.3861 | 0.4428 | 0.9067 | 0.6270 | 0.8213 | 0.0933 | 0.5572 | 0.3730 | 0.1787 | 0.3861 |
| XGBClassifier | PFS+Bayes | 0.7921 | 0.6402 | 0.4682 | 0.5409 | 0.8364 | 0.4107 | 0.4682 | 0.9068 | 0.6402 | 0.8281 | 0.0932 | 0.5318 | 0.3598 | 0.1719 | 0.4107 |
| XGBClassifier | SKB+Grid | 0.7893 | 0.6302 | 0.4701 | 0.5385 | 0.8395 | 0.4057 | 0.4701 | 0.9023 | 0.6302 | 0.8278 | 0.0977 | 0.5299 | 0.3698 | 0.1722 | 0.4057 |
| Dynamic Selection Ensemble | PFS+Bayes | 0.7922 | 0.6345 | 0.4846 | 0.5494 | 0.8352 | 0.4175 | 0.4846 | 0.9012 | 0.6345 | 0.8316 | 0.0988 | 0.5154 | 0.3655 | 0.1684 | 0.4175 |
| Dynamic Selection Ensemble | SKB+Grid | 0.7887 | 0.6253 | 0.4788 | 0.5423 | 0.8361 | 0.4080 | 0.4788 | 0.8984 | 0.6253 | 0.8296 | 0.1016 | 0.5212 | 0.3747 | 0.1704 | 0.4080 |
| Stack Ensemble | PFS+Bayes | 0.7991 | 0.6391 | 0.5339 | 0.5815 | 0.8421 | 0.4508 | 0.5339 | 0.8931 | 0.6391 | 0.8441 | 0.1069 | 0.4661 | 0.3609 | 0.1559 | 0.4508 |
| Stack Ensemble | SKB+Grid | 0.7924 | 0.6239 | 0.5194 | 0.5663 | 0.8399 | 0.4315 | 0.5194 | 0.8890 | 0.6239 | 0.8394 | 0.1110 | 0.4806 | 0.3761 | 0.1606 | 0.4315 |
| Oracle Ensemble | PFS+Bayes | 0.8603 | 0.8066 | 0.6156 | 0.6974 | 0.9527 | 0.6087 | 0.6156 | 0.9469 | 0.8066 | 0.8744 | 0.0531 | 0.3844 | 0.1934 | 0.1256 | 0.6087 |
| Oracle Ensemble | SKB+Grid | 0.8528 | 0.7938 | 0.5915 | 0.6774 | 0.9436 | 0.5848 | 0.5915 | 0.9454 | 0.7938 | 0.8673 | 0.0546 | 0.4085 | 0.2062 | 0.1327 | 0.5848 |
| 模型 | 超参数优化 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC | Cohen's Kappa | TPR | TNR | PPV | NPV | FPR | FNR | FDR | FOR | MCC |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| AdaBoostClassifier | PFS+Bayes | 0.7418 | 0.4302 | 0.8272 | 0.5660 | 0.8543 | 0.4072 | 0.8272 | 0.7199 | 0.4302 | 0.9422 | 0.2801 | 0.1728 | 0.5698 | 0.0578 | 0.4072 |
| AdaBoostClassifier | SKB+Grid | 0.7344 | 0.4226 | 0.8324 | 0.5606 | 0.8537 | 0.3981 | 0.8324 | 0.7094 | 0.4226 | 0.9431 | 0.2906 | 0.1676 | 0.5774 | 0.0569 | 0.3981 |
| DecisionTreeClassifier | PFS+Bayes | 0.7460 | 0.4251 | 0.7024 | 0.5296 | 0.7298 | 0.3698 | 0.7024 | 0.7572 | 0.4251 | 0.9088 | 0.2428 | 0.2976 | 0.5749 | 0.0912 | 0.3698 |
| DecisionTreeClassifier | SKB+Grid | 0.7525 | 0.4341 | 0.7069 | 0.5377 | 0.7355 | 0.3818 | 0.7069 | 0.7642 | 0.4341 | 0.9108 | 0.2358 | 0.2931 | 0.5659 | 0.0892 | 0.3818 |
| ExtraTreesClassifier | PFS+Bayes | 0.7685 | 0.4585 | 0.7596 | 0.5718 | 0.8435 | 0.4262 | 0.7596 | 0.7708 | 0.4585 | 0.9262 | 0.2292 | 0.2404 | 0.5415 | 0.0738 | 0.4262 |
| ExtraTreesClassifier | SKB+Grid | 0.7804 | 0.4746 | 0.7363 | 0.5771 | 0.8447 | 0.4380 | 0.7363 | 0.7916 | 0.4746 | 0.9216 | 0.2084 | 0.2637 | 0.5254 | 0.0784 | 0.4380 |
| RandomForestClassifier | PFS+Bayes | 0.7732 | 0.4636 | 0.7246 | 0.5653 | 0.8412 | 0.4218 | 0.7246 | 0.7857 | 0.4636 | 0.9178 | 0.2143 | 0.2754 | 0.5364 | 0.0822 | 0.4218 |
| RandomForestClassifier | SKB+Grid | 0.7834 | 0.4795 | 0.7454 | 0.5835 | 0.8524 | 0.4464 | 0.7454 | 0.7931 | 0.4795 | 0.9242 | 0.2069 | 0.2546 | 0.5205 | 0.0758 | 0.4464 |
| XGBClassifier | PFS+Bayes | 0.7846 | 0.4800 | 0.6928 | 0.5670 | 0.8379 | 0.4299 | 0.6928 | 0.8081 | 0.4800 | 0.9115 | 0.1919 | 0.3072 | 0.5200 | 0.0885 | 0.4299 |
| XGBClassifier | SKB+Grid | 0.7983 | 0.5032 | 0.7176 | 0.5916 | 0.8500 | 0.4631 | 0.7176 | 0.8190 | 0.5032 | 0.9190 | 0.1810 | 0.2824 | 0.4968 | 0.0810 | 0.4631 |
| Dynamic Selection Ensemble | PFS+Bayes | 0.7881 | 0.4864 | 0.7312 | 0.5841 | 0.8465 | 0.4496 | 0.7312 | 0.8027 | 0.4864 | 0.9212 | 0.1973 | 0.2688 | 0.5136 | 0.0788 | 0.4496 |
| Dynamic Selection Ensemble | SKB+Grid | 0.7851 | 0.4814 | 0.7224 | 0.5777 | 0.8186 | 0.4413 | 0.7224 | 0.8011 | 0.4814 | 0.9187 | 0.1989 | 0.2776 | 0.5186 | 0.0813 | 0.4413 |
| Stack Ensemble | PFS+Bayes | 0.7971 | 0.5012 | 0.7225 | 0.5917 | 0.8489 | 0.4625 | 0.7225 | 0.8161 | 0.5012 | 0.9201 | 0.1839 | 0.2775 | 0.4988 | 0.0799 | 0.4625 |
| Stack Ensemble | SKB+Grid | 0.7939 | 0.4960 | 0.7176 | 0.5864 | 0.8501 | 0.4552 | 0.7176 | 0.8135 | 0.4960 | 0.9185 | 0.1865 | 0.2824 | 0.5040 | 0.0815 | 0.4552 |
| Oracle Ensemble | PFS+Bayes | 0.8970 | 0.6920 | 0.8897 | 0.7785 | 0.9821 | 0.7127 | 0.8897 | 0.8988 | 0.6920 | 0.9696 | 0.1012 | 0.1103 | 0.3080 | 0.0304 | 0.7127 |
| Oracle Ensemble | SKB+Grid | 0.8889 | 0.6727 | 0.8842 | 0.7641 | 0.9781 | 0.6931 | 0.8842 | 0.8901 | 0.6727 | 0.9678 | 0.1099 | 0.1158 | 0.3273 | 0.0322 | 0.6931 |
-
Oracle元分类器性能最优:在所有三个数据集上,Oracle元分类器都表现出最佳性能,因为这是所有基分类器集成所能达到的理论上限。
-
数据集难度差异:
- Customer Churn数据集:最具挑战性,所有模型性能都接近随机水平(准确率~50%),只有Oracle元分类器能达到90%+的准确率
- Telecommunication数据集:单一模型准确率在77-80%之间
- Transaction数据集:单一模型准确率在73-80%之间
-
超参数优化策略比较:
- PFS+Bayes和SKB+Grid两种策略在不同数据集上表现各有优势
- 对于Customer Churn数据集,两种策略差异不大
- 对于Telecommunication和Transaction数据集,两种策略都能有效提升模型性能
-
模型性能排序(基于平均准确率):
- Oracle元分类器 > Stack Ensemble > Dynamic Selection Ensemble > 单一模型
- 单一模型中:XGBClassifier > RandomForestClassifier > ExtraTreesClassifier > DecisionTreeClassifier > AdaBoostClassifier
-
ROC-AUC表现:
- Oracle元分类器在所有数据集上都达到了0.95+的ROC-AUC值
- 单一模型的ROC-AUC值在0.73-0.85之间
- 集成模型相比单一模型有显著提升
完整的数据预处理流程包括:
- 删除不必要的列
- 处理缺失值
- 删除重复值
- 异常值检测和处理(IQR方法)
- 特征编码(标签编码和独热编码)
- 数据标准化
- 类别不平衡处理(SMOTEENN)
- 本项目复现了相关论文的研究成果
- Oracle分类器是理想情况下的分类器,在实际部署中无法使用
- 建议在运行完整实验前先进行超参数优化
详见 requirements.txt 文件。
本项目仅供学习和研究使用。
欢迎提交Issue和Pull Request来改进本项目。
如有问题,请通过GitHub Issues联系。