Skip to content

Commit 07d6bfa

Browse files
authored
[WIP] Sliced OT Plans (#757)
* first implementation of sliced ot plans with example * tests + temperature option in expected-sliced * fixed cell rendering in example * ref number update * skip jax and tf in expected sliced testing due to array assignment * raise NotImplementedError when expected_sliced is used with tf or jax
1 parent 85113e9 commit 07d6bfa

File tree

6 files changed

+583
-3
lines changed

6 files changed

+583
-3
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ POT provides the following generic OT solvers:
7272
* Fused unbalanced Gromov-Wasserstein [70].
7373
* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77]
7474
* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77]
75+
* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84]
7576

7677
POT provides the following Machine Learning related solvers:
7778

@@ -449,5 +450,8 @@ Artificial Intelligence.
449450

450451
[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).
451452

453+
[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385.
452454

453-
```
455+
[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661.
456+
457+
[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations.

RELEASES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Releases
22

3+
## 0.9.7dev
4+
5+
#### New features
6+
7+
- Added Sliced OT plans (PR #757)
8+
39
## 0.9.6.post1
410

511
*September 2025*
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===============
4+
Sliced OT Plans
5+
===============
6+
7+
Compares different Sliced OT plans between two 2D point clouds. The min-Pivot
8+
Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both
9+
were further studied theoretically in [83].
10+
11+
.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.
12+
13+
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661.
14+
15+
.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations.
16+
"""
17+
18+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
19+
# License: MIT License
20+
21+
# sphinx_gallery_thumbnail_number = 1
22+
23+
##############################################################################
24+
# Setup data and imports
25+
# ----------------------
26+
import numpy as np
27+
import ot
28+
import matplotlib.pyplot as plt
29+
from ot.sliced import get_random_projections
30+
31+
seed = 0
32+
np.random.seed(seed)
33+
n = 10
34+
d = 2
35+
X = np.random.randn(n, 2)
36+
Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :]
37+
n_proj = 20
38+
thetas = get_random_projections(d, n_proj).T
39+
alpha = 0.3
40+
41+
##############################################################################
42+
# Compute min-Pivot Sliced permutation
43+
# ------------------------------------
44+
min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True)
45+
min_plan = np.zeros((n, n))
46+
min_plan[np.arange(n), min_perm] = 1 / n
47+
48+
##############################################################################
49+
# Compute Expected Sliced Plan
50+
# ------------------------------------
51+
expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True)
52+
53+
##############################################################################
54+
# Compute 2-Wasserstein Plan
55+
# ------------------------------------
56+
a = np.ones(n, device=X.device) / n
57+
dists = ot.dist(X, Y)
58+
W2 = ot.emd2(a, a, dists)
59+
W2_plan = ot.emd(a, a, dists)
60+
61+
##############################################################################
62+
# Plot resulting assignments
63+
# ------------------------------------
64+
fig, axs = plt.subplots(2, 3, figsize=(12, 4))
65+
fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16)
66+
67+
# draw min sliced permutation
68+
axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}")
69+
for i in range(n):
70+
axs[0, 0].plot(
71+
[X[i, 0], Y[min_perm[i], 0]],
72+
[X[i, 1], Y[min_perm[i], 1]],
73+
color="black",
74+
alpha=alpha,
75+
label="min-Sliced perm" if i == 0 else None,
76+
)
77+
axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues")
78+
79+
# draw expected sliced plan
80+
axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}")
81+
for i in range(n):
82+
for j in range(n):
83+
w = alpha * expected_plan[i, j].item() * n
84+
axs[0, 1].plot(
85+
[X[i, 0], Y[j, 0]],
86+
[X[i, 1], Y[j, 1]],
87+
color="black",
88+
alpha=w,
89+
label="Expected Sliced plan" if i == 0 and j == 0 else None,
90+
)
91+
axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues")
92+
93+
# draw W2 plan
94+
axs[0, 2].set_title(f"W2: cost={W2:.2f}")
95+
for i in range(n):
96+
for j in range(n):
97+
w = alpha * W2_plan[i, j].item() * n
98+
axs[0, 2].plot(
99+
[X[i, 0], Y[j, 0]],
100+
[X[i, 1], Y[j, 1]],
101+
color="black",
102+
alpha=w,
103+
label="W2 plan" if i == 0 and j == 0 else None,
104+
)
105+
axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues")
106+
107+
for ax in axs[0, :]:
108+
ax.scatter(X[:, 0], X[:, 1], label="X")
109+
ax.scatter(Y[:, 0], Y[:, 1], label="Y")
110+
111+
for ax in axs.flatten():
112+
ax.set_aspect("equal")
113+
ax.set_xticks([])
114+
ax.set_yticks([])
115+
116+
fig.tight_layout()
117+
118+
##############################################################################
119+
# Compare Expected Sliced plans with different inverse-temperatures beta
120+
# ------------------------------------
121+
## As the temperature decreases, ES becomes sparser and approaches minPS
122+
betas = [0.0, 5.0, 50.0]
123+
n_plots = len(betas) + 1
124+
size = 4
125+
fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size))
126+
fig.suptitle(
127+
"Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16
128+
)
129+
for beta_idx, beta in enumerate(betas):
130+
expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta)
131+
print(f"beta={beta}: cost={expected_cost:.2f}")
132+
133+
axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}")
134+
for i in range(n):
135+
for j in range(n):
136+
w = alpha * expected_plan[i, j].item() * n
137+
axs[0, beta_idx].plot(
138+
[X[i, 0], Y[j, 0]],
139+
[X[i, 1], Y[j, 1]],
140+
color="black",
141+
alpha=w,
142+
label="Expected Sliced plan" if i == 0 and j == 0 else None,
143+
)
144+
145+
axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X")
146+
axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y")
147+
axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues")
148+
149+
# draw min sliced permutation (limit when beta -> +inf)
150+
axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}")
151+
for i in range(n):
152+
axs[0, -1].plot(
153+
[X[i, 0], Y[min_perm[i], 0]],
154+
[X[i, 1], Y[min_perm[i], 1]],
155+
color="black",
156+
alpha=alpha,
157+
label="min-Sliced perm" if i == 0 else None,
158+
)
159+
axs[0, -1].scatter(X[:, 0], X[:, 1], label="X")
160+
axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y")
161+
axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues")
162+
163+
for ax in axs.flatten():
164+
ax.set_aspect("equal")
165+
ax.set_xticks([])
166+
ax.set_yticks([])
167+
168+
fig.tight_layout()

ot/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
sliced_wasserstein_sphere,
5959
sliced_wasserstein_sphere_unif,
6060
linear_sliced_wasserstein_sphere,
61+
min_pivot_sliced,
62+
expected_sliced,
6163
)
6264
from .gromov import (
6365
gromov_wasserstein,
@@ -109,6 +111,8 @@
109111
"sliced_wasserstein_distance",
110112
"sliced_wasserstein_sphere",
111113
"linear_sliced_wasserstein_sphere",
114+
"min_pivot_sliced",
115+
"expected_sliced",
112116
"gromov_wasserstein",
113117
"gromov_wasserstein2",
114118
"gromov_barycenters",

0 commit comments

Comments
 (0)