@@ -1093,14 +1093,6 @@ def sample(
1093
1093
else :
1094
1094
variable_subset_variance = [i for i in range (X_train .shape [1 ])]
1095
1095
1096
- # Set num_features_subsample to default, ncol(X_train), if not already set
1097
- if num_features_subsample_mu is None :
1098
- num_features_subsample_mu = X_train .shape [1 ]
1099
- if num_features_subsample_tau is None :
1100
- num_features_subsample_tau = X_train .shape [1 ]
1101
- if num_features_subsample_variance is None :
1102
- num_features_subsample_variance = X_train .shape [1 ]
1103
-
1104
1096
# Determine whether a test set is provided
1105
1097
self .has_test = X_test is not None
1106
1098
@@ -1498,6 +1490,14 @@ def sample(
1498
1490
# Store propensity score requirements of the BCF forests
1499
1491
self .propensity_covariate = propensity_covariate
1500
1492
1493
+ # Set num_features_subsample to default, ncol(X_train), if not already set
1494
+ if num_features_subsample_mu is None :
1495
+ num_features_subsample_mu = X_train_processed .shape [1 ]
1496
+ if num_features_subsample_tau is None :
1497
+ num_features_subsample_tau = X_train_processed .shape [1 ]
1498
+ if num_features_subsample_variance is None :
1499
+ num_features_subsample_variance = X_train_processed .shape [1 ]
1500
+
1501
1501
# Container of variance parameter samples
1502
1502
self .num_gfr = num_gfr
1503
1503
self .num_burnin = num_burnin
0 commit comments