Skip to content

Commit b5836e8

Browse files
committed
reintroduce symbolic tensor check in log_sinkhorn
1 parent 6305870 commit b5836e8

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import keras
22

33
from .. import logging
4+
from ..tensor_utils import is_symbolic_tensor
45

56
from .euclidean import euclidean
67

@@ -26,6 +27,9 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2627

2728
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2829

30+
if is_symbolic_tensor(log_plan):
31+
return log_plan
32+
2933
def contains_nans(plan):
3034
return keras.ops.any(keras.ops.isnan(plan))
3135

0 commit comments

Comments
 (0)