Skip to content

Commit 94a9dcf

Browse files
badisamikemhenry
andauthored
Make MBAR bootstrapping deterministic independent of verbose flag (#551)
* Add failing test * Makes MBAR deterministic independent of verbose flag * Minor clean up * Fix windoze tests * run black --------- Co-authored-by: Mike Henry <11765982+mikemhenry@users.noreply.github.com>
1 parent 7796672 commit 94a9dcf

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

pymbar/mbar.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ def __init__(
186186
We usually just do steps of adaptive sampling without. "robust" would be the backup.
187187
Default: dict(method="adaptive", options=dict(min_sc_iter=0)),
188188
189+
rseed: int or None, optional, default=None
190+
Seed to use when constructing a new RNGState. If rseed is None, will use the global
191+
np.random RNG to generate a seed.
192+
189193
Notes
190194
-----
191195
The reduced potential energy ``u_kn[k,n] = u_k(x_{ln})``, where the reduced potential energy ``u_l(x)`` is
@@ -266,30 +270,34 @@ def __init__(
266270
# verbosity level -- if True, will print extra debug information
267271
self.verbose = verbose
268272

273+
if rseed is None:
274+
rseed = np.random.randint(np.iinfo(np.int32).max)
275+
self.rng = np.random.default_rng(rseed)
276+
269277
# perform consistency checks on the data.
270278

271279
self.samestates = []
280+
# if, for any set of data, all reduced potential energies are the same,
281+
# they are probably the same state.
282+
#
283+
# This can take quite a while, so we do it on just
284+
# the first few data points.
285+
#
286+
# determine if there are less than 50 points to compare energies.
287+
# If so, use that number instead of 50.
288+
# (the number 50 is pretty arbitrary)
289+
290+
maxpoint = 50
291+
if self.N < maxpoint:
292+
maxpoint = self.N
293+
# pick random indices
294+
# indices = np.arange(maxpoint) # can uncomment this if need to remove random choices in testing.
295+
# Done outside of the verbose if statement to ensure deterministic results of bootstrapping, independent of
296+
# the Verbose flag
297+
indices = self.rng.choice(np.arange(self.N), maxpoint)
298+
# this could possibly be made faster with np.unique(axis=0,return_indices=True)
299+
# but not clear if needed.
272300
if self.verbose:
273-
# if, for any set of data, all reduced potential energies are the same,
274-
# they are probably the same state.
275-
#
276-
# This can take quite a while, so we do it on just
277-
# the first few data points.
278-
#
279-
# determine if there are less than 50 points to compare energies.
280-
# If so, use that number instead of 50.
281-
# (the number 50 is pretty arbitrary)
282-
283-
maxpoint = 50
284-
if self.N < maxpoint:
285-
maxpoint = self.N
286-
287-
# this could possibly be made faster with np.unique(axis=0,return_indices=True)
288-
# but not clear if needed.
289-
290-
# pick random indices
291-
# indices = np.arange(maxpoint) # can uncomment this if need to remove random choices in testing.
292-
indices = np.random.choice(np.arange(self.N), maxpoint)
293301
for k in range(K):
294302
for l in range(k):
295303
diffsum = 0
@@ -402,11 +410,6 @@ def __init__(
402410
elif pname == "bootstrap_solver_protocol":
403411
bootstrap_solver_protocol = prot
404412

405-
if rseed != None:
406-
self.rstate = np.random.get_state()
407-
else:
408-
np.random.seed(rseed)
409-
410413
self.f_k = mbar_solvers.solve_mbar_for_all_states(
411414
self.u_kn, self.N_k, self.f_k, self.states_with_samples, solver_protocol
412415
)
@@ -426,7 +429,7 @@ def __init__(
426429
# which of the indices are equal to K
427430
k_indices = np.where(self.x_kindices == k)[0]
428431
# pick new random ones, selected of these K.
429-
new_kindices = k_indices[np.random.randint(int(N_k[k]), size=int(N_k[k]))]
432+
new_kindices = k_indices[self.rng.integers(int(N_k[k]), size=int(N_k[k]))]
430433
rints[k_indices] = new_kindices
431434
# If we initialized with BAR, then BAR, starting from the provided initial_f_k as well.
432435
if initialize == "BAR":

pymbar/tests/test_mbar.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,18 @@ def test_mbar_compute_expectations_inner(mbar_and_test):
528528
u_n = u_kn[:2, :]
529529
state_map = np.array([[0, 0], [1, 0], [2, 0], [2, 1]], int)
530530
_ = mbar.compute_expectations_inner(A_in, u_n, state_map)
531+
532+
533+
@pytest.mark.parametrize("n_bootstrap", [1, 100])
534+
def test_mbar_bootstrap_deterministic_given_same_seed(fixed_harmonic_sample, n_bootstrap):
535+
"""Verify that providing a seed to the mbar bootstrap will produce the same values"""
536+
_, u_kn, _, _ = fixed_harmonic_sample.sample(N_k, mode="u_kn")
537+
538+
mbar_a = MBAR(u_kn, N_k, verbose=True, n_bootstraps=n_bootstrap, rseed=814)
539+
ref_fe_diff = mbar_a.compute_free_energy_differences(uncertainty_method="bootstrap")
540+
541+
# The verbose flag should have no impact on the determinism
542+
mbar_b = MBAR(u_kn, N_k, verbose=False, n_bootstraps=n_bootstrap, rseed=814)
543+
test_fe_diff = mbar_b.compute_free_energy_differences(uncertainty_method="bootstrap")
544+
np.testing.assert_equal(ref_fe_diff["Delta_f"], test_fe_diff["Delta_f"])
545+
np.testing.assert_equal(ref_fe_diff["dDelta_f"], test_fe_diff["dDelta_f"])

0 commit comments

Comments
 (0)