Skip to content

Commit 3639646

Browse files
committed
Add test for specifying TS name in predict call
1 parent c237987 commit 3639646

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

cesium_app/handlers/prediction.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def _get_prediction(self, prediction_id):
2525
try:
2626
d = Prediction.get(Prediction.id == prediction_id)
2727
except Prediction.DoesNotExist:
28-
raise AccessError('No such dataset')
28+
raise AccessError('No such prediction')
2929

3030
if not d.is_owned_by(self.get_username()):
31-
raise AccessError('No such dataset')
31+
raise AccessError('No such prediction')
3232

3333
return d
3434

@@ -90,7 +90,9 @@ def post(self):
9090
executor = yield self._get_executor()
9191

9292
if ts_names:
93-
ts_uris = [f.uri for f in dataset.files if f.name in ts_names]
93+
ts_uris = [f.uri for f in dataset.files if os.path.basename(f.name)
94+
in ts_names or os.path.basename(f.name).split('.npz')[0]
95+
in ts_names]
9496
else:
9597
ts_uris = dataset.uris
9698

cesium_app/tests/frontend/test_predict.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from os.path import join as pjoin
88
import numpy as np
99
import numpy.testing as npt
10+
from cesium_app.config import cfg
11+
import json
12+
import requests
1013
from cesium_app.tests.fixtures import (create_test_project, create_test_dataset,
1114
create_test_featureset, create_test_model,
1215
create_test_prediction)
@@ -204,3 +207,34 @@ def test_download_prediction_csv_regr(driver):
204207
[4, 3.1, 3.1]])
205208
finally:
206209
os.remove('/tmp/cesium_prediction_results.csv')
210+
211+
212+
def test_predict_specific_ts_name():
213+
with create_test_project() as p, create_test_dataset(p) as ds,\
214+
create_test_featureset(p) as fs, create_test_model(fs) as m:
215+
ts_data = [[1, 2, 3, 4], [32.2, 53.3, 32.3, 32.52], [0.2, 0.3, 0.6, 0.3]]
216+
impute_kwargs = {'strategy': 'constant', 'value': None}
217+
data = {'datasetID': ds.id,
218+
'ts_names': ['217801'],
219+
'modelID': m.id}
220+
print('data:', data)
221+
response = requests.post('{}/predictions'.format(cfg['server']['url']),
222+
data=json.dumps(data)).json()
223+
print('response dict:', response)
224+
assert response['status'] == 'success'
225+
226+
n_secs = 0
227+
while n_secs < 5:
228+
pred_info = requests.get('{}/predictions/{}'.format(
229+
cfg['server']['url'], response['data']['id'])).json()
230+
print(pred_info)
231+
if pred_info['status'] == 'success' and pred_info['data']['finished']:
232+
assert isinstance(pred_info['data']['results']['217801']
233+
['features']['total_time'],
234+
float)
235+
assert 'Mira' in pred_info['data']['results']['217801']['prediction']
236+
break
237+
n_secs += 1
238+
time.sleep(1)
239+
else:
240+
raise Exception('test_predict_specific_ts_name timed out')

0 commit comments

Comments
 (0)