Skip to content

Commit 8885980

Browse files
committed
Fix optimal transport indexing for tensorflow
1 parent 1378467 commit 8885980

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
The base probability distribution from which samples are drawn, such as "normal".
8383
Default is "normal".
8484
use_optimal_transport : bool, optional
85-
Whether to apply optimal transport for improved training stability. Default is False.
85+
Whether to apply optimal transport for improved training stability. Default is True.
8686
loss_fn : str, optional
8787
The loss function used for training, such as "mse". Default is "mse".
8888
integrate_kwargs : dict[str, any], optional
@@ -269,7 +269,7 @@ def compute_metrics(
269269
)
270270
if conditions is not None:
271271
# conditions must be resampled along with x1
272-
conditions = conditions[assignments]
272+
conditions = keras.ops.take(conditions, assignments, axis=0)
273273

274274
t = keras.random.uniform((keras.ops.shape(x0)[0],), seed=self.seed_generator)
275275
t = expand_right_as(t, x0)

bayesflow/utils/optimal_transport/optimal_transport.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import keras
2+
13
from .log_sinkhorn import log_sinkhorn
24
from .sinkhorn import sinkhorn
35

@@ -37,7 +39,7 @@ def optimal_transport(x1, x2, method="log_sinkhorn", return_assignments=False, *
3739
x1 and x2 in optimal transport permutation order.
3840
"""
3941
assignments = methods[method.lower()](x1, x2, **kwargs)
40-
x2 = x2[assignments]
42+
x2 = keras.ops.take(x2, assignments, axis=0)
4143

4244
if return_assignments:
4345
return x1, x2, assignments

0 commit comments

Comments
 (0)