Skip to content

Commit 9898a3b

Browse files
authored
Improve pipeline feature naming and categorical grouping support (#344)
* Improve pipeline feature naming and categorical grouping - add strip_pipeline_prefix/feature_name_fn support for pipeline-transformed feature names - add auto_detect_pipeline_cats inference for onehot-expanded pipeline columns - accept binary-like scaled onehot columns in parse_cats validation - preserve index in transformed pipeline dataframes and improve pipeline fallback warning - add helper unit tests and extend pipeline tests; update README/docs/release notes Refs #213 * Update TODO for #213 pipeline support improvements
1 parent 85660c0 commit 9898a3b

File tree

10 files changed

+798
-249
lines changed

10 files changed

+798
-249
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ db = ExplainerDashboard(explainer,
168168
db.run(port=8050)
169169
```
170170

171+
If you are passing an sklearn/imblearn `Pipeline`, you can also clean up transformed
172+
feature names and let the explainer infer onehot groups automatically:
173+
174+
```python
175+
explainer = ClassifierExplainer(
176+
pipeline_model, X_test, y_test,
177+
strip_pipeline_prefix=True, # e.g. "num__Age" -> "Age"
178+
feature_name_fn=None, # optional custom rename function
179+
auto_detect_pipeline_cats=True, # infer cats from transformed pipeline output
180+
)
181+
```
182+
171183
For a regression model you can also pass the units of the target variable (e.g.
172184
dollars):
173185

@@ -184,6 +196,10 @@ explainer = RegressionExplainer(model, X_test, y_test,
184196
ExplainerDashboard(explainer).run()
185197
```
186198

199+
For pipeline-based models with post-processing/scaling, grouped categorical
200+
features passed through `cats` are now accepted as long as encoded columns are
201+
binary-like (not strictly only `0/1`).
202+
187203
`y_test` is actually optional, although some parts of the dashboard like performance
188204
metrics will obviously not be available: `ExplainerDashboard(ClassifierExplainer(model, X_test)).run()`.
189205

RELEASE_NOTES.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
- Add CatBoost regression tests for classifier/regression `pdp_df(...)` with `X_row` containing missing categorical values.
1717
- Add hub regression test for integrated hub yaml serialization to verify `pickle_type` is preserved and explainer artifacts are written.
1818
- Add regression tests for issue #294 covering multiclass logodds consistency across prediction table, contributions, PDP highlight predictions, and XGBoost decision-path summaries.
19+
- Add pipeline tests for transformed feature-name cleanup (`strip_pipeline_prefix`, `feature_name_fn`) and pipeline categorical grouping autodetection.
20+
- Add explainer-method unit tests for binary-like onehot detection, transformed feature-name deduping, inferred pipeline cats, and pipeline extraction warning text.
21+
22+
### Improvements
23+
- Add pipeline feature-name cleanup options: `strip_pipeline_prefix=True` and `feature_name_fn=...` for sklearn/imblearn pipeline transformed output columns.
24+
- Add optional `auto_detect_pipeline_cats=True` to infer onehot groups from transformed pipeline columns when `cats` is not provided.
25+
- Preserve input index in transformed pipeline dataframes produced during pipeline extraction.
26+
- Improve pipeline extraction warning guidance and include concrete checks (`get_feature_names_out`, transform compatibility on `X`/`X_background`).
27+
- Relax onehot grouping validation to also accept binary-like scaled onehot columns (not only strict `0/1`) when parsing `cats`.
1928

2029
### CI
2130
- Update `explainerdashboard` GitHub Actions workflow to run a weekly scheduled full test suite (`pytest`) to detect dependency breakages earlier.

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- [M][Explainers][#198/#340] LightGBM string categorical handling across SHAP/plots.
1212
- [S][Hub][#146/#342] hub.to_yaml integrate_dashboard_yamls honors pickle_type and dumps integrated explainer artifacts.
1313
- [M][Explainers][#294] align/explain multiclass logodds between Contributions Plot and Prediction Box (+ PDP highlight and XGBoost decision path wording alignment).
14+
- [M][Explainers/Methods/Docs][#213] improve sklearn/imblearn pipeline support: feature-name cleanup (`strip_pipeline_prefix`, `feature_name_fn`), auto-detect onehot groups (`auto_detect_pipeline_cats`), accept binary-like scaled onehot columns in `cats`, preserve transformed index, add warnings/docs/tests.
1415

1516
**Now**
1617
- [M][Explainers][#118] add LightGBM tree visualization support (dtreeviz).

docs/source/deployment.rst

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,31 @@ Deployment
33

44
When deploying your dashboard it is better not to use the built-in flask
55
development server but use a more robust production server like ``gunicorn`` or ``waitress``.
6-
Probably `gunicorn <https://gunicorn.org/>`_ is a bit more fully featured and
6+
Probably `gunicorn <https://gunicorn.org/>`_ is a bit more fully featured and
77
faster but only works on unix/linux/osx, whereas
8-
`waitress <https://docs.pylonsproject.org/projects/waitress/en/stable/>`_ also works
9-
on Windows and has very minimal dependencies.
8+
`waitress <https://docs.pylonsproject.org/projects/waitress/en/stable/>`_ also works
9+
on Windows and has very minimal dependencies.
1010

11-
Install with either ``pip install gunicorn`` or ``pip install waitress``.
11+
Install with either ``pip install gunicorn`` or ``pip install waitress``.
1212

1313
Storing explainer and running default dashboard with gunicorn
1414
=============================================================
1515

16-
Before you start a dashboard with gunicorn you need to store both the explainer
16+
Before you start a dashboard with gunicorn you need to store both the explainer
1717
instance and and a configuration for the dashboard::
1818

1919
from explainerdashboard import ClassifierExplainer, ExplainerDashboard
2020

2121
explainer = ClassifierExplainer(model, X, y)
22-
db = ExplainerDashboard(explainer, title="Cool Title", shap_interaction=False)
22+
db = ExplainerDashboard(explainer, title="Cool Title", shap_interaction=False)
2323
db.to_yaml("dashboard.yaml", explainerfile="explainer.joblib", dump_explainer=True)
2424

2525
Now you re-load your dashboard and expose a flask server as ``app`` in ``dashboard.py``::
2626

2727
from explainerdashboard import ExplainerDashboard
2828

2929
db = ExplainerDashboard.from_config("dashboard.yaml")
30-
app = db.flask_server()
30+
app = db.flask_server()
3131

3232

3333
.. highlight:: bash
@@ -36,13 +36,13 @@ If you named the file above ``dashboard.py``, you can now start the gunicorn ser
3636

3737
$ gunicorn dashboard:app
3838

39-
If you want to run the server server with for example three workers, binding to
39+
If you want to run the server server with for example three workers, binding to
4040
port ``8050`` you launch gunicorn with::
4141

4242
$ gunicorn -w 3 -b localhost:8050 dashboard:app
4343

44-
If you now point your browser to ``http://localhost:8050`` you should see your dashboard.
45-
Next step is finding a nice url in your organization's domain, and forwarding it
44+
If you now point your browser to ``http://localhost:8050`` you should see your dashboard.
45+
Next step is finding a nice url in your organization's domain, and forwarding it
4646
to your dashboard server.
4747

4848
With waitress you would call::
@@ -70,19 +70,19 @@ You need to pass the Flask ``server`` instance and the ``url_base_pathname`` to
7070
under ``db.app.index``::
7171

7272
from flask import Flask
73-
73+
7474
app = Flask(__name__)
7575

7676
[...]
77-
77+
7878
db = ExplainerDashboard(explainer, server=app, url_base_pathname="/dashboard/")
7979

8080
@app.route('/dashboard')
8181
def return_dashboard():
8282
return db.app.index()
8383

8484

85-
.. highlight:: bash
85+
.. highlight:: bash
8686

8787
Now you can start the dashboard by::
8888

@@ -95,12 +95,12 @@ Deploying to heroku
9595
===================
9696

9797
In case you would like to deploy to `heroku <www.heroku.com>`_ (which is normally
98-
the simplest option for dash apps, see
99-
`dash instructions here <https://dash.plotly.com/deployment>`_). The demonstration
98+
the simplest option for dash apps, see
99+
`dash instructions here <https://dash.plotly.com/deployment>`_). The demonstration
100100
dashboard is also hosted on heroku at `titanicexplainer.herokuapp.com <http://titanicexplainer.herokuapp.com>`_.
101101

102-
In order to deploy the heroku there are a few things to keep in mind. First of
103-
all you need to add ``explainerdashboard`` and ``gunicorn`` to
102+
In order to deploy the heroku there are a few things to keep in mind. First of
103+
all you need to add ``explainerdashboard`` and ``gunicorn`` to
104104
``requirements.txt`` (pinning is recommended to force a new build of your environment
105105
whenever you upgrade versions)::
106106

@@ -112,8 +112,8 @@ your explainer in ``runtime.txt``::
112112

113113
python-3.8.6
114114

115-
(supported versions as of this writing are ``python-3.9.0``, ``python-3.8.6``,
116-
``python-3.7.9`` and ``python-3.6.12``, but check the
115+
(supported versions as of this writing are ``python-3.9.0``, ``python-3.8.6``,
116+
``python-3.7.9`` and ``python-3.6.12``, but check the
117117
`heroku documentation <https://devcenter.heroku.com/articles/python-support#supported-runtimes>`_
118118
for the latest)
119119

@@ -126,10 +126,10 @@ And you need to tell heroku how to start your server in ``Procfile``::
126126
Graphviz buildpack
127127
------------------
128128

129-
If you want to visualize individual trees inside your ``RandomForest`` or ``xgboost``
129+
If you want to visualize individual trees inside your ``RandomForest`` or ``xgboost``
130130
model using the ``dtreeviz`` package you will
131131
need to make sure that ``graphviz`` is installed on your ``heroku`` dyno by
132-
adding the following buildstack (as well as the ``python`` buildpack):
132+
adding the following buildstack (as well as the ``python`` buildpack):
133133
``https://github.com/weibeld/heroku-buildpack-graphviz.git``
134134

135135
(you can add buildpacks through the "settings" page of your heroku project)
@@ -150,11 +150,17 @@ E.g. **generate_dashboard.py**::
150150
X_train, y_train, X_test, y_test = titanic_survive()
151151
model = RandomForestClassifier(n_estimators=50, max_depth=5).fit(X_train, y_train)
152152

153-
explainer = ClassifierExplainer(model, X_test, y_test,
153+
explainer = ClassifierExplainer(model, X_test, y_test,
154154
cats=["Sex", 'Deck', 'Embarked'],
155155
labels=['Not Survived', 'Survived'],
156156
descriptions=feature_descriptions)
157157

158+
# For sklearn/imblearn pipeline models you can alternatively use:
159+
# explainer = ClassifierExplainer(
160+
# pipeline_model, X_test, y_test,
161+
# strip_pipeline_prefix=True,
162+
# auto_detect_pipeline_cats=True)
163+
158164
db = ExplainerDashboard(explainer)
159165
db.to_yaml("dashboard.yaml", explainerfile="explainer.joblib", dump_explainer=True)
160166

@@ -193,45 +199,45 @@ Reducing memory usage
193199

194200
If you deploy the dashboard with a large dataset with a large number of rows (``n``)
195201
and a large number of columns (``m``),
196-
it can use up quite a bit of memory: the dataset itself, shap values,
202+
it can use up quite a bit of memory: the dataset itself, shap values,
197203
shap interaction values and any other calculated properties are alle kept in
198204
memory in order to make the dashboard responsive. You can check the (approximate)
199205
memory usage with ``explainer.memory_usage()``. In order to reduce the memory
200206
footprint there are a number of things you can do:
201207

202208
1. Not including shap interaction tab.
203-
Shap interaction values are shape ``n*m*m``, so can take a subtantial amount
204-
of memory, especially if you have a significant amount of columns ``m``.
205-
2. Setting a lower precision.
209+
Shap interaction values are shape ``n*m*m``, so can take a subtantial amount
210+
of memory, especially if you have a significant amount of columns ``m``.
211+
2. Setting a lower precision.
206212
By default shap values are stored as ``'float64'``,
207213
but you can store them as ``'float32'`` instead and save half the space:
208-
```ClassifierExplainer(model, X_test, y_test, precision='float32')```. You
214+
```ClassifierExplainer(model, X_test, y_test, precision='float32')```. You
209215
can also set a lower precision on your ``X_test`` dataset yourself ofcourse.
210216
3. Drop non-positive class shap values.
211217
For multi class classifiers, by default ``ClassifierExplainer`` calculates
212218
shap values for all classes. If you are only interested in a single class
213219
you can drop the other shap values with ``explainer.keep_shap_pos_label_only(pos_label)``
214-
4. Storing row data externally and loading on the fly.
220+
4. Storing row data externally and loading on the fly.
215221
You can for example only store a subset of ``10.000`` rows in
216222
the ``explainer`` itself (enough to generate representative importance and dependence plots),
217-
and store the rest of your millions of rows of input data in an external file
223+
and store the rest of your millions of rows of input data in an external file
218224
or database that get loaded one by one with the following functions:
219225

220-
- with ``explainer.set_X_row_func()`` you can set a function that takes
226+
- with ``explainer.set_X_row_func()`` you can set a function that takes
221227
an `index` as argument and returns a single row dataframe with model
222228
compatible input data for that index. This function can include a query
223-
to a database or fileread.
224-
- with ``explainer.set_y_func()`` you can set a function that takes
229+
to a database or fileread.
230+
- with ``explainer.set_y_func()`` you can set a function that takes
225231
and `index` as argument and returns the observed outcome ``y`` for
226232
that index.
227-
- with ``explainer.set_index_list_func()`` you can set a function
233+
- with ``explainer.set_index_list_func()`` you can set a function
228234
that returns a list of available indexes that can be queried.
229-
230-
If the number of indexes is too long to fit in a dropdown you can pass
235+
236+
If the number of indexes is too long to fit in a dropdown you can pass
231237
``index_dropdown=False`` which turns the dropdowns into free text fields.
232-
Instead of an ``index_list_func`` you can also set an
238+
Instead of an ``index_list_func`` you can also set an
233239
``explainer.set_index_check_func(func)`` which should return a bool whether
234-
the ``index`` exists or not.
240+
the ``index`` exists or not.
235241

236242
Important: these function can be called multiple times by multiple independent
237243
components, so probably best to implement some kind of caching functionality.
@@ -242,22 +248,22 @@ footprint there are a number of things you can do:
242248
Setting logins and password
243249
===========================
244250

245-
``ExplainerDashboard`` supports `dash basic auth functionality <https://dash.plotly.com/authentication>`_.
251+
``ExplainerDashboard`` supports `dash basic auth functionality <https://dash.plotly.com/authentication>`_.
246252
``ExplainerHub`` uses ``flask_simple_login`` for its user authentication.
247253

248-
You can simply add a list of logins to the ``ExplainerDashboard`` to force a login
254+
You can simply add a list of logins to the ``ExplainerDashboard`` to force a login
249255
and prevent random users from accessing the details of your model dashboard::
250256

251257
ExplainerDashboard(explainer, logins=[['login1', 'password1'], ['login2', 'password2']]).run()
252258

253-
Whereas :ref:`ExplainerHub<ExplainerHub>` has somewhat more intricate user management
254-
using ``FlaskLogin``, but the basic syntax is the same. See the
259+
Whereas :ref:`ExplainerHub<ExplainerHub>` has somewhat more intricate user management
260+
using ``FlaskLogin``, but the basic syntax is the same. See the
255261
:ref:`ExplainerHub documetation<ExplainerHub>` for more details::
256262

257263
hub = ExplainerHub([db1, db2], logins=[['login1', 'password1'], ['login2', 'password2']])
258264

259-
Make sure not to check these login/password pairs into version control though,
260-
but store them somewhere safe! ``ExplainerHub`` stores passwords into a hashed
265+
Make sure not to check these login/password pairs into version control though,
266+
but store them somewhere safe! ``ExplainerHub`` stores passwords into a hashed
261267
format by default.
262268

263269

@@ -266,20 +272,20 @@ Automatically restart gunicorn server upon changes
266272

267273
We can use the ``explainerdashboard`` CLI tools to automatically rebuild our
268274
explainer whenever there is a change to the underlying
269-
model, dataset or explainer configuration. And we we can use ``kill -HUP gunicorn.pid``
270-
to force the gunicorn to restart and reload whenever a new ``explainer.joblib``
271-
is generated or the dashboard configuration ``dashboard.yaml`` changes. These two
272-
processes together ensure that the dashboard automatically updates whenever there
275+
model, dataset or explainer configuration. And we we can use ``kill -HUP gunicorn.pid``
276+
to force the gunicorn to restart and reload whenever a new ``explainer.joblib``
277+
is generated or the dashboard configuration ``dashboard.yaml`` changes. These two
278+
processes together ensure that the dashboard automatically updates whenever there
273279
are underlying changes.
274280

275-
First we store the explainer config in ``explainer.yaml`` and the dashboard
281+
First we store the explainer config in ``explainer.yaml`` and the dashboard
276282
config in ``dashboard.yaml``. We also indicate which modelfiles and datafiles the
277-
explainer depends on, and which columns in the datafile should be used as
283+
explainer depends on, and which columns in the datafile should be used as
278284
a target and which as index::
279285

280286
explainer = ClassifierExplainer(model, X, y, labels=['Not Survived', 'Survived'])
281287
explainer.dump("explainer.joblib")
282-
explainer.to_yaml("explainer.yaml",
288+
explainer.to_yaml("explainer.yaml",
283289
modelfile="model.pkl",
284290
datafile="data.csv",
285291
index_col="Name",
@@ -300,12 +306,12 @@ directly from the config file::
300306

301307
.. highlight:: bash
302308

303-
Now we would like to rebuild the ``explainer.joblib`` file whenever there is a
304-
change to ``model.pkl``, ``data.csv`` or ``explainer.yaml`` by running
305-
``explainerdashboard build``. And we restart the ``gunicorn`` server whenever
306-
there is a change in ``explainer.joblib`` or ``dashboard.yaml`` by killing
307-
the gunicorn server with ``kill -HUP pid`` To do that we need to install
308-
the python package ``watchdog`` (``pip install watchdog[watchmedo]``). This
309+
Now we would like to rebuild the ``explainer.joblib`` file whenever there is a
310+
change to ``model.pkl``, ``data.csv`` or ``explainer.yaml`` by running
311+
``explainerdashboard build``. And we restart the ``gunicorn`` server whenever
312+
there is a change in ``explainer.joblib`` or ``dashboard.yaml`` by killing
313+
the gunicorn server with ``kill -HUP pid`` To do that we need to install
314+
the python package ``watchdog`` (``pip install watchdog[watchmedo]``). This
309315
package can keep track of filechanges and execute shell-scripts upon file changes.
310316

311317
So we can start the gunicorn server and the two watchdog filechange trackers
@@ -321,17 +327,14 @@ from a shell script ``start_server.sh``::
321327

322328
wait # wait till user hits ctrl-c to exit and kill all three processes
323329

324-
Now we can simply run ``chmod +x start_server.sh`` and ``./start_server.sh`` to
330+
Now we can simply run ``chmod +x start_server.sh`` and ``./start_server.sh`` to
325331
get our server up and running.
326332

327-
Whenever we now make a change to either one of the source files
333+
Whenever we now make a change to either one of the source files
328334
(``model.pkl``, ``data.csv`` or ``explainer.yaml``), this produces a fresh
329335
``explainer.joblib``. And whenever there is a change to either ``explainer.joblib``
330-
or ``dashboard.yaml`` gunicorns restarts and rebuild the dashboard.
331-
332-
So you can keep an explainerdashboard running without interuption and simply
333-
an updated ``model.pkl`` or a fresh dataset ``data.csv`` into the directory and
334-
the dashboard will automatically update.
335-
336-
336+
or ``dashboard.yaml`` gunicorns restarts and rebuild the dashboard.
337337

338+
So you can keep an explainerdashboard running without interuption and simply
339+
an updated ``model.pkl`` or a fresh dataset ``data.csv`` into the directory and
340+
the dashboard will automatically update.

0 commit comments

Comments
 (0)