Skip to content

Hotfix for two issues with GFR feature subset sampling in BCF #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -503,17 +503,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
X_test_raw <- X_test
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)

# Set num_features_subsample to default, ncol(X_train), if not already set
if (is.null(num_features_subsample_mu)) {
num_features_subsample_mu <- ncol(X_train)
}
if (is.null(num_features_subsample_tau)) {
num_features_subsample_tau <- ncol(X_train)
}
if (is.null(num_features_subsample_variance)) {
num_features_subsample_variance <- ncol(X_train)
}

# Convert all input data to matrices if not already converted
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
Z_train <- as.matrix(as.numeric(Z_train))
Expand Down Expand Up @@ -722,6 +711,17 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
variable_weights_variance <- variable_weights_variance / sum(variable_weights_variance)
}

# Set num_features_subsample to default, ncol(X_train), if not already set
if (is.null(num_features_subsample_mu)) {
num_features_subsample_mu <- ncol(X_train)
}
if (is.null(num_features_subsample_tau)) {
num_features_subsample_tau <- ncol(X_train)
}
if (is.null(num_features_subsample_variance)) {
num_features_subsample_variance <- ncol(X_train)
}

# Preliminary runtime checks for probit link
if (probit_outcome_model) {
if (!(length(unique(y_train)) == 2)) {
Expand Down
33 changes: 22 additions & 11 deletions include/stochtree/tree_sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,18 +716,28 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
// Subsample features (if requested)
std::vector<bool> feature_subset(p, true);
if (num_features_subsample < p) {
std::vector<int> feature_indices(p);
std::iota(feature_indices.begin(), feature_indices.end(), 0);
std::vector<int> features_selected(num_features_subsample);
sample_without_replacement<int, double>(
features_selected.data(), variable_weights.data(), feature_indices.data(),
p, num_features_subsample, gen
);
for (int i = 0; i < p; i++) {
feature_subset.at(i) = false;
// Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
int number_nonzero_weights = 0;
for (int j = 0; j < p; j++) {
if (std::abs(variable_weights.at(j)) > kEpsilon) {
number_nonzero_weights++;
}
}
for (const auto& feat : features_selected) {
feature_subset.at(feat) = true;
if (number_nonzero_weights > num_features_subsample) {
// Sample with replacement according to variable_weights
std::vector<int> feature_indices(p);
std::iota(feature_indices.begin(), feature_indices.end(), 0);
std::vector<int> features_selected(num_features_subsample);
sample_without_replacement<int, double>(
features_selected.data(), variable_weights.data(), feature_indices.data(),
p, num_features_subsample, gen
);
for (int i = 0; i < p; i++) {
feature_subset.at(i) = false;
}
for (const auto& feat : features_selected) {
feature_subset.at(feat) = true;
}
}
}

Expand Down Expand Up @@ -782,6 +792,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
* \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon).
* \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via
* their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
* \param num_features_subsample How many features to subsample when running the GFR algorithm.
* \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object.
*/
template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>
Expand Down
16 changes: 8 additions & 8 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,14 +1093,6 @@ def sample(
else:
variable_subset_variance = [i for i in range(X_train.shape[1])]

# Set num_features_subsample to default, ncol(X_train), if not already set
if num_features_subsample_mu is None:
num_features_subsample_mu = X_train.shape[1]
if num_features_subsample_tau is None:
num_features_subsample_tau = X_train.shape[1]
if num_features_subsample_variance is None:
num_features_subsample_variance = X_train.shape[1]

# Determine whether a test set is provided
self.has_test = X_test is not None

Expand Down Expand Up @@ -1498,6 +1490,14 @@ def sample(
# Store propensity score requirements of the BCF forests
self.propensity_covariate = propensity_covariate

# Set num_features_subsample to default, ncol(X_train), if not already set
if num_features_subsample_mu is None:
num_features_subsample_mu = X_train_processed.shape[1]
if num_features_subsample_tau is None:
num_features_subsample_tau = X_train_processed.shape[1]
if num_features_subsample_variance is None:
num_features_subsample_variance = X_train_processed.shape[1]

# Container of variance parameter samples
self.num_gfr = num_gfr
self.num_burnin = num_burnin
Expand Down
4 changes: 2 additions & 2 deletions vignettes/CausalInference.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ sqrt(mean((test_outcome_mean - y_test)^2))

#### MCMC, covariate subset in $\tau(X)$

Here we simulate from the model with the original MCMC sampler, using only covariate $X_1$ in the treatment effect forest.
Here we simulate from the model with the original MCMC sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest.

```{r}
num_gfr <- 0
Expand Down Expand Up @@ -1073,7 +1073,7 @@ sqrt(mean((y_test - test_outcome_mean)^2))

#### Warmstart, covariate subset in $\tau(X)$

Here we simulate from the model with the warm-start sampler, using only covariate $X_1$ in the treatment effect forest.
Here we simulate from the model with the warm-start sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest.

```{r}
num_gfr <- 10
Expand Down
Loading