Skip to content

Commit 1601757

Browse files
committed
fix optimal transport config (#429)
1 parent 57c9ad8 commit 1601757

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ class FlowMatching(InferenceNetwork):
3939
}
4040

4141
OPTIMAL_TRANSPORT_DEFAULT_CONFIG = {
42-
"method": "sinkhorn",
43-
"cost": "euclidean",
42+
"method": "log_sinkhorn",
4443
"regularization": 0.1,
4544
"max_steps": 100,
46-
"tolerance": 1e-4,
45+
"atol": 1e-5,
46+
"rtol": 1e-4,
4747
}
4848

4949
INTEGRATE_DEFAULT_CONFIG = {

bayesflow/utils/optimal_transport/log_sinkhorn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ def log_sinkhorn_plan(x1, x2, regularization: float = 1.0, rtol=1e-5, atol=1e-8,
2727

2828
log_plan = cost / -(regularization * keras.ops.mean(cost) + 1e-16)
2929

30-
if is_symbolic_tensor(log_plan):
31-
return log_plan
32-
3330
def contains_nans(plan):
3431
return keras.ops.any(keras.ops.isnan(plan))
3532

@@ -59,7 +56,7 @@ def do_nothing():
5956
def log_steps():
6057
msg = "Log-Sinkhorn-Knopp converged after {:d} steps."
6158

62-
logging.info(msg, steps)
59+
logging.debug(msg, steps)
6360

6461
def warn_convergence():
6562
marginals = keras.ops.logsumexp(log_plan, axis=0)

0 commit comments

Comments
 (0)