Skip to content

Commit 043e9e3

Browse files
committed
fix conditional sampling
1 parent f78ec57 commit 043e9e3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/pyjuice/queries/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo
307307
if not conditional:
308308
assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling."
309309
else:
310-
num_samples = pc.node_mars.size(0) # Reuse the batch size
310+
num_samples = pc.node_mars.size(1) # Reuse the batch size
311311

312312
root_ns = pc.root_ns
313313
assert root_ns._output_ind_range[1] - root_ns._output_ind_range[0] == 1, "It is ambiguous to sample from multi-head PCs."

0 commit comments

Comments
 (0)