Skip to content

Commit 620ad8e

Browse files
committed
Pass n_jobs to GridSearchCV and not model object when applicable
1 parent 1ef9beb commit 620ad8e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

cesium_app/handlers/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,12 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6060
fset, data = featurize.load_featureset(fset_path)
6161
if len(data['labels']) != len(fset):
6262
raise ValueError("Cannot build model for unlabeled feature set.")
63+
n_jobs = (model_params.pop('n_jobs') if 'n_jobs' in model_params
64+
and params_to_optimize else -1)
6365
model = MODELS_TYPE_DICT[model_type](**model_params)
6466
if params_to_optimize:
65-
model = GridSearchCV(model, params_to_optimize)
67+
model = GridSearchCV(model, params_to_optimize,
68+
n_jobs=n_jobs)
6669
model.fit(fset, data['labels'])
6770
score = model.score(fset, data['labels'])
6871
best_params = model.best_params_ if params_to_optimize else {}

0 commit comments

Comments
 (0)