Skip to content

Commit ff300a1

Browse files
committed
Add OOB score and feature importance chart to displayed model metrics
1 parent 39687b3 commit ff300a1

File tree

5 files changed

+61
-9
lines changed

5 files changed

+61
-9
lines changed

cesium_app/handlers/model.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,24 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6767
model = GridSearchCV(model, params_to_optimize,
6868
n_jobs=n_jobs)
6969
model.fit(fset, data['labels'])
70-
score = model.score(fset, data['labels'])
70+
71+
metrics = {}
72+
metrics['train_score'] = model.score(fset, data['labels'])
73+
7174
best_params = model.best_params_ if params_to_optimize else {}
7275
joblib.dump(model, model_path)
7376

74-
return score, best_params
77+
if model_type == 'RandomForestClassifier':
78+
if params_to_optimize:
79+
model = model.best_estimator_
80+
if hasattr(model, 'oob_score_'):
81+
metrics['oob_score'] = model.oob_score_
82+
if hasattr(model, 'feature_importances_'):
83+
metrics['feature_importances'] = dict(zip(
84+
fset.columns.get_level_values(0).tolist(),
85+
model.feature_importances_.tolist()))
86+
87+
return metrics, best_params
7588

7689

7790
class ModelHandler(BaseHandler):
@@ -102,12 +115,12 @@ def get(self, model_id=None, action=None):
102115
@auth_or_token
103116
async def _await_model_statistics(self, model_stats_future, model):
104117
try:
105-
score, best_params = await model_stats_future
118+
model_metrics, best_params = await model_stats_future
106119

107120
model = DBSession().merge(model)
108121
model.task_id = None
109122
model.finished = datetime.datetime.now()
110-
model.train_score = score
123+
model.metrics = model_metrics
111124
model.params.update(best_params)
112125
DBSession().commit()
113126

cesium_app/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class Model(Base):
9090
file_uri = sa.Column(sa.String(), nullable=True, index=True)
9191
task_id = sa.Column(sa.String())
9292
finished = sa.Column(sa.DateTime)
93-
train_score = sa.Column(sa.Float)
93+
metrics = sa.Column(sa.JSON, nullable=True)
9494

9595
featureset = relationship('Featureset')
9696
project = relationship('Project')

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
"bokehjs": "^0.12.5",
1010
"bootstrap": "^3.3.7",
1111
"bootstrap-css": "^3.0.0",
12+
"chart.js": "^2.7.1",
1213
"css-loader": "^0.26.2",
1314
"exports-loader": "^0.6.4",
1415
"imports-loader": "^0.7.1",
1516
"jquery": "^3.1.1",
1617
"prop-types": "^15.5.10",
1718
"react": "^15.1.0",
19+
"react-chartjs-2": "^2.7.0",
1820
"react-dom": "^15.1.0",
1921
"react-redux": "^5.0.3",
2022
"react-tabs": "^0.8.2",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import React from 'react';
2+
import { HorizontalBar } from 'react-chartjs-2';
3+
4+
5+
const FeatureImportancesBarchart = props => {
6+
const sorted_features = Object.keys(props.data).sort(
7+
(a, b) => props.data[b] - props.data[a]).slice(0, 15);
8+
const values = sorted_features.map(
9+
feature => props.data[feature].toFixed(3));
10+
const data = {
11+
labels: sorted_features,
12+
datasets: [
13+
{
14+
label: 'Feature Importance',
15+
backgroundColor: '#2222ff',
16+
hoverBackgroundColor: '#5555ff',
17+
data: values
18+
}
19+
]
20+
};
21+
22+
return (
23+
<div style={{ height: 300, width: 600 }}>
24+
<HorizontalBar data={data} />
25+
</div>
26+
);
27+
};
28+
29+
export default FeatureImportancesBarchart;

static/js/components/Models.jsx

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import Delete from './Delete';
1313
import Download from './Download';
1414
import { $try, reformatDatetime } from '../utils';
1515
import FoldableRow from './FoldableRow';
16+
import FeatureImportances from './FeatureImportances';
1617

1718

1819
const ModelsTab = props => (
@@ -178,7 +179,7 @@ const ModelInfo = props => (
178179
<tr>
179180
<th>Model Type</th>
180181
<th>Hyperparameters</th>
181-
<th>Training Data Score</th>
182+
{Object.keys(props.model.metrics).map(metric => <th>{metric}</th>)}
182183
</tr>
183184
</thead>
184185
<tbody>
@@ -200,9 +201,16 @@ const ModelInfo = props => (
200201
</tbody>
201202
</table>
202203
</td>
203-
<td>
204-
{props.model.train_score}
205-
</td>
204+
{
205+
Object.keys(props.model.metrics).map(metric => (
206+
<td>
207+
{
208+
metric == 'feature_importances' ?
209+
<FeatureImportances data={props.model.metrics[metric]} /> :
210+
props.model.metrics[metric]
211+
}
212+
</td>))
213+
}
206214
</tr>
207215
</tbody>
208216
</table>

0 commit comments

Comments
 (0)