Skip to content

Commit d5b394b

Browse files
authored
Merge pull request #1085 from katrinabrock/add-exe-helpers
Add parsers for cpp opts
2 parents 50b6e8d + dde1164 commit d5b394b

File tree

5 files changed

+290
-90
lines changed

5 files changed

+290
-90
lines changed

R/cpp_opts.R

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Internal functions for handling cpp options
2+
3+
# running and parsing exe info --------------------------------
4+
# run <model> info command
5+
#' @example `.cmdstan/bin`
6+
run_info_cli <- function(exe_file) {
7+
withr::with_path(
8+
c(
9+
toolchain_PATH_env_var(),
10+
tbb_path()
11+
),
12+
ret <- wsl_compatible_run(
13+
command = wsl_safe_path(exe_file),
14+
args = "info",
15+
echo = is_verbose_mode(),
16+
error_on_status = FALSE
17+
)
18+
)
19+
ret
20+
}
21+
22+
# new (future) parser
23+
# Parse the string output of <model> `info` into an R object (list)
24+
parse_exe_info_string <- function(ret_stdout) {
25+
info <- list()
26+
info_raw <- strsplit(strsplit(ret_stdout, "\n")[[1]], "=")
27+
for (key_val in info_raw) {
28+
if (length(key_val) > 1) {
29+
key_val <- trimws(key_val)
30+
val <- key_val[2]
31+
if (!is.na(as.logical(val))) {
32+
val <- as.logical(val)
33+
}
34+
info[[tolower(key_val[1])]] <- val
35+
}
36+
}
37+
38+
info[["stan_version"]] <- paste0(
39+
info[["stan_version_major"]],
40+
".",
41+
info[["stan_version_minor"]],
42+
".", info[["stan_version_patch"]]
43+
)
44+
info[["stan_version_major"]] <- NULL
45+
info[["stan_version_minor"]] <- NULL
46+
info[["stan_version_patch"]] <- NULL
47+
48+
info
49+
}
50+
51+
# old (current) parser
52+
model_compile_info <- function(exe_file, version) {
53+
info <- NULL
54+
if (version > "2.26.1") {
55+
56+
ret <- run_info_cli(exe_file)
57+
if (ret$status == 0) {
58+
info <- list()
59+
info_raw <- strsplit(strsplit(ret$stdout, "\n")[[1]], "=")
60+
for (key_val in info_raw) {
61+
if (length(key_val) > 1) {
62+
key_val <- trimws(key_val)
63+
val <- key_val[2]
64+
if (!is.na(as.logical(val))) {
65+
val <- as.logical(val)
66+
}
67+
info[[toupper(key_val[1])]] <- val
68+
}
69+
}
70+
info[["STAN_VERSION"]] <- paste0(info[["STAN_VERSION_MAJOR"]], ".", info[["STAN_VERSION_MINOR"]], ".", info[["STAN_VERSION_PATCH"]])
71+
info[["STAN_VERSION_MAJOR"]] <- NULL
72+
info[["STAN_VERSION_MINOR"]] <- NULL
73+
info[["STAN_VERSION_PATCH"]] <- NULL
74+
}
75+
}
76+
info
77+
}
78+
79+
# convert to compile flags --------------------
80+
# from list(flag1=TRUE, flag2=FALSE) to "FLAG1=TRUE\nFLAG2=FALSE"
81+
cpp_options_to_compile_flags <- function(cpp_options) {
82+
if (length(cpp_options) == 0) {
83+
return(NULL)
84+
}
85+
cpp_built_options <- c()
86+
for (i in seq_along(cpp_options)) {
87+
option_name <- names(cpp_options)[i]
88+
if (is.null(option_name) || !nzchar(option_name)) {
89+
cpp_built_options <- c(cpp_built_options, cpp_options[[i]])
90+
} else {
91+
cpp_built_options <- c(cpp_built_options, paste0(toupper(option_name), "=", cpp_options[[i]]))
92+
}
93+
}
94+
cpp_built_options
95+
}
96+
97+
98+
# check options overall for validity ---------------------------------
99+
# takes list of options as input and returns list of options
100+
# returns list with names standardized to lowercase
101+
validate_cpp_options <- function(cpp_options) {
102+
if (is.null(cpp_options) || length(cpp_options) == 0) return(list())
103+
104+
if (
105+
!is.null(cpp_options[["user_header"]]) &&
106+
!is.null(cpp_options[["USER_HEADER"]])
107+
) {
108+
warning(
109+
"User header specified both via cpp_options[[\"USER_HEADER\"]] ",
110+
"and cpp_options[[\"user_header\"]]. Please only specify your user header in one location",
111+
call. = FALSE
112+
)
113+
}
114+
115+
names(cpp_options) <- tolower(names(cpp_options))
116+
flags_set_if_defined <- c(
117+
# cmdstan
118+
"stan_threads", "stan_mpi", "stan_opencl",
119+
"stan_no_range_checks", "stan_cpp_optims",
120+
# stan math
121+
"integrated_opencl", "tbb_lib", "tbb_inc", "tbb_interface_new"
122+
)
123+
for (flag in flags_set_if_defined) {
124+
if (isFALSE(cpp_options[[flag]])) warning(
125+
toupper(flag), " set to ", cpp_options[flag],
126+
" Since this is a non-empty value, ",
127+
"it will result in the corresponding ccp option being turned ON. To turn this",
128+
" option off, use cpp_options = list(", flag, " = NULL)."
129+
)
130+
}
131+
cpp_options
132+
}
133+
134+
# check specific options for validity ---------------------------------
135+
# no type checking for opencl_ids
136+
# cpp_options must be a list
137+
# opencl_ids returned unchanged
138+
assert_valid_opencl <- function(opencl_ids, cpp_options) {
139+
if (is.null(cpp_options[["stan_opencl"]])
140+
&& !is.null(opencl_ids)) {
141+
stop("'opencl_ids' is set but the model was not compiled for use with OpenCL.",
142+
"\nRecompile the model with 'cpp_options = list(stan_opencl = TRUE)'",
143+
call. = FALSE)
144+
}
145+
invisible(opencl_ids)
146+
}
147+
148+
# cpp_options must be a list
149+
assert_valid_threads <- function(threads, cpp_options, multiple_chains = FALSE) {
150+
threads_arg <- if (multiple_chains) "threads_per_chain" else "threads"
151+
checkmate::assert_integerish(threads, .var.name = threads_arg,
152+
null.ok = TRUE, lower = 1, len = 1)
153+
if (is.null(cpp_options[["stan_threads"]]) || !isTRUE(cpp_options[["stan_threads"]])) {
154+
if (!is.null(threads)) {
155+
warning(
156+
"'", threads_arg, "' is set but the model was not compiled with ",
157+
"'cpp_options = list(stan_threads = TRUE)' ",
158+
"so '", threads_arg, "' will have no effect!",
159+
call. = FALSE
160+
)
161+
threads <- NULL
162+
}
163+
} else if (isTRUE(cpp_options[["stan_threads"]]) && is.null(threads)) {
164+
stop(
165+
"The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ",
166+
"but '", threads_arg, "' was not set!",
167+
call. = FALSE
168+
)
169+
}
170+
invisible(threads)
171+
}
172+
173+
# For two functions below
174+
# both styles are lists which should have flag names in lower case as names of the list
175+
# cpp_options style means is NULL or empty string
176+
# exe_info style means off is FALSE
177+
178+
exe_info_style_cpp_options <- function(cpp_options) {
179+
if(is.null(cpp_options)) cpp_options <- list()
180+
names(cpp_options) <- tolower(names(cpp_options))
181+
flags_reported_in_exe_info <- c(
182+
"stan_threads", "stan_mpi", "stan_opencl",
183+
"stan_no_range_checks", "stan_cpp_optims"
184+
)
185+
for (flag in flags_reported_in_exe_info) {
186+
cpp_options[[flag]] <- !(
187+
is.null(cpp_options[[flag]]) || cpp_options[[flag]] == ""
188+
)
189+
}
190+
cpp_options
191+
}
192+
193+
exe_info_reflects_cpp_options <- function(exe_info, cpp_options) {
194+
if (length(exe_info) == 0) {
195+
warning("Recompiling is recommended due to missing exe_info.")
196+
return(TRUE)
197+
}
198+
if (is.null(cpp_options)) return(TRUE)
199+
200+
cpp_options <- exe_info_style_cpp_options(cpp_options)[tolower(names(cpp_options))]
201+
overlap <- names(cpp_options)[names(cpp_options) %in% names(exe_info)]
202+
203+
if (length(overlap) == 0) TRUE else all.equal(
204+
exe_info[overlap],
205+
cpp_options[overlap]
206+
)
207+
}

R/model.R

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,41 +2252,6 @@ CmdStanModel$set("public", name = "expose_functions", value = expose_functions)
22522252

22532253

22542254
# internal ----------------------------------------------------------------
2255-
2256-
assert_valid_opencl <- function(opencl_ids, cpp_options) {
2257-
if (is.null(cpp_options[["stan_opencl"]])
2258-
&& !is.null(opencl_ids)) {
2259-
stop("'opencl_ids' is set but the model was not compiled with for use with OpenCL.",
2260-
"\nRecompile the model with 'cpp_options = list(stan_opencl = TRUE)'",
2261-
call. = FALSE)
2262-
}
2263-
invisible(opencl_ids)
2264-
}
2265-
2266-
assert_valid_threads <- function(threads, cpp_options, multiple_chains = FALSE) {
2267-
threads_arg <- if (multiple_chains) "threads_per_chain" else "threads"
2268-
checkmate::assert_integerish(threads, .var.name = threads_arg,
2269-
null.ok = TRUE, lower = 1, len = 1)
2270-
if (is.null(cpp_options[["stan_threads"]]) || !isTRUE(cpp_options[["stan_threads"]])) {
2271-
if (!is.null(threads)) {
2272-
warning(
2273-
"'", threads_arg, "' is set but the model was not compiled with ",
2274-
"'cpp_options = list(stan_threads = TRUE)' ",
2275-
"so '", threads_arg, "' will have no effect!",
2276-
call. = FALSE
2277-
)
2278-
threads <- NULL
2279-
}
2280-
} else if (isTRUE(cpp_options[["stan_threads"]]) && is.null(threads)) {
2281-
stop(
2282-
"The model was compiled with 'cpp_options = list(stan_threads = TRUE)' ",
2283-
"but '", threads_arg, "' was not set!",
2284-
call. = FALSE
2285-
)
2286-
}
2287-
invisible(threads)
2288-
}
2289-
22902255
assert_valid_stanc_options <- function(stanc_options) {
22912256
i <- 1
22922257
names <- names(stanc_options)
@@ -2313,22 +2278,6 @@ assert_stan_file_exists <- function(stan_file) {
23132278
}
23142279
}
23152280

2316-
cpp_options_to_compile_flags <- function(cpp_options) {
2317-
if (length(cpp_options) == 0) {
2318-
return(NULL)
2319-
}
2320-
cpp_built_options <- c()
2321-
for (i in seq_along(cpp_options)) {
2322-
option_name <- names(cpp_options)[i]
2323-
if (is.null(option_name) || !nzchar(option_name)) {
2324-
cpp_built_options <- c(cpp_built_options, cpp_options[[i]])
2325-
} else {
2326-
cpp_built_options <- c(cpp_built_options, paste0(toupper(option_name), "=", cpp_options[[i]]))
2327-
}
2328-
}
2329-
cpp_built_options
2330-
}
2331-
23322281
include_paths_stanc3_args <- function(include_paths = NULL, standalone_call = FALSE) {
23332282
stancflags <- NULL
23342283
if (!is.null(include_paths)) {
@@ -2385,41 +2334,6 @@ model_variables <- function(stan_file, include_paths = NULL, allow_undefined = F
23852334
variables
23862335
}
23872336

2388-
model_compile_info <- function(exe_file, version) {
2389-
info <- NULL
2390-
if (version > "2.26.1") {
2391-
withr::with_path(
2392-
c(
2393-
toolchain_PATH_env_var(),
2394-
tbb_path()
2395-
),
2396-
ret <- wsl_compatible_run(
2397-
command = wsl_safe_path(exe_file),
2398-
args = "info",
2399-
error_on_status = FALSE
2400-
)
2401-
)
2402-
if (ret$status == 0) {
2403-
info <- list()
2404-
info_raw <- strsplit(strsplit(ret$stdout, "\n")[[1]], "=")
2405-
for (key_val in info_raw) {
2406-
if (length(key_val) > 1) {
2407-
key_val <- trimws(key_val)
2408-
val <- key_val[2]
2409-
if (!is.na(as.logical(val))) {
2410-
val <- as.logical(val)
2411-
}
2412-
info[[toupper(key_val[1])]] <- val
2413-
}
2414-
}
2415-
info[["STAN_VERSION"]] <- paste0(info[["STAN_VERSION_MAJOR"]], ".", info[["STAN_VERSION_MINOR"]], ".", info[["STAN_VERSION_PATCH"]])
2416-
info[["STAN_VERSION_MAJOR"]] <- NULL
2417-
info[["STAN_VERSION_MINOR"]] <- NULL
2418-
info[["STAN_VERSION_PATCH"]] <- NULL
2419-
}
2420-
}
2421-
info
2422-
}
24232337

24242338
is_variables_method_supported <- function(mod) {
24252339
cmdstan_version() >= "2.27.0" && mod$has_stan_file() && file.exists(mod$stan_file())

tests/testthat/helper-custom-expectations.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,11 @@ expect_noninteractive_silent <- function(object) {
100100
rlang::with_interactive(value = FALSE,
101101
expect_silent(object))
102102
}
103+
104+
expect_equal_ignore_order <- function(object, expected, ...) {
105+
object <- expected[sort(names(object))]
106+
expected <- expected[sort(names(expected))]
107+
expect_equal(object, expected, ...)
108+
}
109+
110+
expect_not_true <- function(...) expect_false(isTRUE(...))

tests/testthat/test-cpp_opts.R

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
test_that("parse_exe_info_string works", {
2+
expect_equal_ignore_order(
3+
parse_exe_info_string("
4+
stan_version_major = 2
5+
stan_version_minor = 38
6+
stan_version_patch = 0
7+
STAN_THREADS=false
8+
STAN_MPI=false
9+
STAN_OPENCL=true
10+
STAN_NO_RANGE_CHECKS=false
11+
STAN_CPP_OPTIMS=false
12+
"),
13+
list(
14+
stan_version = "2.38.0",
15+
stan_threads = FALSE,
16+
stan_mpi = FALSE,
17+
stan_opencl = TRUE,
18+
stan_no_range_checks = FALSE,
19+
stan_cpp_optims = FALSE
20+
)
21+
)
22+
})
23+
24+
test_that("validate_cpp_options works", {
25+
expect_equal_ignore_order(
26+
validate_cpp_options(list(
27+
Stan_Threads = TRUE,
28+
STAN_OPENCL = NULL,
29+
aBc = FALSE
30+
)),
31+
list(
32+
stan_threads = TRUE,
33+
stan_opencl = NULL,
34+
abc = FALSE
35+
)
36+
)
37+
expect_warning(validate_cpp_options(list(STAN_OPENCL = FALSE)))
38+
})
39+
40+
41+
test_that("exe_info cpp_options comparison works", {
42+
exe_info_all_flags_off <- exe_info_style_cpp_options(list())
43+
exe_info_all_flags_off[["stan_version"]] <- "35.0.0"
44+
45+
expect_true(exe_info_reflects_cpp_options(
46+
exe_info_all_flags_off,
47+
list()
48+
))
49+
expect_true(exe_info_reflects_cpp_options(
50+
list(stan_opencl = FALSE),
51+
list(stan_opencl = NULL)
52+
))
53+
expect_not_true(exe_info_reflects_cpp_options(
54+
list(stan_opencl = FALSE),
55+
list(stan_opencl = FALSE)
56+
))
57+
expect_not_true(exe_info_reflects_cpp_options(
58+
list(stan_opencl = FALSE, stan_threads = FALSE),
59+
list(stan_opencl = NULL, stan_threads = TRUE)
60+
))
61+
expect_not_true(exe_info_reflects_cpp_options(
62+
list(stan_opencl = FALSE, stan_threads = FALSE),
63+
list(stan_opencl = NULL, stan_threads = TRUE, EXTRA_ARG = TRUE)
64+
))
65+
66+
# no exe_info -> no recompile based on cpp info
67+
expect_warning(
68+
expect_true(exe_info_reflects_cpp_options(list(), list())),
69+
"Recompiling is recommended"
70+
)
71+
})

0 commit comments

Comments
 (0)