1
1
import warnings
2
+ from typing import Optional
2
3
3
4
import torch
4
5
from joblib import Parallel , delayed
@@ -18,6 +19,7 @@ def get_posterior_samples_on_batch(
18
19
num_workers : int = 1 ,
19
20
show_progress_bar : bool = False ,
20
21
use_batched_sampling : bool = True ,
22
+ batch_size : Optional [int ] = None ,
21
23
) -> Tensor :
22
24
"""Get posterior samples for a batch of xs.
23
25
@@ -28,22 +30,37 @@ def get_posterior_samples_on_batch(
28
30
num_workers: number of workers to use for parallelization.
29
31
show_progress_bars: whether to show progress bars.
30
32
use_batched_sampling: whether to use batched sampling if possible.
31
-
33
+ batch_size: batch size for batched sampling. Useful for batched sampling with
34
+ large batches of xs for avoiding memory overflow.
32
35
Returns:
33
36
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
34
37
"""
35
- batch_size = len (xs )
38
+ num_xs = len (xs )
39
+ if batch_size is None :
40
+ batch_size = num_xs
36
41
37
- # Try using batched sampling when implemented.
38
- try :
39
- # has shape (num_samples, batch_size, dim_parameters)
40
- if use_batched_sampling :
41
- posterior_samples = posterior .sample_batched (
42
- sample_shape , x = xs , show_progress_bars = show_progress_bar
42
+ if use_batched_sampling :
43
+ try :
44
+ # distribute the batch of xs into smaller batches
45
+ batched_xs = xs .split (batch_size )
46
+ posterior_samples = torch .cat (
47
+ [ # has shape (num_samples, num_xs, dim_parameters)
48
+ posterior .sample_batched (
49
+ sample_shape , x = xs_batch , show_progress_bars = show_progress_bar
50
+ )
51
+ for xs_batch in batched_xs
52
+ ],
53
+ dim = 1 ,
43
54
)
44
- else :
45
- raise NotImplementedError
46
- except (NotImplementedError , AssertionError ):
55
+ except (NotImplementedError , AssertionError ):
56
+ warnings .warn (
57
+ "Batched sampling not implemented for this posterior. "
58
+ "Falling back to non-batched sampling." ,
59
+ stacklevel = 2 ,
60
+ )
61
+ use_batched_sampling = False
62
+
63
+ if not use_batched_sampling :
47
64
# We need a function with extra training step for new x for VIPosterior.
48
65
def sample_fun (
49
66
posterior : NeuralPosterior , sample_shape : Shape , x : Tensor , seed : int = 0
@@ -57,13 +74,13 @@ def sample_fun(
57
74
if isinstance (posterior , (VIPosterior , MCMCPosterior )):
58
75
warnings .warn (
59
76
"Using non-batched sampling. Depending on the number of different xs "
60
- f"( { batch_size } ) and the number of parallel workers { num_workers } , "
61
- "this might be slow ." ,
77
+ f"( { num_xs } ) and the number of parallel workers { num_workers } , "
78
+ "this might take a lot of time ." ,
62
79
stacklevel = 2 ,
63
80
)
64
81
65
82
# Run in parallel with progress bar.
66
- seeds = torch .randint (0 , 2 ** 32 , (batch_size ,))
83
+ seeds = torch .randint (0 , 2 ** 32 , (num_xs ,))
67
84
outputs = list (
68
85
tqdm (
69
86
Parallel (return_as = "generator" , n_jobs = num_workers )(
@@ -72,7 +89,7 @@ def sample_fun(
72
89
),
73
90
disable = not show_progress_bar ,
74
91
total = len (xs ),
75
- desc = f"Sampling { batch_size } times { sample_shape } posterior samples." ,
92
+ desc = f"Sampling { num_xs } times { sample_shape } posterior samples." ,
76
93
)
77
94
) # (batch_size, num_samples, dim_parameters)
78
95
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
@@ -81,8 +98,8 @@ def sample_fun(
81
98
).permute (1 , 0 , 2 )
82
99
83
100
assert posterior_samples .shape [:2 ] == sample_shape + (
84
- batch_size ,
85
- ), f"""Expected batched posterior samples of shape {
86
- sample_shape + ( batch_size ,)
87
- } got { posterior_samples . shape [: 2 ] } ."""
101
+ num_xs ,
102
+ ), f"""Expected batched posterior samples of shape { sample_shape + ( num_xs ,) } got {
103
+ posterior_samples . shape [: 2 ]
104
+ } ."""
88
105
return posterior_samples
0 commit comments