Skip to content

Commit 8d2f5b1

Browse files
chore: autopublish 2025-02-12T13:02:03Z
1 parent 8423cc0 commit 8d2f5b1

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

src/energy_forcasting_model/pipelines/catboost_pipeline/nodes.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
import matplotlib.pyplot as plt
33
from catboost import CatBoostRegressor
44
from shap import TreeExplainer, summary_plot
5-
from sklearn.inspection import partial_dependence, PartialDependenceDisplay
5+
from sklearn.inspection import PartialDependenceDisplay
6+
67

78
def train_catboost_model(X_train, y_train, params):
89
"""
910
Trains a CatBoost regression model using the given training data and parameters.
10-
11+
1112
Args:
1213
X_train (DataFrame): Training features.
1314
y_train (Series/DataFrame): Training target.
1415
params (dict): Dictionary containing CatBoost parameters.
15-
16+
1617
Returns:
1718
CatBoostRegressor: Trained CatBoost model.
1819
"""
@@ -52,7 +53,7 @@ def train_catboost_model(X_train, y_train, params):
5253
loss_function=params["loss_function"],
5354
eval_metric=params["eval_metric"],
5455
random_seed=params.get("random_state", 42),
55-
verbose=params.get("verbose_eval", True)
56+
verbose=params.get("verbose_eval", True),
5657
)
5758

5859
# Train the model with an evaluation set for monitoring
@@ -61,7 +62,7 @@ def train_catboost_model(X_train, y_train, params):
6162
X_train_clean,
6263
y_train_clean,
6364
eval_set=[(X_train_clean, y_train_clean)],
64-
verbose=params.get("verbose_eval", True)
65+
verbose=params.get("verbose_eval", True),
6566
)
6667

6768
# Log the completion of the training process
@@ -93,37 +94,34 @@ def explain_catboost_model(model, X_train):
9394
summary_plot(shap_values, X_train, show=False)
9495
fig = plt.gcf() # Get current figure
9596
plt.close(fig)
96-
97+
9798
logger.info("SHAP summary plot created successfully.")
9899
return fig
99100

100101

101102
def plot_partial_dependence_catboost(model, X_train, features):
102103
"""
103104
Generates a partial dependence plot for specified features using the trained CatBoost model.
104-
105+
105106
Args:
106107
model: Trained CatBoost model.
107108
X_train (DataFrame): Training features.
108109
features (list): List of feature names or indices for which to compute partial dependence.
109-
110+
110111
Returns:
111112
matplotlib.figure.Figure: The partial dependence plot figure.
112113
"""
113114
logger = logging.getLogger(__name__)
114115
logger.info("Creating partial dependence plot for CatBoost model...")
115-
116+
116117
# Create the plot using scikit-learn's PartialDependenceDisplay
117118
fig, ax = plt.subplots(figsize=(12, 8))
118119
display = PartialDependenceDisplay.from_estimator(
119-
model,
120-
X_train,
121-
features=features,
122-
ax=ax
120+
model, X_train, features=features, ax=ax
123121
)
124122
ax.set_title("Partial Dependence Plot")
125123
plt.tight_layout()
126124
plt.close(fig)
127-
125+
128126
logger.info("Partial dependence plot created successfully.")
129127
return fig

src/energy_forcasting_model/pipelines/catboost_pipeline/pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .nodes import (
44
train_catboost_model,
55
explain_catboost_model,
6-
plot_partial_dependence_catboost
6+
plot_partial_dependence_catboost,
77
)
88

99
from ..random_forest_pipeline.nodes import (
@@ -12,6 +12,7 @@
1212
generate_predictions,
1313
)
1414

15+
1516
def create_pipeline(**kwargs) -> Pipeline:
1617
return pipeline(
1718
[
@@ -34,7 +35,12 @@ def create_pipeline(**kwargs) -> Pipeline:
3435
],
3536
outputs="catboost_feature_importance_plot",
3637
name="plot_feature_importance_node",
37-
tags=["feature_importance", "visualization", "catboost", "model_training"],
38+
tags=[
39+
"feature_importance",
40+
"visualization",
41+
"catboost",
42+
"model_training",
43+
],
3844
),
3945
node( # Node 3: Generate Predictions
4046
func=generate_predictions,

src/energy_forcasting_model/pipelines/lightgbm_training_pipeline/nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import lightgbm as lgb
22
import logging
33

4+
45
def train_lightgbm_model(X_train, y_train, params):
56
"""
67
Trains a LightGBM regression model using the given training data and parameters.

0 commit comments

Comments
 (0)