@@ -130,14 +130,32 @@ SurrogateLearnerCollection = R6Class("SurrogateLearnerCollection",
130
130
if (! is.null(self $ input_trafo )) {
131
131
xdt = self $ input_trafo $ transform(xdt )
132
132
}
133
+
134
+ # speeding up some checks by constructing the predict task directly instead of relying on predict_newdata
133
135
preds = lapply(self $ learner , function (learner ) {
134
- pred = learner $ predict_newdata(newdata = xdt )
136
+ task = learner $ state $ train_task $ clone()
137
+ set(xdt , j = task $ target_names , value = NA_real_ ) # tasks only have features and the target but we have to set the target to NA
138
+ newdata = as_data_backend(xdt )
139
+ task $ backend = newdata
140
+ task $ row_roles $ use = task $ backend $ rownames
141
+ pred = learner $ predict(task )
135
142
if (learner $ predict_type == " se" ) {
136
143
data.table(mean = pred $ response , se = pred $ se )
137
144
} else {
138
145
data.table(mean = pred $ response )
139
146
}
140
147
})
148
+
149
+ # slow
150
+ # preds = lapply(self$learner, function(learner) {
151
+ # pred = learner$predict_newdata(newdata = xdt)
152
+ # if (learner$predict_type == "se") {
153
+ # data.table(mean = pred$response, se = pred$se)
154
+ # } else {
155
+ # data.table(mean = pred$response)
156
+ # }
157
+ # })
158
+
141
159
names(preds ) = names(self $ learner )
142
160
if (! is.null(self $ output_trafo ) && self $ output_trafo $ invert_posterior ) {
143
161
preds = self $ output_trafo $ inverse_transform_posterior(preds )
0 commit comments