@@ -373,3 +373,69 @@ def build_support(
373
373
support = constraints .interval (lower_bound , upper_bound )
374
374
375
375
return support
376
+
377
+
378
+ class OneDimPriorWrapper (Distribution ):
379
+ """Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.
380
+
381
+ 1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
382
+ `.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
383
+ ```python
384
+ dist = torch.distributions.Exponential(torch.tensor(3.0))
385
+ dist.sample((10,)).shape # (10,)
386
+ ```
387
+
388
+ `sbi` will raise an error that the sample dimension is missing. A simple solution is
389
+ to add a batch dimension to `dist` as follows:
390
+ ```python
391
+ dist = torch.distributions.Exponential(torch.tensor([3.0]))
392
+ dist.sample((10,)).shape # (10, 1)
393
+ ```
394
+
395
+ Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
396
+ ```python
397
+ dist = torch.distributions.Exponential(torch.tensor([3.0]))
398
+ samples = dist.sample((10,))
399
+ dist.log_prob(samples).shape # (10, 1)
400
+ ```
401
+
402
+ This will lead to unexpected errors in `sbi`. The point of this class is to wrap
403
+ those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
404
+ """
405
+
406
+ def __init__ (
407
+ self ,
408
+ prior : Distribution ,
409
+ validate_args = None ,
410
+ ) -> None :
411
+ super ().__init__ (
412
+ batch_shape = prior .batch_shape ,
413
+ event_shape = prior .event_shape ,
414
+ validate_args = (
415
+ prior ._validate_args if validate_args is None else validate_args
416
+ ),
417
+ )
418
+ self .prior = prior
419
+
420
+ def sample (self , * args , ** kwargs ) -> Tensor :
421
+ return self .prior .sample (* args , ** kwargs )
422
+
423
+ def log_prob (self , * args , ** kwargs ) -> Tensor :
424
+ """Override the log_prob method to get rid of the additional batch dimension."""
425
+ return self .prior .log_prob (* args , ** kwargs )[..., 0 ]
426
+
427
+ @property
428
+ def arg_constraints (self ) -> Dict [str , constraints .Constraint ]:
429
+ return self .prior .arg_constraints
430
+
431
+ @property
432
+ def support (self ):
433
+ return self .prior .support
434
+
435
+ @property
436
+ def mean (self ) -> Tensor :
437
+ return self .prior .mean
438
+
439
+ @property
440
+ def variance (self ) -> Tensor :
441
+ return self .prior .variance
0 commit comments