Skip to content

Commit 5e92657

Browse files
authored
feat: InputTrafos and OutputTrafos, AcqFunctionEILog (#178)
includes changes to input transformations (#177)
1 parent ea67608 commit 5e92657

File tree

76 files changed

+2988
-111
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2988
-111
lines changed

DESCRIPTION

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Collate:
8181
'AcqFunctionEHVI.R'
8282
'AcqFunctionEHVIGH.R'
8383
'AcqFunctionEI.R'
84+
'AcqFunctionEILog.R'
8485
'AcqFunctionEIPS.R'
8586
'AcqFunctionMean.R'
8687
'AcqFunctionMulti.R'
@@ -90,10 +91,17 @@ Collate:
9091
'AcqFunctionStochasticCB.R'
9192
'AcqFunctionStochasticEI.R'
9293
'AcqOptimizer.R'
94+
'mlr_input_trafos.R'
95+
'InputTrafo.R'
96+
'InputTrafoUnitcube.R'
9397
'aaa.R'
9498
'OptimizerADBO.R'
9599
'OptimizerAsyncMbo.R'
96100
'OptimizerMbo.R'
101+
'mlr_output_trafos.R'
102+
'OutputTrafo.R'
103+
'OutputTrafoLog.R'
104+
'OutputTrafoStandardize.R'
97105
'mlr_result_assigners.R'
98106
'ResultAssigner.R'
99107
'ResultAssignerArchive.R'

NAMESPACE

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(as.data.table,DictionaryAcqFunction)
4+
S3method(as.data.table,DictionaryInputTrafo)
45
S3method(as.data.table,DictionaryLoopFunction)
6+
S3method(as.data.table,DictionaryOutputTrafo)
57
S3method(as.data.table,DictionaryResultAssigner)
68
S3method(print,loop_function)
79
export(AcqFunction)
@@ -10,6 +12,7 @@ export(AcqFunctionCB)
1012
export(AcqFunctionEHVI)
1113
export(AcqFunctionEHVIGH)
1214
export(AcqFunctionEI)
15+
export(AcqFunctionEILog)
1316
export(AcqFunctionEIPS)
1417
export(AcqFunctionMean)
1518
export(AcqFunctionMulti)
@@ -19,9 +22,14 @@ export(AcqFunctionSmsEgo)
1922
export(AcqFunctionStochasticCB)
2023
export(AcqFunctionStochasticEI)
2124
export(AcqOptimizer)
25+
export(InputTrafo)
26+
export(InputTrafoUnitcube)
2227
export(OptimizerADBO)
2328
export(OptimizerAsyncMbo)
2429
export(OptimizerMbo)
30+
export(OutputTrafo)
31+
export(OutputTrafoLog)
32+
export(OutputTrafoStandardize)
2533
export(ResultAssigner)
2634
export(ResultAssignerArchive)
2735
export(ResultAssignerSurrogate)
@@ -46,9 +54,13 @@ export(default_loop_function)
4654
export(default_result_assigner)
4755
export(default_rf)
4856
export(default_surrogate)
57+
export(it)
4958
export(mlr_acqfunctions)
59+
export(mlr_input_trafos)
5060
export(mlr_loop_functions)
61+
export(mlr_output_trafos)
5162
export(mlr_result_assigners)
63+
export(ot)
5264
export(ras)
5365
export(redis_available)
5466
export(srlrn)
@@ -67,6 +79,7 @@ importFrom(stats,pnorm)
6779
importFrom(stats,quantile)
6880
importFrom(stats,rexp)
6981
importFrom(stats,runif)
82+
importFrom(stats,sd)
7083
importFrom(stats,setNames)
7184
importFrom(utils,bibentry)
7285
useDynLib(mlr3mbo,c_eps_indicator)

R/AcqFunctionAEI.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ AcqFunctionAEI = R6Class("AcqFunctionAEI",
9595
#' Update the acquisition function and set `y_effective_best` and `noise_var`.
9696
update = function() {
9797
xdt = self$archive$data[, self$archive$cols_x, with = FALSE]
98-
p = self$surrogate$predict(xdt)
99-
y_effective = p$mean + (self$surrogate_max_to_min * self$constants$values$c * p$se) # pessimistic prediction
98+
pred = self$surrogate$predict(xdt)
99+
# NOTE: output_trafo_must_be_considered is not relevant to y here because y_effective_best is determined from the predictions
100+
y_effective = pred$mean + (self$surrogate_max_to_min * self$constants$values$c * pred$se) # pessimistic prediction
100101
self$y_effective_best = min(self$surrogate_max_to_min * y_effective)
101102

102103
if (!is.null(self$surrogate$learner$model) && length(self$surrogate$learner$model@covariance@nugget) == 1L) {

R/AcqFunctionEHVI.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ AcqFunctionEHVI = R6Class("AcqFunctionEHVI",
8383
stopf("'%s' only works for exactly two objectives.", format(self))
8484
}
8585
ys = self$archive$data[, self$archive$cols_y, with = FALSE]
86+
if (self$surrogate$output_trafo_must_be_considered) {
87+
ys = self$surrogate$output_trafo$transform(ys)
88+
}
8689
for (column in self$archive$cols_y) {
8790
set(ys, j = column, value = ys[[column]] * self$surrogate_max_to_min[[column]]) # assume minimization
8891
}
@@ -155,3 +158,4 @@ mlr_acqfunctions$add("ehvi", AcqFunctionEHVI)
155158
psi_function = function(a, b, mu, sigma) {
156159
(sigma * dnorm((b - mu) / sigma) + ((a - mu) * pnorm((b - mu) / sigma)))
157160
}
161+

R/AcqFunctionEHVIGH.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ AcqFunctionEHVIGH = R6Class("AcqFunctionEHVIGH",
108108
update = function() {
109109
n_obj = length(self$archive$cols_y)
110110
ys = self$archive$data[, self$archive$cols_y, with = FALSE]
111+
if (self$surrogate$output_trafo_must_be_considered) {
112+
ys = self$surrogate$output_trafo$transform(ys)
113+
}
111114
for (column in self$archive$cols_y) {
112115
set(ys, j = column, value = ys[[column]] * self$surrogate_max_to_min[[column]]) # assume minimization
113116
}

R/AcqFunctionEI.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
8181
#' @description
8282
#' Update the acquisition function and set `y_best`.
8383
update = function() {
84-
self$y_best = min(self$surrogate_max_to_min * self$archive$data[[self$surrogate$cols_y]])
84+
y = self$archive$data[, self$surrogate$cols_y, with = FALSE]
85+
if (self$surrogate$output_trafo_must_be_considered) {
86+
y = self$surrogate$output_trafo$transform(y)
87+
}
88+
self$y_best = min(self$surrogate_max_to_min * y)
8589
}
8690
),
8791

R/AcqFunctionEILog.R

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#' @title Acquisition Function Expected Improvement on Log Scale
2+
#'
3+
#' @include AcqFunction.R
4+
#' @name mlr_acqfunctions_ei_log
5+
#'
6+
#' @templateVar id ei_log
7+
#' @template section_dictionary_acqfunctions
8+
#'
9+
#' @description
10+
#' Expected Improvement assuming that the target variable has been modeled on log scale.
11+
#' In general only sensible if the [SurrogateLearner] uses an [OutputTrafoLog] without inverting the posterior predictive distribution (`invert_posterior = FALSE`).
12+
#' See also the example below.
13+
#'
14+
#' @section Parameters:
15+
#' * `"epsilon"` (`numeric(1)`)\cr
16+
#' \eqn{\epsilon} value used to determine the amount of exploration.
17+
#' Higher values result in the importance of improvements predicted by the posterior mean
18+
#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty.
19+
#' Defaults to `0` (standard Expected Improvement).
20+
#'
21+
#' @family Acquisition Function
22+
#' @export
23+
#' @examples
24+
#' if (requireNamespace("mlr3learners") &
25+
#' requireNamespace("DiceKriging") &
26+
#' requireNamespace("rgenoud")) {
27+
#' library(bbotk)
28+
#' library(paradox)
29+
#' library(mlr3learners)
30+
#' library(data.table)
31+
#'
32+
#' fun = function(xs) {
33+
#' list(y = xs$x ^ 2)
34+
#' }
35+
#' domain = ps(x = p_dbl(lower = -10, upper = 10))
36+
#' codomain = ps(y = p_dbl(tags = "minimize"))
37+
#' objective = ObjectiveRFun$new(fun = fun, domain = domain, codomain = codomain)
38+
#'
39+
#' instance = OptimInstanceBatchSingleCrit$new(
40+
#' objective = objective,
41+
#' terminator = trm("evals", n_evals = 5))
42+
#'
43+
#' instance$eval_batch(data.table(x = c(-6, -5, 3, 9)))
44+
#'
45+
#' learner = default_gp()
46+
#'
47+
#' output_trafo = ot("log", invert_posterior = FALSE)
48+
#'
49+
#' surrogate = srlrn(learner, output_trafo = output_trafo, archive = instance$archive)
50+
#'
51+
#' acq_function = acqf("ei_log", surrogate = surrogate)
52+
#'
53+
#' acq_function$surrogate$update()
54+
#' acq_function$update()
55+
#' acq_function$eval_dt(data.table(x = c(-1, 0, 1)))
56+
#' }
57+
AcqFunctionEILog = R6Class("AcqFunctionEILog",
58+
inherit = AcqFunction,
59+
60+
public = list(
61+
62+
#' @field y_best (`numeric(1)`)\cr
63+
#' Best objective function value observed so far.
64+
#' In the case of maximization, this already includes the necessary change of sign.
65+
y_best = NULL,
66+
67+
#' @description
68+
#' Creates a new instance of this [R6][R6::R6Class] class.
69+
#'
70+
#' @param surrogate (`NULL` | [SurrogateLearner]).
71+
#' @param epsilon (`numeric(1)`).
72+
initialize = function(surrogate = NULL, epsilon = 0) {
73+
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
74+
assert_number(epsilon, lower = 0, finite = TRUE)
75+
76+
constants = ps(epsilon = p_dbl(lower = 0, default = 0))
77+
constants$values$epsilon = epsilon
78+
79+
super$initialize("acq_ei_log", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement on Log Scale", man = "mlr3mbo::mlr_acqfunctions_ei_log")
80+
},
81+
82+
#' @description
83+
#' Update the acquisition function and set `y_best`.
84+
update = function() {
85+
assert_r6(self$surrogate$output_trafo, "OutputTrafoLog")
86+
assert_false(self$surrogate$output_trafo$invert_posterior)
87+
y = self$archive$data[, self$surrogate$cols_y, with = FALSE]
88+
if (self$surrogate$output_trafo_must_be_considered) {
89+
y = self$surrogate$output_trafo$transform(y)
90+
}
91+
self$y_best = min(self$surrogate_max_to_min * y)
92+
}
93+
),
94+
95+
private = list(
96+
.fun = function(xdt, ...) {
97+
if (is.null(self$y_best)) {
98+
stop("$y_best is not set. Missed to call $update()?")
99+
}
100+
assert_r6(self$surrogate$output_trafo, "OutputTrafoLog")
101+
assert_false(self$surrogate$output_trafo$invert_posterior)
102+
constants = list(...)
103+
epsilon = constants$epsilon
104+
p = self$surrogate$predict(xdt)
105+
mu = p$mean
106+
se = p$se
107+
108+
# FIXME: try to unify w.r.t minimization / maximization and the respective transformation
109+
if (self$surrogate_max_to_min == 1L) {
110+
# y is to be minimized and the OutputTrafoLog performed the transformation accordingly
111+
assert_true(self$surrogate$output_trafo$max_to_min == 1L)
112+
y_best = self$y_best
113+
d = (y_best - mu) - epsilon
114+
d_norm = d / se
115+
multiplicative_factor = (self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$max - self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$min)
116+
ei_log = multiplicative_factor * ((exp(y_best) * pnorm(d_norm)) - (exp((0.5 * se^2) + mu)) * pnorm(d_norm - se))
117+
} else {
118+
# y is to be maximized and the OutputTrafoLog performed the transformation accordingly
119+
y_best = - self$y_best
120+
d = (mu - y_best) - epsilon
121+
d_norm = d / se
122+
multiplicative_factor = (self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$max - self$surrogate$output_trafo$state[[self$surrogate$output_trafo$cols_y]]$min)
123+
ei_log = multiplicative_factor * ((exp(-y_best) * pnorm(d_norm)) - (exp((0.5 * se^2) - mu) * pnorm(d_norm - se)))
124+
}
125+
ei_log = ifelse(se < 1e-20 | is.na(ei_log), 0, ei_log)
126+
data.table(acq_ei_log = ei_log)
127+
}
128+
)
129+
)
130+
131+
mlr_acqfunctions$add("ei_log", AcqFunctionEILog)
132+

R/AcqFunctionEIPS.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ AcqFunctionEIPS = R6Class("AcqFunctionEIPS",
7878
#' @description
7979
#' Update the acquisition function and set `y_best`.
8080
update = function() {
81-
self$y_best = min(self$surrogate_max_to_min[[self$col_y]] * self$archive$data[[self$col_y]])
81+
ys = self$archive$data[, self$surrogate$cols_y, with = FALSE]
82+
if (self$surrogate$output_trafo_must_be_considered) {
83+
ys = self$surrogate$output_trafo$transform(ys)
84+
}
85+
self$y_best = min(self$surrogate_max_to_min[[self$col_y]] * ys[[self$col_y]])
8286
}
8387
),
8488

R/AcqFunctionPI.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ AcqFunctionPI = R6Class("AcqFunctionPI",
6868
#' @description
6969
#' Update the acquisition function and set `y_best`.
7070
update = function() {
71-
self$y_best = min(self$surrogate_max_to_min * self$archive$data[[self$surrogate$cols_y]])
71+
y = self$archive$data[, self$surrogate$cols_y, with = FALSE]
72+
if (self$surrogate$output_trafo_must_be_considered) {
73+
y = self$surrogate$output_trafo$transform(y)
74+
}
75+
self$y_best = min(self$surrogate_max_to_min * y)
7276
}
7377
),
7478

R/AcqFunctionSmsEgo.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ AcqFunctionSmsEgo = R6Class("AcqFunctionSmsEgo",
116116

117117
n_obj = length(self$archive$cols_y)
118118
ys = self$archive$data[, self$archive$cols_y, with = FALSE]
119+
if (self$surrogate$output_trafo_must_be_considered) {
120+
ys = self$surrogate$output_trafo$transform(ys)
121+
}
119122
for (column in self$archive$cols_y) {
120123
set(ys, j = column, value = ys[[column]] * self$surrogate_max_to_min[[column]]) # assume minimization
121124
}

0 commit comments

Comments
 (0)