Skip to content

Commit adf891c

Browse files
authored
Merge pull request #181 from StochasticTree/bcf-feature-subsets-hotfix
Hotfix for two issues with GFR feature subset sampling in BCF
2 parents da60316 + 778d12c commit adf891c

File tree

4 files changed

+43
-32
lines changed

4 files changed

+43
-32
lines changed

R/bcf.R

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -503,17 +503,6 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
503503
X_test_raw <- X_test
504504
if (!is.null(X_test)) X_test <- preprocessPredictionData(X_test, X_train_metadata)
505505

506-
# Set num_features_subsample to default, ncol(X_train), if not already set
507-
if (is.null(num_features_subsample_mu)) {
508-
num_features_subsample_mu <- ncol(X_train)
509-
}
510-
if (is.null(num_features_subsample_tau)) {
511-
num_features_subsample_tau <- ncol(X_train)
512-
}
513-
if (is.null(num_features_subsample_variance)) {
514-
num_features_subsample_variance <- ncol(X_train)
515-
}
516-
517506
# Convert all input data to matrices if not already converted
518507
if ((is.null(dim(Z_train))) && (!is.null(Z_train))) {
519508
Z_train <- as.matrix(as.numeric(Z_train))
@@ -722,6 +711,17 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
722711
variable_weights_variance <- variable_weights_variance / sum(variable_weights_variance)
723712
}
724713

714+
# Set num_features_subsample to default, ncol(X_train), if not already set
715+
if (is.null(num_features_subsample_mu)) {
716+
num_features_subsample_mu <- ncol(X_train)
717+
}
718+
if (is.null(num_features_subsample_tau)) {
719+
num_features_subsample_tau <- ncol(X_train)
720+
}
721+
if (is.null(num_features_subsample_variance)) {
722+
num_features_subsample_variance <- ncol(X_train)
723+
}
724+
725725
# Preliminary runtime checks for probit link
726726
if (probit_outcome_model) {
727727
if (!(length(unique(y_train)) == 2)) {

include/stochtree/tree_sampler.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -716,18 +716,28 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
716716
// Subsample features (if requested)
717717
std::vector<bool> feature_subset(p, true);
718718
if (num_features_subsample < p) {
719-
std::vector<int> feature_indices(p);
720-
std::iota(feature_indices.begin(), feature_indices.end(), 0);
721-
std::vector<int> features_selected(num_features_subsample);
722-
sample_without_replacement<int, double>(
723-
features_selected.data(), variable_weights.data(), feature_indices.data(),
724-
p, num_features_subsample, gen
725-
);
726-
for (int i = 0; i < p; i++) {
727-
feature_subset.at(i) = false;
719+
// Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
720+
int number_nonzero_weights = 0;
721+
for (int j = 0; j < p; j++) {
722+
if (std::abs(variable_weights.at(j)) > kEpsilon) {
723+
number_nonzero_weights++;
724+
}
728725
}
729-
for (const auto& feat : features_selected) {
730-
feature_subset.at(feat) = true;
726+
if (number_nonzero_weights > num_features_subsample) {
727+
// Sample with replacement according to variable_weights
728+
std::vector<int> feature_indices(p);
729+
std::iota(feature_indices.begin(), feature_indices.end(), 0);
730+
std::vector<int> features_selected(num_features_subsample);
731+
sample_without_replacement<int, double>(
732+
features_selected.data(), variable_weights.data(), feature_indices.data(),
733+
p, num_features_subsample, gen
734+
);
735+
for (int i = 0; i < p; i++) {
736+
feature_subset.at(i) = false;
737+
}
738+
for (const auto& feat : features_selected) {
739+
feature_subset.at(feat) = true;
740+
}
731741
}
732742
}
733743

@@ -782,6 +792,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
782792
* \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon).
783793
* \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via
784794
* their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
795+
* \param num_features_subsample How many features to subsample when running the GFR algorithm.
785796
* \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object.
786797
*/
787798
template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatConstructorArgs>

stochtree/bcf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,14 +1093,6 @@ def sample(
10931093
else:
10941094
variable_subset_variance = [i for i in range(X_train.shape[1])]
10951095

1096-
# Set num_features_subsample to default, ncol(X_train), if not already set
1097-
if num_features_subsample_mu is None:
1098-
num_features_subsample_mu = X_train.shape[1]
1099-
if num_features_subsample_tau is None:
1100-
num_features_subsample_tau = X_train.shape[1]
1101-
if num_features_subsample_variance is None:
1102-
num_features_subsample_variance = X_train.shape[1]
1103-
11041096
# Determine whether a test set is provided
11051097
self.has_test = X_test is not None
11061098

@@ -1498,6 +1490,14 @@ def sample(
14981490
# Store propensity score requirements of the BCF forests
14991491
self.propensity_covariate = propensity_covariate
15001492

1493+
# Set num_features_subsample to default, ncol(X_train), if not already set
1494+
if num_features_subsample_mu is None:
1495+
num_features_subsample_mu = X_train_processed.shape[1]
1496+
if num_features_subsample_tau is None:
1497+
num_features_subsample_tau = X_train_processed.shape[1]
1498+
if num_features_subsample_variance is None:
1499+
num_features_subsample_variance = X_train_processed.shape[1]
1500+
15011501
# Container of variance parameter samples
15021502
self.num_gfr = num_gfr
15031503
self.num_burnin = num_burnin

vignettes/CausalInference.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ sqrt(mean((test_outcome_mean - y_test)^2))
949949

950950
#### MCMC, covariate subset in $\tau(X)$
951951

952-
Here we simulate from the model with the original MCMC sampler, using only covariate $X_1$ in the treatment effect forest.
952+
Here we simulate from the model with the original MCMC sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest.
953953

954954
```{r}
955955
num_gfr <- 0
@@ -1073,7 +1073,7 @@ sqrt(mean((y_test - test_outcome_mean)^2))
10731073

10741074
#### Warmstart, covariate subset in $\tau(X)$
10751075

1076-
Here we simulate from the model with the warm-start sampler, using only covariate $X_1$ in the treatment effect forest.
1076+
Here we simulate from the model with the warm-start sampler, using only covariates $X_1$ and $X_2$ in the treatment effect forest.
10771077

10781078
```{r}
10791079
num_gfr <- 10

0 commit comments

Comments
 (0)