@@ -426,9 +426,9 @@ def sample_batched(
426
426
# warn if num_chains is larger than num requested samples
427
427
if num_chains > torch .Size (sample_shape ).numel ():
428
428
warnings .warn (
429
- f"""Passed num_chains { num_chains } is larger than the number of
430
- requested samples { torch .Size (sample_shape ).numel ()} , resetting
431
- it to { torch .Size (sample_shape ).numel ()} ."" " ,
429
+ "The passed number of MCMC chains is larger than the number of "
430
+ f" requested samples: { num_chains } > { torch .Size (sample_shape ).numel ()} ,"
431
+ f" resetting it to { torch .Size (sample_shape ).numel ()} ." ,
432
432
stacklevel = 2 ,
433
433
)
434
434
num_chains = torch .Size (sample_shape ).numel ()
@@ -453,12 +453,11 @@ def sample_batched(
453
453
num_chains_extended = batch_size * num_chains
454
454
if num_chains_extended > 100 :
455
455
warnings .warn (
456
- f"""Note that for batched sampling, we use { num_chains } for each
457
- x in the batch. With the given settings, this results in a
458
- large number of chains ({ num_chains_extended } ), This can be
459
- large number of chains ({ num_chains_extended } ), which can be
460
- slow and memory-intensive. Consider reducing the number of
461
- chains.""" ,
456
+ "Note that for batched sampling, we use num_chains many chains for each"
457
+ " x in the batch. With the given settings, this results in a large "
458
+ f"number large number of chains ({ num_chains_extended } ), which can be "
459
+ "slow and memory-intensive for vectorized MCMC. Consider reducing the "
460
+ "number of chains." ,
462
461
stacklevel = 2 ,
463
462
)
464
463
init_strategy_parameters ["num_return_samples" ] = num_chains_extended
0 commit comments