Skip to content

Commit 163138d

Browse files
committed
perf: partial merges from #173
1 parent 5e92657 commit 163138d

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

R/SurrogateLearner.R

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,18 @@ SurrogateLearner = R6Class("SurrogateLearner",
112112
if (!is.null(self$input_trafo)) {
113113
xdt = self$input_trafo$transform(xdt)
114114
}
115-
pred = self$learner$predict_newdata(newdata = xdt)
115+
116+
# speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
117+
task = self$learner$state$train_task$clone()
118+
set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA
119+
newdata = as_data_backend(xdt)
120+
task$backend = newdata
121+
task$row_roles$use = task$backend$rownames
122+
pred = self$learner$predict(task)
123+
124+
# slow
125+
#pred = self$learner$predict_newdata(newdata = xdt)
126+
116127
pred = if (self$learner$predict_type == "se") {
117128
data.table(mean = pred$response, se = pred$se)
118129
} else {

R/SurrogateLearnerCollection.R

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,32 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
130130
if (!is.null(self$input_trafo)) {
131131
xdt = self$input_trafo$transform(xdt)
132132
}
133+
134+
# speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
133135
preds = lapply(self$learner, function(learner) {
134-
pred = learner$predict_newdata(newdata = xdt)
136+
task = learner$state$train_task$clone()
137+
set(xdt, j = task$target_names, value = NA_real_) # tasks only have features and the target but we have to set the target to NA
138+
newdata = as_data_backend(xdt)
139+
task$backend = newdata
140+
task$row_roles$use = task$backend$rownames
141+
pred = learner$predict(task)
135142
if (learner$predict_type == "se") {
136143
data.table(mean = pred$response, se = pred$se)
137144
} else {
138145
data.table(mean = pred$response)
139146
}
140147
})
148+
149+
# slow
150+
#preds = lapply(self$learner, function(learner) {
151+
# pred = learner$predict_newdata(newdata = xdt)
152+
# if (learner$predict_type == "se") {
153+
# data.table(mean = pred$response, se = pred$se)
154+
# } else {
155+
# data.table(mean = pred$response)
156+
# }
157+
#})
158+
141159
names(preds) = names(self$learner)
142160
if (!is.null(self$output_trafo) && self$output_trafo$invert_posterior) {
143161
preds = self$output_trafo$inverse_transform_posterior(preds)

0 commit comments

Comments
 (0)