Skip to content

Commit 3c6d680

Browse files
committed
refactor: remove deprecated x_shape where not needed.
1 parent 829817b commit 3c6d680

File tree

5 files changed

+1
-10
lines changed

5 files changed

+1
-10
lines changed

sbi/inference/posteriors/ensemble_posterior.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def __init__(
9696
potential_fn=potential_fn,
9797
theta_transform=theta_transform,
9898
device=device,
99-
x_shape=None,
10099
)
101100

102101
def ensure_same_device(self, posteriors: List) -> str:

sbi/inference/trainers/fmpe/fmpe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
npe_msg_on_invalid_x,
2424
validate_theta_and_x,
2525
warn_if_zscoring_changes_data,
26-
x_shape_from_simulation,
2726
)
2827
from sbi.utils.sbiutils import mask_sims_from_prior
2928

@@ -199,7 +198,6 @@ def train(
199198
theta[self.train_indices].to("cpu"),
200199
x[self.train_indices].to("cpu"),
201200
)
202-
self._x_shape = x_shape_from_simulation(x.to("cpu"))
203201

204202
del theta, x
205203

sbi/inference/trainers/npse/npse.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
test_posterior_net_for_multi_d_x,
2727
validate_theta_and_x,
2828
warn_if_zscoring_changes_data,
29-
x_shape_from_simulation,
3029
)
3130
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
3231

@@ -282,7 +281,6 @@ def default_calibration_kernel(x):
282281
theta[self.train_indices].to("cpu"),
283282
x[self.train_indices].to("cpu"),
284283
)
285-
self._x_shape = x_shape_from_simulation(x.to("cpu"))
286284

287285
test_posterior_net_for_multi_d_x(
288286
self._neural_net,

sbi/inference/trainers/nre/nre_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
check_estimator_arg,
2222
check_prior,
2323
clamp_and_warn,
24-
x_shape_from_simulation,
2524
)
2625
from sbi.utils.torchutils import repeat_rows
2726

@@ -203,7 +202,6 @@ def train(
203202
theta[self.train_indices].to("cpu"),
204203
x[self.train_indices].to("cpu"),
205204
)
206-
self._x_shape = x_shape_from_simulation(x.to("cpu"))
207205
del x, theta
208206
self._neural_net.to(self._device)
209207

tests/test_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ def __init__(
219219
potential_fn: Callable,
220220
theta_transform: Optional[TorchTransform] = None,
221221
device: Optional[str] = "cpu",
222-
x_shape: Optional[torch.Size] = None,
223222
):
224223
"""
225224
Args:
@@ -228,10 +227,9 @@ def __init__(
228227
Allows to perform, e.g. MCMC in unconstrained space.
229228
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
230229
`potential_fn.device` is used.
231-
x_shape: Shape of the observed data.
232230
"""
233231
assert isinstance(potential_fn, PosteriorPotential)
234-
super().__init__(potential_fn, theta_transform, device, x_shape)
232+
super().__init__(potential_fn, theta_transform, device)
235233

236234
def sample(
237235
self,

0 commit comments

Comments
 (0)