Skip to content

Commit c13d114

Browse files
authored
fix assertions of pc.mini_batch_em
1 parent 043e9e3 commit c13d114

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/pyjuice/model/tensorcircuit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def mini_batch_em(self, step_size: float, pseudocount: float = 0.0, keep_zero_pa
494494
:param step_size_rescaling: whether to rescale the step size by flows
495495
:type step_size_rescaling: bool
496496
"""
497-
assert self._cum_flow > 0.0, "Please perform a backward pass before calling `mini_batch_em`."
497+
assert not step_size_rescaling or self._cum_flow > 0.0, "Please perform a backward pass before calling `mini_batch_em`."
498498
assert 0.0 < step_size <= 1.0, "`step_size` should be between 0 and 1."
499499

500500
# Apply step size rescaling according to the mini-batch EM objective derivation

0 commit comments

Comments
 (0)