@@ -716,18 +716,28 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
716
716
// Subsample features (if requested)
717
717
std::vector<bool > feature_subset (p, true );
718
718
if (num_features_subsample < p) {
719
- std::vector<int > feature_indices (p);
720
- std::iota (feature_indices.begin (), feature_indices.end (), 0 );
721
- std::vector<int > features_selected (num_features_subsample);
722
- sample_without_replacement<int , double >(
723
- features_selected.data (), variable_weights.data (), feature_indices.data (),
724
- p, num_features_subsample, gen
725
- );
726
- for (int i = 0 ; i < p; i++) {
727
- feature_subset.at (i) = false ;
719
+ // Check if the number of (meaningfully) nonzero selection probabilities is greater than num_features_subsample
720
+ int number_nonzero_weights = 0 ;
721
+ for (int j = 0 ; j < p; j++) {
722
+ if (std::abs (variable_weights.at (j)) > kEpsilon ) {
723
+ number_nonzero_weights++;
724
+ }
728
725
}
729
- for (const auto & feat : features_selected) {
730
- feature_subset.at (feat) = true ;
726
+ if (number_nonzero_weights > num_features_subsample) {
727
+ // Sample with replacement according to variable_weights
728
+ std::vector<int > feature_indices (p);
729
+ std::iota (feature_indices.begin (), feature_indices.end (), 0 );
730
+ std::vector<int > features_selected (num_features_subsample);
731
+ sample_without_replacement<int , double >(
732
+ features_selected.data (), variable_weights.data (), feature_indices.data (),
733
+ p, num_features_subsample, gen
734
+ );
735
+ for (int i = 0 ; i < p; i++) {
736
+ feature_subset.at (i) = false ;
737
+ }
738
+ for (const auto & feat : features_selected) {
739
+ feature_subset.at (feat) = true ;
740
+ }
731
741
}
732
742
}
733
743
@@ -782,6 +792,7 @@ static inline void GFRSampleTreeOneIter(Tree* tree, ForestTracker& tracker, Fore
782
792
* \param pre_initialized Whether or not `active_forest` has already been initialized (note: this parameter will be refactored out soon).
783
793
* \param backfitting Whether or not the sampler uses "backfitting" (wherein the sampler for a given tree only depends on the other trees via
784
794
* their effect on the residual) or the more general "blocked MCMC" (wherein the state of other trees must be more explicitly considered).
795
+ * \param num_features_subsample How many features to subsample when running the GFR algorithm.
785
796
* \param leaf_suff_stat_args Any arguments which must be supplied to initialize a `LeafSuffStat` object.
786
797
*/
787
798
template <typename LeafModel, typename LeafSuffStat, typename ... LeafSuffStatConstructorArgs>
0 commit comments