Skip to content

Commit 1ef9beb

Browse files
committed
Expose n_jobs param for relevant models, set default to n_cores
1 parent b15966e commit 1ef9beb

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

cesium_app/ext/sklearn_models.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import collections
23
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
34
from sklearn.linear_model import (LinearRegression, SGDClassifier,
@@ -12,6 +13,8 @@
1213
'BayesianARDRegressor': ARDRegression,
1314
'BayesianRidgeRegressor': BayesianRidge}
1415

16+
N_JOBS_DEFAULT = os.cpu_count()
17+
1518

1619
def make_list(x):
1720
"""Wrap `x` in a list if it isn't already a list or tuple.
@@ -49,7 +52,8 @@ def make_list(x):
4952
{"name": "bootstrap", "type": bool, "default": True},
5053
{"name": "oob_score", "type": bool, "default": False},
5154
{"name": "random_state", "type": int, "default": None},
52-
{"name": "class_weight", "type": dict, "default": None}],
55+
{"name": "class_weight", "type": dict, "default": None},
56+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
5357
"type": "classifier",
5458
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"},
5559

@@ -67,7 +71,8 @@ def make_list(x):
6771
{"name": "bootstrap", "type": bool, "default": True},
6872
{"name": "oob_score", "type": bool, "default": False},
6973
{"name": "random_state", "type": int, "default": None},
70-
{"name": "class_weight", "type": dict, "default": None}],
74+
{"name": "class_weight", "type": dict, "default": None},
75+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
7176
"type": "classifier",
7277
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html"},
7378

@@ -84,7 +89,8 @@ def make_list(x):
8489
{"name": "max_leaf_nodes", "type": int, "default": None},
8590
{"name": "bootstrap", "type": bool, "default": True},
8691
{"name": "oob_score", "type": bool, "default": False},
87-
{"name": "random_state", "type": int, "default": None}],
92+
{"name": "random_state", "type": int, "default": None},
93+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
8894
"type": "regressor",
8995
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html"},
9096

@@ -101,7 +107,8 @@ def make_list(x):
101107
{"name": "max_leaf_nodes", "type": int, "default": None},
102108
{"name": "bootstrap", "type": bool, "default": True},
103109
{"name": "oob_score", "type": bool, "default": False},
104-
{"name": "random_state", "type": int, "default": None}],
110+
{"name": "random_state", "type": int, "default": None},
111+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
105112
"type": "regressor",
106113
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html"},
107114

@@ -122,14 +129,16 @@ def make_list(x):
122129
{"name": "eta0", "type": float, "default": 0.0},
123130
{"name": "power_t", "type": float, "default": 0.5},
124131
{"name": "class_weight", "type": [dict, str], "default": None},
125-
{"name": "average", "type": [bool, int], "default": False}],
132+
{"name": "average", "type": [bool, int], "default": False},
133+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
126134
"type": "classifier",
127135
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html"},
128136

129137
{"name": "LinearRegressor",
130138
"params": [
131139
{"name": "fit_intercept", "type": bool, "default": True},
132-
{"name": "normalize", "type": bool, "default": False}],
140+
{"name": "normalize", "type": bool, "default": False},
141+
{"name": "n_jobs", "type": int, "default": N_JOBS_DEFAULT}],
133142
"type": "regressor",
134143
"url": "http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html"},
135144

cesium_app/handlers/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _build_model_compute_statistics(fset_path, model_type, model_params,
6262
raise ValueError("Cannot build model for unlabeled feature set.")
6363
model = MODELS_TYPE_DICT[model_type](**model_params)
6464
if params_to_optimize:
65-
model = GridSearchCV(model, params_to_optimize, n_jobs=-1)
65+
model = GridSearchCV(model, params_to_optimize)
6666
model.fit(fset, data['labels'])
6767
score = model.score(fset, data['labels'])
6868
best_params = model.best_params_ if params_to_optimize else {}

0 commit comments

Comments
 (0)