Skip to content

Commit e1305b9

Browse files
authored
style: pin pre commit and ruff to recent versions. (#1358)
* update and pin ruff and pre-commit. * linting and formatting with new ruff * apply ruff to tests and tutorials, ignore long lines.
1 parent 597da91 commit e1305b9

File tree

60 files changed

+362
-370
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+362
-370
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.3.3
3+
rev: v0.9.0
44
hooks:
55
- id: ruff
66
- id: ruff-format
77
args: [--diff]
88
- repo: https://github.com/pre-commit/pre-commit-hooks
9-
rev: v4.5.0
9+
rev: v5.0.0
1010
hooks:
1111
- id: check-added-large-files
1212
- id: check-merge-conflict

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ doc = [
6464
dev = [
6565
"ffmpeg",
6666
# Lint
67-
"pre-commit == 3.5.0",
67+
"pre-commit == 4.0.1",
6868
"pyyaml",
6969
"pyright",
70-
"ruff>=0.3.3",
70+
"ruff==0.9.0",
7171
# Test
7272
"pytest",
7373
"pytest-cov",
@@ -106,6 +106,7 @@ ignore = [
106106
[tool.ruff.lint.extend-per-file-ignores]
107107
"__init__.py" = ["E402", "F401", "F403"] # allow unused imports and undefined names
108108
"test_*.py" = ["F403", "F405"]
109+
"tutorials/*.ipynb" = ["E501"] # allow long lines in notebooks
109110

110111
[tool.ruff.lint.isort]
111112
case-sensitive = true

sbi/analysis/plot.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -775,9 +775,9 @@ def pairplot(
775775

776776
# checks.
777777
if fig_kwargs_filled["legend"]:
778-
assert len(fig_kwargs_filled["samples_labels"]) >= len(
779-
samples
780-
), "Provide at least as many labels as samples."
778+
assert len(fig_kwargs_filled["samples_labels"]) >= len(samples), (
779+
"Provide at least as many labels as samples."
780+
)
781781
if offdiag is not None:
782782
warn("offdiag is deprecated, use upper or lower instead.", stacklevel=2)
783783
upper = offdiag
@@ -1594,9 +1594,9 @@ def _sbc_rank_plot(
15941594
ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore
15951595

15961596
plot_types = ["hist", "cdf"]
1597-
assert (
1598-
plot_type in plot_types
1599-
), "plot type {plot_type} not implemented, use one in {plot_types}."
1597+
assert plot_type in plot_types, (
1598+
"plot type {plot_type} not implemented, use one in {plot_types}."
1599+
)
16001600

16011601
if legend_kwargs is None:
16021602
legend_kwargs = dict(loc="best", handlelength=0.8)
@@ -1609,9 +1609,9 @@ def _sbc_rank_plot(
16091609
params_in_subplots = True
16101610

16111611
for ranki in ranks_list:
1612-
assert (
1613-
ranki.shape == ranks_list[0].shape
1614-
), "all ranks in list must have the same shape."
1612+
assert ranki.shape == ranks_list[0].shape, (
1613+
"all ranks in list must have the same shape."
1614+
)
16151615

16161616
num_rows = int(np.ceil(num_parameters / num_cols))
16171617
if figsize is None:
@@ -1636,9 +1636,9 @@ def _sbc_rank_plot(
16361636
)
16371637
ax = np.atleast_1d(ax) # type: ignore
16381638
else:
1639-
assert (
1640-
ax.size >= num_parameters
1641-
), "There must be at least as many subplots as parameters."
1639+
assert ax.size >= num_parameters, (
1640+
"There must be at least as many subplots as parameters."
1641+
)
16421642
num_rows = ax.shape[0] if ax.ndim > 1 else 1
16431643
assert ax is not None
16441644

@@ -2221,9 +2221,9 @@ def pairplot_dep(
22212221

22222222
# checks.
22232223
if opts["legend"]:
2224-
assert len(opts["samples_labels"]) >= len(
2225-
samples
2226-
), "Provide at least as many labels as samples."
2224+
assert len(opts["samples_labels"]) >= len(samples), (
2225+
"Provide at least as many labels as samples."
2226+
)
22272227
if opts["upper"] is not None:
22282228
opts["offdiag"] = opts["upper"]
22292229

sbi/analysis/sensitivity_analysis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ def train(
250250
prevent exploding gradients. Use `None` for no clipping.
251251
"""
252252

253-
assert (
254-
self._theta is not None and self._emergent_property is not None
255-
), "You must call .add_property() first."
253+
assert self._theta is not None and self._emergent_property is not None, (
254+
"You must call .add_property() first."
255+
)
256256

257257
# Get indices for permutation of the data.
258258
num_examples = len(self._theta)
@@ -433,9 +433,9 @@ def find_directions(
433433
if posterior_log_prob_as_property:
434434
predictions = self._posterior.potential(thetas, track_gradients=True)
435435
else:
436-
assert (
437-
self._regression_net is not None
438-
), "self._regression_net is None, you must call `.train()` first."
436+
assert self._regression_net is not None, (
437+
"self._regression_net is None, you must call `.train()` first."
438+
)
439439
predictions = self._regression_net.forward(thetas)
440440
loss = predictions.mean()
441441
loss.backward()

sbi/diagnostics/lc2st.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ def __init__(
8383
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
8484
"""
8585

86-
assert (
87-
thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0]
88-
), "Number of samples must match"
86+
assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], (
87+
"Number of samples must match"
88+
)
8989

9090
# set observed data for classification
9191
self.theta_p = posterior_samples
@@ -283,9 +283,9 @@ def get_statistic_on_observed_data(
283283
Returns:
284284
L-C2ST statistic at `x_o`.
285285
"""
286-
assert (
287-
self.trained_clfs is not None
288-
), "No trained classifiers found. Run `train_on_observed_data` first."
286+
assert self.trained_clfs is not None, (
287+
"No trained classifiers found. Run `train_on_observed_data` first."
288+
)
289289
_, scores = self.get_scores(
290290
theta_o=theta_o,
291291
x_o=x_o,
@@ -372,9 +372,9 @@ def train_under_null_hypothesis(
372372
joint_q_perm[:, self.theta_q.shape[1] :],
373373
)
374374
else:
375-
assert (
376-
self.null_distribution is not None
377-
), "You need to provide a null distribution"
375+
assert self.null_distribution is not None, (
376+
"You need to provide a null distribution"
377+
)
378378
theta_p_t = self.null_distribution.sample((self.theta_p.shape[0],))
379379
theta_q_t = self.null_distribution.sample((self.theta_p.shape[0],))
380380
x_p_t, x_q_t = self.x_p, self.x_q
@@ -419,9 +419,9 @@ def get_statistics_under_null_hypothesis(
419419
Run `train_under_null_hypothesis`."
420420
)
421421
else:
422-
assert (
423-
len(self.trained_clfs_null) == self.num_trials_null
424-
), "You need one classifier per trial."
422+
assert len(self.trained_clfs_null) == self.num_trials_null, (
423+
"You need one classifier per trial."
424+
)
425425

426426
probs_null, stats_null = [], []
427427
for t in tqdm(
@@ -433,9 +433,9 @@ def get_statistics_under_null_hypothesis(
433433
if self.permutation:
434434
theta_o_t = theta_o
435435
else:
436-
assert (
437-
self.null_distribution is not None
438-
), "You need to provide a null distribution"
436+
assert self.null_distribution is not None, (
437+
"You need to provide a null distribution"
438+
)
439439

440440
theta_o_t = self.null_distribution.sample((theta_o.shape[0],))
441441

sbi/diagnostics/sbc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def run_sbc(
6969
stacklevel=2,
7070
)
7171

72-
assert (
73-
thetas.shape[0] == xs.shape[0]
74-
), "Unequal number of parameters and observations."
72+
assert thetas.shape[0] == xs.shape[0], (
73+
"Unequal number of parameters and observations."
74+
)
7575

7676
if "sbc_batch_size" in kwargs:
7777
warnings.warn(

sbi/diagnostics/tarp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def _run_tarp(
133133
"""
134134
num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape
135135

136-
assert (
137-
references.shape == thetas.shape
138-
), "references must have the same shape as thetas"
136+
assert references.shape == thetas.shape, (
137+
"references must have the same shape as thetas"
138+
)
139139

140140
if num_bins is None:
141141
num_bins = num_tarp_samples // 10

sbi/inference/abc/mcabc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def __call__(
130130
"""
131131

132132
# Exactly one of eps or quantile need to be passed.
133-
assert (eps is not None) ^ (
134-
quantile is not None
135-
), "Eps or quantile must be passed, but not both."
133+
assert (eps is not None) ^ (quantile is not None), (
134+
"Eps or quantile must be passed, but not both."
135+
)
136136
if kde_kwargs is None:
137137
kde_kwargs = {}
138138

sbi/inference/abc/smcabc.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def __init__(
9595
)
9696

9797
kernels = ("gaussian", "uniform")
98-
assert (
99-
kernel in kernels
100-
), f"Kernel '{kernel}' not supported. Choose one from {kernels}."
98+
assert kernel in kernels, (
99+
f"Kernel '{kernel}' not supported. Choose one from {kernels}."
100+
)
101101
self.kernel = kernel
102102

103103
algorithm_variants = ("A", "B", "C")
@@ -198,13 +198,13 @@ def __call__(
198198
if kde_kwargs is None:
199199
kde_kwargs = {}
200200
assert isinstance(epsilon_decay, float) and epsilon_decay > 0.0
201-
assert not (
202-
self.distance.requires_iid_data and lra
203-
), "Currently there is no support to run inference "
201+
assert not (self.distance.requires_iid_data and lra), (
202+
"Currently there is no support to run inference "
203+
)
204204
"on multiple observations together with lra."
205-
assert not (
206-
self.distance.requires_iid_data and sass
207-
), "Currently there is no support to run inference "
205+
assert not (self.distance.requires_iid_data and sass), (
206+
"Currently there is no support to run inference "
207+
)
208208
"on multiple observations together with sass."
209209

210210
# Pilot run for SASS.
@@ -363,9 +363,9 @@ def _set_xo_and_sample_initial_population(
363363
) -> Tuple[Tensor, float, Tensor, Tensor]:
364364
"""Return particles, epsilon and distances of initial population."""
365365

366-
assert (
367-
num_particles <= num_initial_pop
368-
), "number of initial round simulations must be greater than population size"
366+
assert num_particles <= num_initial_pop, (
367+
"number of initial round simulations must be greater than population size"
368+
)
369369

370370
assert (x_o.shape[0] == 1) or self.distance.requires_iid_data, (
371371
"Your data contain iid data-points, but the choice of "

sbi/inference/posteriors/base_posterior.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,7 @@ def __repr__(self):
288288
return desc
289289

290290
def __str__(self):
291-
desc = (
292-
f"Posterior p(θ|x) of type {self.__class__.__name__}. " f"{self._purpose}"
293-
)
291+
desc = f"Posterior p(θ|x) of type {self.__class__.__name__}. {self._purpose}"
294292
return desc
295293

296294
def __getstate__(self) -> Dict:

0 commit comments

Comments
 (0)