@@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor:
243
243
score = (- (1 - t ) * v - input ) / (t + self .noise_scale )
244
244
return score
245
245
246
- def drift_fn (self , input : Tensor , times : Tensor ) -> Tensor :
246
+ def drift_fn (
247
+ self , input : Tensor , times : Tensor , effective_t_max : float = 0.99
248
+ ) -> Tensor :
247
249
r"""Drift function for the flow matching estimator.
248
250
249
251
The drift function is calculated based on [3]_ (see Equation 7):
@@ -263,16 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
263
265
Args:
264
266
input: Parameters :math:`\theta_t`.
265
267
times: SDE time variable in [0,1].
268
+ effective_t_max: Upper bound on time to avoid numerical issues at t=1.
269
+ This effectively prevents and explosion of the SDE in the beginning.
270
+ Note that this does not affect the ODE sampling, which always uses
271
+ times in [0,1].
266
272
267
273
Returns:
268
274
Drift function at a given time.
269
275
"""
270
276
# analytical f(t) does not depend on noise_scale and is undefined at t = 1.
271
- # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
272
- # this effectively prevents and explosion of the SDE in the beginning.
273
- return - input / torch . maximum ( 1 - times , torch . tensor ( 1e-2 ). to ( input ) )
277
+ return - input / torch . maximum (
278
+ 1 - times , torch . tensor ( 1 - effective_t_max ). to ( input )
279
+ )
274
280
275
- def diffusion_fn (self , input : Tensor , times : Tensor ) -> Tensor :
281
+ def diffusion_fn (
282
+ self , input : Tensor , times : Tensor , effective_t_max : float = 0.99
283
+ ) -> Tensor :
276
284
r"""Diffusion function for the flow matching estimator.
277
285
278
286
The diffusion function is calculated based on [3]_ (see Equation 7):
@@ -290,17 +298,19 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
290
298
Args:
291
299
input: Parameters :math:`\theta_t`.
292
300
times: SDE time variable in [0,1].
301
+ effective_t_max: Upper bound on time to avoid numerical issues at t=1.
302
+ This effectively prevents and explosion of the SDE in the beginning.
303
+ Note that this does not affect the ODE sampling, which always uses
304
+ times in [0,1].
293
305
294
306
Returns:
295
307
Diffusion function at a given time.
296
308
"""
297
309
# analytical g(t) is undefined at t = 1.
298
- # NOTE: We bound the singularity to avoid numerical issues i.e. 1 - t > 0.01
299
- # this effectively prevents and explosion of the SDE in the beginning.
300
310
return torch .sqrt (
301
311
2
302
312
* (times + self .noise_scale )
303
- / torch .maximum (1 - times , torch .tensor (1e-2 ).to (times ))
313
+ / torch .maximum (1 - times , torch .tensor (1 - effective_t_max ).to (times ))
304
314
)
305
315
306
316
def mean_t_fn (self , times : Tensor ) -> Tensor :
0 commit comments