Skip to content

Commit d7ee92a

Browse files
committed
Removed useless unsqueeze for benchmark on Simformer
1 parent 6bf3111 commit d7ee92a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/bm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def train_and_eval_amortized_inference(
196196
new_posterior_latent_idx=torch.arange(0, num_theta),
197197
new_posterior_observed_idx=torch.arange(num_theta, num_theta + num_x),
198198
)
199-
inputs = torch.cat([thetas.unsqueeze(-1), xs.unsqueeze(-1)], dim=1)
199+
inputs = torch.cat([thetas, xs], dim=1)
200200
inference.append_simulations(
201201
inputs,
202202
)

0 commit comments

Comments
 (0)