Skip to content

Commit 72ec46d

Browse files
Add as argument with docstring
1 parent 0596236 commit 72ec46d

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

sbi/neural_nets/estimators/flowmatching_estimator.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor:
243243
score = (-(1 - t) * v - input) / (t + self.noise_scale)
244244
return score
245245

246-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
246+
def drift_fn(
247+
self, input: Tensor, times: Tensor, effective_t_max: float = 0.99
248+
) -> Tensor:
247249
r"""Drift function for the flow matching estimator.
248250
249251
The drift function is calculated based on [3]_ (see Equation 7):
@@ -263,16 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
263265
Args:
264266
input: Parameters :math:`\theta_t`.
265267
times: SDE time variable in [0,1].
268+
effective_t_max: Upper bound on time to avoid numerical issues at t=1.
269+
This effectively prevents and explosion of the SDE in the beginning.
270+
Note that this does not affect the ODE sampling, which always uses
271+
times in [0,1].
266272
267273
Returns:
268274
Drift function at a given time.
269275
"""
270276
# analytical f(t) does not depend on noise_scale and is undefined at t = 1.
271-
# NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
272-
# this effectively prevents and explosion of the SDE in the beginning.
273-
return -input / torch.maximum(1 - times, torch.tensor(1e-2).to(input))
277+
return -input / torch.maximum(
278+
1 - times, torch.tensor(1 - effective_t_max).to(input)
279+
)
274280

275-
def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
281+
def diffusion_fn(
282+
self, input: Tensor, times: Tensor, effective_t_max: float = 0.99
283+
) -> Tensor:
276284
r"""Diffusion function for the flow matching estimator.
277285
278286
The diffusion function is calculated based on [3]_ (see Equation 7):
@@ -290,17 +298,19 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
290298
Args:
291299
input: Parameters :math:`\theta_t`.
292300
times: SDE time variable in [0,1].
301+
effective_t_max: Upper bound on time to avoid numerical issues at t=1.
302+
This effectively prevents and explosion of the SDE in the beginning.
303+
Note that this does not affect the ODE sampling, which always uses
304+
times in [0,1].
293305
294306
Returns:
295307
Diffusion function at a given time.
296308
"""
297309
# analytical g(t) is undefined at t = 1.
298-
# NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
299-
# this effectively prevents and explosion of the SDE in the beginning.
300310
return torch.sqrt(
301311
2
302312
* (times + self.noise_scale)
303-
/ torch.maximum(1 - times, torch.tensor(1e-2).to(times))
313+
/ torch.maximum(1 - times, torch.tensor(1 - effective_t_max).to(times))
304314
)
305315

306316
def mean_t_fn(self, times: Tensor) -> Tensor:

0 commit comments

Comments
 (0)