2
2
import matplotlib .pyplot as plt
3
3
from catboost import CatBoostRegressor
4
4
from shap import TreeExplainer , summary_plot
5
- from sklearn .inspection import partial_dependence , PartialDependenceDisplay
5
+ from sklearn .inspection import PartialDependenceDisplay
6
+
6
7
7
8
def train_catboost_model (X_train , y_train , params ):
8
9
"""
9
10
Trains a CatBoost regression model using the given training data and parameters.
10
-
11
+
11
12
Args:
12
13
X_train (DataFrame): Training features.
13
14
y_train (Series/DataFrame): Training target.
14
15
params (dict): Dictionary containing CatBoost parameters.
15
-
16
+
16
17
Returns:
17
18
CatBoostRegressor: Trained CatBoost model.
18
19
"""
@@ -52,7 +53,7 @@ def train_catboost_model(X_train, y_train, params):
52
53
loss_function = params ["loss_function" ],
53
54
eval_metric = params ["eval_metric" ],
54
55
random_seed = params .get ("random_state" , 42 ),
55
- verbose = params .get ("verbose_eval" , True )
56
+ verbose = params .get ("verbose_eval" , True ),
56
57
)
57
58
58
59
# Train the model with an evaluation set for monitoring
@@ -61,7 +62,7 @@ def train_catboost_model(X_train, y_train, params):
61
62
X_train_clean ,
62
63
y_train_clean ,
63
64
eval_set = [(X_train_clean , y_train_clean )],
64
- verbose = params .get ("verbose_eval" , True )
65
+ verbose = params .get ("verbose_eval" , True ),
65
66
)
66
67
67
68
# Log the completion of the training process
@@ -93,37 +94,34 @@ def explain_catboost_model(model, X_train):
93
94
summary_plot (shap_values , X_train , show = False )
94
95
fig = plt .gcf () # Get current figure
95
96
plt .close (fig )
96
-
97
+
97
98
logger .info ("SHAP summary plot created successfully." )
98
99
return fig
99
100
100
101
101
102
def plot_partial_dependence_catboost (model , X_train , features ):
102
103
"""
103
104
Generates a partial dependence plot for specified features using the trained CatBoost model.
104
-
105
+
105
106
Args:
106
107
model: Trained CatBoost model.
107
108
X_train (DataFrame): Training features.
108
109
features (list): List of feature names or indices for which to compute partial dependence.
109
-
110
+
110
111
Returns:
111
112
matplotlib.figure.Figure: The partial dependence plot figure.
112
113
"""
113
114
logger = logging .getLogger (__name__ )
114
115
logger .info ("Creating partial dependence plot for CatBoost model..." )
115
-
116
+
116
117
# Create the plot using scikit-learn's PartialDependenceDisplay
117
118
fig , ax = plt .subplots (figsize = (12 , 8 ))
118
119
display = PartialDependenceDisplay .from_estimator (
119
- model ,
120
- X_train ,
121
- features = features ,
122
- ax = ax
120
+ model , X_train , features = features , ax = ax
123
121
)
124
122
ax .set_title ("Partial Dependence Plot" )
125
123
plt .tight_layout ()
126
124
plt .close (fig )
127
-
125
+
128
126
logger .info ("Partial dependence plot created successfully." )
129
127
return fig
0 commit comments