Skip to content

Commit 1391810

Browse files
authored
Merge pull request #180 from bnaul/sklearn_api
Updates for cesium library refactor
2 parents 99f488a + b1dd23c commit 1391810

20 files changed

+189
-257
lines changed

.travis.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ addons:
5050
packages:
5151
- ccache
5252
- wget
53-
- libhdf5-serial-dev
54-
- libnetcdf-dev
5553
- nodejs
5654
- supervisor
5755
- nginx

cesium_app/ext/sklearn_models.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,37 @@
1-
from cesium.util import make_list
2-
from cesium.build_model import MODELS_TYPE_DICT
1+
import collections
2+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
3+
from sklearn.linear_model import (LinearRegression, SGDClassifier,
4+
RidgeClassifierCV, ARDRegression,
5+
BayesianRidge)
6+
7+
MODELS_TYPE_DICT = {'RandomForestClassifier': RandomForestClassifier,
8+
'RandomForestRegressor': RandomForestRegressor,
9+
'LinearSGDClassifier': SGDClassifier,
10+
'LinearRegressor': LinearRegression,
11+
'RidgeClassifierCV': RidgeClassifierCV,
12+
'BayesianARDRegressor': ARDRegression,
13+
'BayesianRidgeRegressor': BayesianRidge}
14+
15+
16+
def make_list(x):
17+
"""Wrap `x` in a list if it isn't already a list or tuple.
18+
19+
Parameters
20+
----------
21+
x : any valid object
22+
The parameter to be wrapped in a list.
23+
24+
Returns
25+
-------
26+
list or tuple
27+
Returns `[x]` if `x` is not already a list or tuple, otherwise
28+
returns `x`.
29+
30+
"""
31+
if isinstance(x, collections.Iterable) and not isinstance(x, (str, dict)):
32+
return x
33+
else:
34+
return [x]
335

436

537
model_descriptions = [

cesium_app/handlers/dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def post(self):
6161
zipfile_path,
6262
cfg['paths']['ts_data_folder'],
6363
headerfile_path)
64-
meta_features = list(time_series.from_netcdf(ts_paths[0])
65-
.meta_features.keys())
64+
meta_features = list(time_series.load(ts_paths[0]).meta_features.keys())
6665
unique_ts_paths = [os.path.join(os.path.dirname(ts_path),
6766
str(uuid.uuid4()) + "_" +
6867
util.secure_filename(ts_path))

cesium_app/handlers/feature.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import tornado.ioloop
22

3-
import xarray as xr
43
from cesium import featurize, time_series
54
from cesium.features import dask_feature_graph
6-
from cesium import featureset
75

86
from .base import BaseHandler, AccessError
97
from ..models import Dataset, Featureset, Project, File
@@ -79,7 +77,7 @@ def post(self):
7977
return self.error('Cannot access dataset')
8078

8179
fset_path = pjoin(cfg['paths']['features_folder'],
82-
'{}_featureset.nc'.format(uuid.uuid4()))
80+
'{}_featureset.npz'.format(uuid.uuid4()))
8381

8482
fset = Featureset.create(name=featureset_name,
8583
file=File.create(uri=fset_path),
@@ -89,15 +87,18 @@ def post(self):
8987

9088
executor = yield self._get_executor()
9189

92-
all_time_series = executor.map(time_series.from_netcdf, dataset.uris)
90+
all_time_series = executor.map(time_series.load, dataset.uris)
91+
all_labels = executor.map(lambda ts: ts.label, all_time_series)
9392
all_features = executor.map(featurize.featurize_single_ts,
9493
all_time_series,
9594
features_to_use=features_to_use,
9695
custom_script_path=custom_script_path)
9796
computed_fset = executor.submit(featurize.assemble_featureset,
9897
all_features, all_time_series)
99-
imputed_fset = executor.submit(featureset.Featureset.impute, computed_fset)
100-
future = executor.submit(xr.Dataset.to_netcdf, imputed_fset, fset_path)
98+
imputed_fset = executor.submit(featurize.impute_featureset,
99+
computed_fset, inplace=False)
100+
future = executor.submit(featurize.save_featureset, imputed_fset,
101+
fset_path, labels=all_labels)
101102
fset.task_id = future.key
102103
fset.save()
103104

cesium_app/handlers/model.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
from ..models import Project, Model, Featureset, File
55
from ..ext.sklearn_models import (
66
model_descriptions as sklearn_model_descriptions,
7-
check_model_param_types
7+
check_model_param_types, MODELS_TYPE_DICT
88
)
99
from ..util import robust_literal_eval
1010
from ..config import cfg
11+
from cesium import featurize
1112

1213
from os.path import join as pjoin
1314
import uuid
1415
import datetime
1516

16-
from cesium import build_model, featureset
1717
import tornado.ioloop
18+
from sklearn.model_selection import GridSearchCV
1819
import joblib
19-
import xarray as xr
2020
from distributed.client import _wait
2121

2222

@@ -27,7 +27,7 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
2727
Parameters
2828
----------
2929
fset_path : str
30-
Path to feature set NetCDF file.
30+
Path to feature set .npz file.
3131
model_type : str
3232
Type of model to be built, e.g. 'RandomForestClassifier'.
3333
model_params : dict
@@ -57,15 +57,16 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
5757
`params_to_optimize` is None or is an empty dict, this will be an empty
5858
dict.
5959
'''
60-
fset = featureset.from_netcdf(fset_path)
61-
computed_model = build_model.build_model_from_featureset(
62-
featureset=fset, model_type=model_type,
63-
model_parameters=model_params,
64-
params_to_optimize=params_to_optimize)
65-
score = build_model.score_model(computed_model, fset)
66-
best_params = computed_model.best_params_ if params_to_optimize else {}
67-
joblib.dump(computed_model, model_path)
68-
fset.close()
60+
fset, data = featurize.load_featureset(fset_path)
61+
if len(data['labels']) != len(fset):
62+
raise ValueError("Cannot build model for unlabeled feature set.")
63+
model = MODELS_TYPE_DICT[model_type](**model_params)
64+
if params_to_optimize:
65+
model = GridSearchCV(model, params_to_optimize)
66+
model.fit(fset, data['labels'])
67+
score = model.score(fset, data['labels'])
68+
best_params = model.best_params_ if params_to_optimize else {}
69+
joblib.dump(model, model_path)
6970

7071
return score, best_params
7172

cesium_app/handlers/prediction.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,17 @@
77
from tornado.web import RequestHandler
88
from tornado.escape import json_decode
99

10-
import cesium.time_series
11-
import cesium.featurize
12-
import cesium.predict
13-
import cesium.featureset
10+
from cesium import featurize, time_series
1411
from cesium.features import CADENCE_FEATS, GENERAL_FEATS, LOMB_SCARGLE_FEATS
1512

16-
import xarray as xr
1713
import joblib
1814
from os.path import join as pjoin
1915
import uuid
2016
import datetime
2117
import os
2218
import tempfile
19+
import numpy as np
20+
import pandas as pd
2321

2422

2523
class PredictionHandler(BaseHandler):
@@ -82,27 +80,39 @@ def post(self):
8280
if (model.finished is None) or (fset.finished is None):
8381
return self.error('Computation of model or feature set still in progress')
8482

85-
prediction_path = pjoin(cfg['paths']['predictions_folder'],
86-
'{}_prediction.nc'.format(uuid.uuid4()))
87-
prediction_file = File.create(uri=prediction_path)
83+
pred_path = pjoin(cfg['paths']['predictions_folder'],
84+
'{}_prediction.npz'.format(uuid.uuid4()))
85+
prediction_file = File.create(uri=pred_path)
8886
prediction = Prediction.create(file=prediction_file, dataset=dataset,
8987
project=dataset.project, model=model)
9088

9189
executor = yield self._get_executor()
9290

93-
all_time_series = executor.map(cesium.time_series.from_netcdf,
94-
dataset.uris)
95-
all_features = executor.map(cesium.featurize.featurize_single_ts,
91+
all_time_series = executor.map(time_series.load, dataset.uris)
92+
all_labels = executor.map(lambda ts: ts.label, all_time_series)
93+
all_features = executor.map(featurize.featurize_single_ts,
9694
all_time_series,
9795
features_to_use=fset.features_list,
9896
custom_script_path=fset.custom_features_script)
99-
fset_data = executor.submit(cesium.featurize.assemble_featureset,
97+
fset_data = executor.submit(featurize.assemble_featureset,
10098
all_features, all_time_series)
101-
fset_data = executor.submit(cesium.featureset.Featureset.impute, fset_data)
102-
model_data = executor.submit(joblib.load, model.file.uri)
103-
predset = executor.submit(cesium.predict.model_predictions,
104-
fset_data, model_data)
105-
future = executor.submit(xr.Dataset.to_netcdf, predset, prediction_path)
99+
imputed_fset = executor.submit(featurize.impute_featureset,
100+
fset_data, inplace=False)
101+
model_or_gridcv = executor.submit(joblib.load, model.file.uri)
102+
model_data = executor.submit(lambda model: model.best_estimator_
103+
if hasattr(model, 'best_estimator_') else model,
104+
model_or_gridcv)
105+
preds = executor.submit(lambda fset, model: model.predict(fset),
106+
imputed_fset, model_data)
107+
pred_probs = executor.submit(lambda fset, model: model.predict_proba(fset)
108+
if hasattr(model, 'predict_proba') else [],
109+
imputed_fset, model_data)
110+
all_classes = executor.submit(lambda model: model.classes_
111+
if hasattr(model, 'classes_') else [],
112+
model_data)
113+
future = executor.submit(featurize.save_featureset, imputed_fset,
114+
pred_path, labels=all_labels, preds=preds,
115+
pred_probs=pred_probs, all_classes=all_classes)
106116

107117
prediction.task_id = future.key
108118
prediction.save()
@@ -114,14 +124,18 @@ def post(self):
114124

115125
def get(self, prediction_id=None, action=None):
116126
if action == 'download':
117-
prediction = cesium.featureset.from_netcdf(self._get_prediction(prediction_id).file.uri)
118-
with tempfile.NamedTemporaryFile() as tf:
119-
util.prediction_to_csv(prediction, tf.name)
120-
with open(tf.name) as f:
121-
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
122-
self.set_header("Content-Disposition",
123-
"attachment; filename=cesium_prediction_results.csv")
124-
self.write(f.read())
127+
pred_path = self._get_prediction(prediction_id).file.uri
128+
fset, data = featurize.load_featureset(pred_path)
129+
result = pd.DataFrame({'ts_name': fset.index,
130+
'label': data['labels'],
131+
'prediction': data['preds']},
132+
columns=['ts_name', 'label', 'prediction'])
133+
if data.get('pred_probs'):
134+
result['probability'] = np.max(data['pred_probs'], axis=1)
135+
self.set_header("Content-Type", 'text/csv; charset="utf-8"')
136+
self.set_header("Content-Disposition", "attachment; "
137+
"filename=cesium_prediction_results.csv")
138+
self.write(result.to_csv(index=False))
125139
else:
126140
if prediction_id is None:
127141
predictions = [prediction
@@ -144,20 +158,22 @@ class PredictRawDataHandler(BaseHandler):
144158
def post(self):
145159
ts_data = json_decode(self.get_argument('ts_data'))
146160
model_id = json_decode(self.get_argument('modelID'))
147-
meta_feats = json_decode(
148-
self.get_argument('meta_features', 'null'))
149-
impute_kwargs = json_decode(
150-
self.get_argument('impute_kwargs', '{}'))
161+
meta_feats = json_decode(self.get_argument('meta_features', 'null'))
162+
impute_kwargs = json_decode(self.get_argument('impute_kwargs', '{}'))
151163

152164
model = Model.get(Model.id == model_id)
153-
computed_model = joblib.load(model.file.uri)
165+
model_data = joblib.load(model.file.uri)
166+
if hasattr(model_data, 'best_estimator_'):
167+
model_data = model_data.best_estimator_
154168
features_to_use = model.featureset.features_list
155169

156-
fset_data = cesium.featurize.featurize_time_series(
157-
*ts_data, features_to_use=features_to_use, meta_features=meta_feats)
158-
fset = cesium.featureset.Featureset(fset_data).impute(**impute_kwargs)
159-
160-
predset = cesium.predict.model_predictions(fset, computed_model)
161-
predset['name'] = predset.name.astype('str')
162-
163-
return self.success(predset)
170+
fset = featurize.featurize_time_series(*ts_data,
171+
features_to_use=features_to_use,
172+
meta_features=meta_feats)
173+
fset = featurize.impute_featureset(fset, **impute_kwargs)
174+
data = {'preds': model_data.predict(fset),
175+
'all_classes': model_data.classes_}
176+
if hasattr(model_data, 'predict_proba'):
177+
data['pred_probs'] = model_data.predict_proba(fset)
178+
pred_info = Prediction.format_pred_data(fset, data)
179+
return self.success(pred_info)

cesium_app/json_util.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from datetime import datetime
22
import simplejson as json
33
import numpy as np
4+
import pandas as pd
45
import peewee
56
import six
6-
import xarray as xr
77

88

99
data_types = {
@@ -16,28 +16,6 @@
1616
}
1717

1818

19-
def dataset_row_to_dict(row):
20-
"""Semi-hacky helper function for extracting JSON for a single time series
21-
of a featureset. For now assumes single-channel data since that's what the
22-
front end can display.
23-
"""
24-
out = {}
25-
out['target'] = row.target.values.item() if 'target' in row else None
26-
if 'prediction' in row:
27-
if 'class_label' in row: # {class label: probability}
28-
out['prediction'] = {six.u(label): value for label, value
29-
in zip(row.class_label.values,
30-
row.prediction.values)}
31-
else: # just a single predicted label or target
32-
out['prediction'] = row.prediction.values.item()
33-
else:
34-
out['prediction'] = None
35-
out['features'] = {f: row[f].item()
36-
for f in row.data_vars if f != 'prediction'}
37-
38-
return out
39-
40-
4119
class Encoder(json.JSONEncoder):
4220
"""Extends json.JSONEncoder with additional capabilities/configurations."""
4321
def default(self, o):
@@ -62,9 +40,9 @@ def default(self, o):
6240
elif isinstance(o, np.ndarray):
6341
return o.tolist()
6442

65-
elif isinstance(o, xr.Dataset):
66-
return {ts_name: dataset_row_to_dict(o.sel(name=ts_name))
67-
for ts_name in o.name.values}
43+
elif isinstance(o, pd.DataFrame):
44+
o.columns = o.columns.droplevel('channel') # flatten MultiIndex
45+
return o.to_dict(orient='index')
6846

6947
elif type(o) is type and o in data_types:
7048
return data_types[o]

0 commit comments

Comments
 (0)