1111
1212from modAL .models import ActiveLearner
1313from modAL .utils .data import modALinput , data_vstack
14- from modAL .utils .selection import multi_argmax
14+ from modAL .utils .selection import multi_argmax , shuffled_argmax
1515from modAL .uncertainty import _proba_uncertainty , _proba_entropy
1616
1717
1818def expected_error_reduction (learner : ActiveLearner , X : modALinput , loss : str = 'binary' ,
19- p_subsample : np .float = 1.0 , n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
19+ p_subsample : np .float = 1.0 , n_instances : int = 1 ,
20+ random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
2021 """
2122 Expected error reduction query strategy.
2223
@@ -32,6 +33,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
3233 calculating expected error. Significantly improves runtime
3334 for large sample pools.
3435 n_instances: The number of instances to be sampled.
36+ random_tie_break: If True, shuffles utility scores to randomize the order. This
37+ can be used to break the tie when the highest utility score is not unique.
3538
3639
3740 Returns:
@@ -73,6 +76,9 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
7376 else :
7477 expected_error [x_idx ] = np .inf
7578
76- query_idx = multi_argmax (expected_error , n_instances )
79+ if not random_tie_break :
80+ query_idx = multi_argmax (expected_error , n_instances )
81+ else :
82+ query_idx = shuffled_argmax (expected_error , n_instances )
7783
7884 return query_idx , X [query_idx ]
0 commit comments