@@ -65,8 +65,8 @@ def __init__(
65
65
vi_method : str = "rKL" ,
66
66
device : Union [str , torch .device ] = "cpu" ,
67
67
x_shape : Optional [torch .Size ] = None ,
68
- parameters : Iterable = [] ,
69
- modules : Iterable = [] ,
68
+ parameters : Optional [ Iterable ] = None ,
69
+ modules : Optional [ Iterable ] = None ,
70
70
):
71
71
"""
72
72
Args:
@@ -140,8 +140,16 @@ def __init__(
140
140
else :
141
141
self .link_transform = theta_transform .inv
142
142
143
+ if parameters is None :
144
+ parameters = []
145
+ if modules is None :
146
+ modules = []
143
147
# This will set the variational distribution and VI method
144
- self .set_q (q , parameters = parameters , modules = modules )
148
+ self .set_q (
149
+ q ,
150
+ parameters = parameters ,
151
+ modules = modules ,
152
+ )
145
153
self .set_vi_method (vi_method )
146
154
147
155
self ._purpose = (
@@ -214,8 +222,8 @@ def q(
214
222
def set_q (
215
223
self ,
216
224
q : Union [str , PyroTransformedDistribution , "VIPosterior" , Callable ],
217
- parameters : Iterable = [] ,
218
- modules : Iterable = [] ,
225
+ parameters : Optional [ Iterable ] = None ,
226
+ modules : Optional [ Iterable ] = None ,
219
227
) -> None :
220
228
"""Defines the variational family.
221
229
@@ -244,6 +252,10 @@ def set_q(
244
252
modules: List of modules associated with the distribution object.
245
253
246
254
"""
255
+ if parameters is None :
256
+ parameters = []
257
+ if modules is None :
258
+ modules = []
247
259
self ._q_arg = (q , parameters , modules )
248
260
if isinstance (q , Distribution ):
249
261
q = adapt_variational_distribution (
0 commit comments