Skip to content

Commit 2665c14

Browse files
authored
fix: remove empty list mutable default argument (#1608)
1 parent 1f82916 commit 2665c14

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

sbi/inference/posteriors/vi_posterior.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(
6565
vi_method: str = "rKL",
6666
device: Union[str, torch.device] = "cpu",
6767
x_shape: Optional[torch.Size] = None,
68-
parameters: Iterable = [],
69-
modules: Iterable = [],
68+
parameters: Optional[Iterable] = None,
69+
modules: Optional[Iterable] = None,
7070
):
7171
"""
7272
Args:
@@ -140,8 +140,16 @@ def __init__(
140140
else:
141141
self.link_transform = theta_transform.inv
142142

143+
if parameters is None:
144+
parameters = []
145+
if modules is None:
146+
modules = []
143147
# This will set the variational distribution and VI method
144-
self.set_q(q, parameters=parameters, modules=modules)
148+
self.set_q(
149+
q,
150+
parameters=parameters,
151+
modules=modules,
152+
)
145153
self.set_vi_method(vi_method)
146154

147155
self._purpose = (
@@ -214,8 +222,8 @@ def q(
214222
def set_q(
215223
self,
216224
q: Union[str, PyroTransformedDistribution, "VIPosterior", Callable],
217-
parameters: Iterable = [],
218-
modules: Iterable = [],
225+
parameters: Optional[Iterable] = None,
226+
modules: Optional[Iterable] = None,
219227
) -> None:
220228
"""Defines the variational family.
221229
@@ -244,6 +252,10 @@ def set_q(
244252
modules: List of modules associated with the distribution object.
245253
246254
"""
255+
if parameters is None:
256+
parameters = []
257+
if modules is None:
258+
modules = []
247259
self._q_arg = (q, parameters, modules)
248260
if isinstance(q, Distribution):
249261
q = adapt_variational_distribution(

sbi/samplers/vi/vi_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def adapt_variational_distribution(
225225
q: PyroTransformedDistribution,
226226
prior: Distribution,
227227
link_transform: Callable,
228-
parameters: Iterable = [],
229-
modules: Iterable = [],
228+
parameters: Optional[Iterable] = None,
229+
modules: Optional[Iterable] = None,
230230
) -> Distribution:
231231
"""This will adapt a distribution to be compatible with DivergenceOptimizers.
232232
Especially it will make sure that the distribution has parameters and that it
@@ -244,6 +244,10 @@ def adapt_variational_distribution(
244244
TransformedDistribution: Compatible variational distribution.
245245
246246
"""
247+
if parameters is None:
248+
parameters = []
249+
if modules is None:
250+
modules = []
247251

248252
# Extract user define parameters
249253
def parameters_fn():

0 commit comments

Comments
 (0)