From d74217d350857fb60f477ddae45fc2f3ef6d1f1f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 25 Jun 2025 15:23:39 -0500 Subject: [PATCH 1/2] Fixed two issues with GFR feature subset sampling in BCF and updated the causal inference R vignette --- R/bcf.R | 22 ++++++++++----------- include/stochtree/tree_sampler.h | 33 +++++++++++++++++++++----------- vignettes/CausalInference.Rmd | 4 ++-- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 0b266909..acede05d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -503,17 +503,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id X_test_raw <- X_test if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata) - # Set num_features_subsample to default, ncol(X_train), if not already set - if (is.null(num_features_subsample_mu)) { - num_features_subsample_mu <- ncol(X_train) - } - if (is.null(num_features_subsample_tau)) { - num_features_subsample_tau <- ncol(X_train) - } - if (is.null(num_features_subsample_variance)) { - num_features_subsample_variance <- ncol(X_train) - } - # Convert all input data to matrices if not already converted if ((is.null(dim(Z_train))) && (!is.null(Z_train))) { Z_train <- as.matrix(as.numeric(Z_train)) @@ -722,6 +711,17 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id variable_weights_variance <- variable_weights_variance / sum(variable_weights_variance) } + # Set num_features_subsample to default, ncol(X_train), if not already set + if (is.null(num_features_subsample_mu)) { + num_features_subsample_mu <- ncol(X_train) + } + if (is.null(num_features_subsample_tau)) { + num_features_subsample_tau <- ncol(X_train) + } + if (is.null(num_features_subsample_variance)) { + num_features_subsample_variance <- ncol(X_train) + } + # Preliminary runtime checks for probit link if (probit_outcome_model) { if (!(length(unique(y_train)) == 2)) { diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index f07ed79f..8810b938 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -716,18 +716,28 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore // Subsample features (if requested) std::vector feature_subset(p, true); if (num_features_subsample < p) { - std::vector feature_indices(p); - std::iota(feature_indices.begin(), feature_indices.end(), 0); - std::vector features_selected(num_features_subsample); - sample_without_replacement( - features_selected.data(), variable_weights.data(), feature_indices.data(), - p, num_features_subsample, gen - ); - for (int i = 0; i < p; i++) { - feature_subset.at(i) = false; + // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample + int number_nonzero_weights = 0; + for (int j = 0; j < p; j++) { + if (std::abs(variable_weights.at(j)) > kEpsilon) { + number_nonzero_weights++; + } } - for (const auto& feat : features_selected) { - feature_subset.at(feat) = true; + if (number_nonzero_weights > num_features_subsample) { + // Sample with replacement according to variable_weights + std::vector feature_indices(p); + std::iota(feature_indices.begin(), feature_indices.end(), 0); + std::vector features_selected(num_features_subsample); + sample_without_replacement( + features_selected.data(), variable_weights.data(), feature_indices.data(), + p, num_features_subsample, gen + ); + for (int i = 0; i < p; i++) { + feature_subset.at(i) = false; + } + for (const auto& feat : features_selected) { + feature_subset.at(feat) = true; + } } } @@ -782,6 +792,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore * \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon). * \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via * their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered). + * \param num_features_subsample How many features to subsample when running the GFR algorithm. * \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object. */ template diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index 50f08069..166b24dd 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -949,7 +949,7 @@ sqrt(mean((test_outcome_mean - y_test)^2)) #### MCMC, covariate subset in $\tau(X)$ -Here we simulate from the model with the original MCMC sampler, using only covariate $X_1$ in the treatment effect forest. +Here we simulate from the model with the original MCMC sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest. ```{r} num_gfr <- 0 @@ -1073,7 +1073,7 @@ sqrt(mean((y_test - test_outcome_mean)^2)) #### Warmstart, covariate subset in $\tau(X)$ -Here we simulate from the model with the warm-start sampler, using only covariate $X_1$ in the treatment effect forest. +Here we simulate from the model with the warm-start sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest. ```{r} num_gfr <- 10 From 778d12c562cabd9aa0cb65e1851918fc543782c4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 25 Jun 2025 15:34:49 -0500 Subject: [PATCH 2/2] Reflected same change through to python BCF --- stochtree/bcf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stochtree/bcf.py b/stochtree/bcf.py index 1fa7ac00..0ccd9e31 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -1093,14 +1093,6 @@ def sample( else: variable_subset_variance = [i for i in range(X_train.shape[1])] - # Set num_features_subsample to default, ncol(X_train), if not already set - if num_features_subsample_mu is None: - num_features_subsample_mu = X_train.shape[1] - if num_features_subsample_tau is None: - num_features_subsample_tau = X_train.shape[1] - if num_features_subsample_variance is None: - num_features_subsample_variance = X_train.shape[1] - # Determine whether a test set is provided self.has_test = X_test is not None @@ -1498,6 +1490,14 @@ def sample( # Store propensity score requirements of the BCF forests self.propensity_covariate = propensity_covariate + # Set num_features_subsample to default, ncol(X_train), if not already set + if num_features_subsample_mu is None: + num_features_subsample_mu = X_train_processed.shape[1] + if num_features_subsample_tau is None: + num_features_subsample_tau = X_train_processed.shape[1] + if num_features_subsample_variance is None: + num_features_subsample_variance = X_train_processed.shape[1] + # Container of variance parameter samples self.num_gfr = num_gfr self.num_burnin = num_burnin