@@ -83,9 +83,9 @@ def __init__(
83
83
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
84
84
"""
85
85
86
- assert (
87
- thetas . shape [ 0 ] == xs . shape [ 0 ] == posterior_samples . shape [ 0 ]
88
- ), "Number of samples must match"
86
+ assert thetas . shape [ 0 ] == xs . shape [ 0 ] == posterior_samples . shape [ 0 ], (
87
+ "Number of samples must match"
88
+ )
89
89
90
90
# set observed data for classification
91
91
self .theta_p = posterior_samples
@@ -283,9 +283,9 @@ def get_statistic_on_observed_data(
283
283
Returns:
284
284
L-C2ST statistic at `x_o`.
285
285
"""
286
- assert (
287
- self . trained_clfs is not None
288
- ), "No trained classifiers found. Run `train_on_observed_data` first."
286
+ assert self . trained_clfs is not None , (
287
+ "No trained classifiers found. Run `train_on_observed_data` first."
288
+ )
289
289
_ , scores = self .get_scores (
290
290
theta_o = theta_o ,
291
291
x_o = x_o ,
@@ -372,9 +372,9 @@ def train_under_null_hypothesis(
372
372
joint_q_perm [:, self .theta_q .shape [1 ] :],
373
373
)
374
374
else :
375
- assert (
376
- self . null_distribution is not None
377
- ), "You need to provide a null distribution"
375
+ assert self . null_distribution is not None , (
376
+ "You need to provide a null distribution"
377
+ )
378
378
theta_p_t = self .null_distribution .sample ((self .theta_p .shape [0 ],))
379
379
theta_q_t = self .null_distribution .sample ((self .theta_p .shape [0 ],))
380
380
x_p_t , x_q_t = self .x_p , self .x_q
@@ -419,9 +419,9 @@ def get_statistics_under_null_hypothesis(
419
419
Run `train_under_null_hypothesis`."
420
420
)
421
421
else :
422
- assert (
423
- len ( self . trained_clfs_null ) == self . num_trials_null
424
- ), "You need one classifier per trial."
422
+ assert len ( self . trained_clfs_null ) == self . num_trials_null , (
423
+ "You need one classifier per trial."
424
+ )
425
425
426
426
probs_null , stats_null = [], []
427
427
for t in tqdm (
@@ -433,9 +433,9 @@ def get_statistics_under_null_hypothesis(
433
433
if self .permutation :
434
434
theta_o_t = theta_o
435
435
else :
436
- assert (
437
- self . null_distribution is not None
438
- ), "You need to provide a null distribution"
436
+ assert self . null_distribution is not None , (
437
+ "You need to provide a null distribution"
438
+ )
439
439
440
440
theta_o_t = self .null_distribution .sample ((theta_o .shape [0 ],))
441
441
0 commit comments