Skip to content

Commit 805cb5c

Browse files
tvayerrflamary
andauthored
[MRG] Nystrom sinkhorn (#742)
* Create _approx_kernel.py * first implem of nystrom * Put everything in low rank + doc * Add: doc, tests for Nystrom and improve computation low rank sinkhorn * Delete test_ny.py * Update RELEASES.md * better names + docs * test+bugfix+examples * Update .gitignore * fix doc and add test * remove docs * doc fix + more coverage * Update test_lowrank.py * remove multiple targets from low rank sinkhorn + little more coverage * Update RELEASES.md * Add in solver + randperm backened + tests * number paper * fix test and plot * fix tf backened + nx.sqrt + doc * fix none seed * improve coverage and fix doc * pass seed none * move to lowrank folder --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent bec2181 commit 805cb5c

File tree

12 files changed

+934
-16
lines changed

12 files changed

+934
-16
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Source Code (MIT):
2020

2121
POT has the following main features:
2222
* A large set of differentiable solvers for optimal transport problems, including:
23-
* Exact linear OT, entropic and quadratic regularized OT,
23+
* Exact linear OT, entropic and quadratic regularized OT,
2424
* Gromov-Wasserstein (GW) distances, Fused GW distances and variants of
2525
quadratic OT,
2626
* Unbalanced and partial OT for different divergences,
@@ -444,3 +444,5 @@ Artificial Intelligence.
444444
[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations.
445445

446446
[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations.
447+
448+
[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019.

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- Backend implementation of `ot.dist` for (PR #701)
2121
- Updated documentation Quickstart guide and User guide with new API (PR #726)
2222
- Fix jax version for auto-grad (PR #732)
23+
- Add Nystrom kernel approximation for Sinkhorn (PR #742)
2324
- Added `ot.solver_1d.linear_circular_ot` and `ot.sliced.linear_sliced_wasserstein_sphere` (PR #736)
2425
- Implement 1d solver for partial optimal transport (PR #741)
2526
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
@@ -48,7 +49,7 @@ This new release contains several new features, starting with
4849
a novel [Gaussian Mixture Model Optimal Transport (GMM-OT)](https://pythonot.github.io/master/gen_modules/ot.gmm.html#examples-using-ot-gmm-gmm-ot-apply-map) solver to compare GMM while enforcing the transport plan to remain a GMM, that benefits from a closed-form solution making it practical for high-dimensional matching problems. We also extended our general unbalanced OT solvers to support any non-negative reference measure in the regularization terms, before adding the novel [translation invariant UOT](https://pythonot.github.io/master/auto_examples/unbalanced-partial/plot_conv_sinkhorn_ti.html) solver showcasing a higher convergence speed. We also implemented several new solvers and enhanced existing ones to perform OT across spaces. These include a [semi-relaxed FGW barycenter](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.html) solver, coupled with new initialization heuristics for the inner divergence computation, to perform graph partitioning or dictionary learning. Followed by novel [unbalanced FGW and Co-optimal transport](https://pythonot.github.io/master/auto_examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.html) solvers to promote robustness to outliers in such matching problems. And we finally updated the implementation of partial GW now supporting asymmetric structures and the KL divergence, while leveraging a new generic conditional gradient solver for partial transport problems enabling significant speed improvements. These latest updates required some modifications to the line search functions of our generic conditional gradient solver, paving the way for future improvements to other GW-based solvers. Last but not least, we implemented a pre-commit scheme to automatically correct common programming mistakes likely to be made by our future contributors.
4950

5051
This release also contains few bug fixes, concerning the support of any metric in `ot.emd_1d` / `ot.emd2_1d`, and the support of any weights in `ot.gaussian`.
51-
52+
5253
#### Breaking change
5354
- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663).
5455

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
============================
4+
Nyström approximation for OT
5+
============================
6+
7+
Shows how to use Nyström kernel approximation for approximating the Sinkhorn algorithm in linear time.
8+
9+
10+
"""
11+
12+
# Author: Titouan Vayer <titouan.vayer@inria.fr>
13+
#
14+
# License: MIT License
15+
16+
# sphinx_gallery_thumbnail_number = 2
17+
18+
import numpy as np
19+
from ot.lowrank import kernel_nystroem, sinkhorn_low_rank_kernel
20+
from ot.bregman import empirical_sinkhorn_nystroem
21+
import math
22+
import ot
23+
import matplotlib.pyplot as plt
24+
from matplotlib.colors import LogNorm
25+
26+
##############################################################################
27+
# Generate data
28+
# -------------
29+
30+
# %%
31+
offset = 1
32+
n_samples_per_blob = 500 # We use 2D ''blobs'' data
33+
random_state = 42
34+
std = 0.2 # standard deviation
35+
np.random.seed(random_state)
36+
37+
centers = np.array(
38+
[
39+
[-offset, -offset], # Class 0 - blob 1
40+
[-offset, offset], # Class 0 - blob 2
41+
[offset, -offset], # Class 1 - blob 1
42+
[offset, offset], # Class 1 - blob 2
43+
]
44+
)
45+
46+
X_list = []
47+
y_list = []
48+
49+
for i, center in enumerate(centers):
50+
blob_points = np.random.randn(n_samples_per_blob, 2) * std + center
51+
label = 0 if i < 2 else 1
52+
X_list.append(blob_points)
53+
y_list.append(np.full(n_samples_per_blob, label))
54+
55+
X = np.vstack(X_list)
56+
y = np.concatenate(y_list)
57+
Xs = X[y == 0] # source data
58+
Xt = X[y == 1] # target data
59+
60+
61+
##############################################################################
62+
# Plot data
63+
# ---------
64+
65+
# %%
66+
plt.scatter(Xs[:, 0], Xs[:, 1], label="Source")
67+
plt.scatter(Xt[:, 0], Xt[:, 1], label="Target")
68+
plt.legend()
69+
70+
##############################################################################
71+
# Compute the Nyström approximation of the Gaussian kernel
72+
# --------------------------------------------------------
73+
74+
# %%
75+
reg = 5.0 # proportional to the std of the Gaussian kernel
76+
anchors = 10 # number of anchor points for the Nyström approximation
77+
ot.tic()
78+
left_factor, right_factor = kernel_nystroem(
79+
Xs, Xt, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state
80+
)
81+
ot.toc()
82+
83+
##############################################################################
84+
# Use this approximation in a Sinkhorn algorithm with low rank kernel.
85+
# Each matrix/vector product in the Sinkhorn is accelerated
86+
# since :math:`Kv = K_1 (K_2^\top v)` can be computed in :math:`O(nr)` time
87+
# instead of :math:`O(n^2)`
88+
89+
# %%
90+
numItermax = 1000
91+
stopThr = 1e-7
92+
verbose = True
93+
a, b = None, None
94+
warn = True
95+
warmstart = None
96+
ot.tic()
97+
u, v, dict_log = sinkhorn_low_rank_kernel(
98+
K1=left_factor,
99+
K2=right_factor,
100+
a=a,
101+
b=b,
102+
numItermax=numItermax,
103+
stopThr=stopThr,
104+
verbose=verbose,
105+
log=True,
106+
warn=warn,
107+
warmstart=warmstart,
108+
)
109+
ot.toc()
110+
##############################################################################
111+
# Compare with Sinkhorn
112+
# ---------------------
113+
114+
# %%
115+
M = ot.dist(Xs, Xt)
116+
ot.tic()
117+
G, log_ = ot.sinkhorn(
118+
a=[],
119+
b=[],
120+
M=M,
121+
reg=reg,
122+
numItermax=numItermax,
123+
verbose=verbose,
124+
log=True,
125+
warn=warn,
126+
warmstart=warmstart,
127+
)
128+
ot.toc()
129+
130+
##############################################################################
131+
# Use directly ot.bregman.empirical_sinkhorn_nystroem
132+
# --------------------------------------------------
133+
134+
# %%
135+
ot.tic()
136+
G_nys = empirical_sinkhorn_nystroem(
137+
Xs,
138+
Xt,
139+
anchors=anchors,
140+
reg=reg,
141+
numItermax=numItermax,
142+
verbose=True,
143+
random_state=random_state,
144+
)[:]
145+
ot.toc()
146+
# %%
147+
ot.tic()
148+
G_sinkh = ot.bregman.empirical_sinkhorn(
149+
Xs, Xt, reg=reg, numIterMax=numItermax, verbose=True
150+
)
151+
ot.toc()
152+
153+
##############################################################################
154+
# Compare OT plans
155+
# ----------------
156+
157+
fig, ax = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
158+
vmin = min(G_sinkh.min(), G_nys.min())
159+
vmax = max(G_sinkh.max(), G_nys.max())
160+
norm = LogNorm(vmin=vmin, vmax=vmax)
161+
im0 = ax[0].imshow(G_sinkh, norm=norm, cmap="coolwarm")
162+
im1 = ax[1].imshow(G_nys, norm=norm, cmap="coolwarm")
163+
cbar = fig.colorbar(im1, ax=ax, orientation="vertical", fraction=0.046, pad=0.04)
164+
ax[0].set_title("OT plan Sinkhorn")
165+
ax[1].set_title("OT plan Nyström Sinkhorn")
166+
for a in ax:
167+
a.set_xticks([])
168+
a.set_yticks([])
169+
plt.show()

ot/backend.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,16 @@ def randn(self, *size, type_as=None):
779779
"""
780780
raise NotImplementedError()
781781

782+
def randperm(self, size, type_as=None):
783+
r"""
784+
Returns a random permutation of integers from 0 to n-1.
785+
786+
This function follows the api from :any:`torch.randperm`
787+
788+
See: https://docs.pytorch.org/docs/stable/generated/torch.randperm.html
789+
"""
790+
raise NotImplementedError()
791+
782792
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
783793
r"""
784794
Creates a sparse tensor in COOrdinate format.
@@ -929,6 +939,16 @@ def inv(self, a):
929939
"""
930940
raise NotImplementedError()
931941

942+
def pinv(self, a, hermitian=False):
943+
r"""
944+
Computes the pseudo inverse of a matrix.
945+
946+
This function follows the api from :any:`numpy.linalg.pinv`.
947+
948+
See: https://numpy.org/devdocs/reference/generated/numpy.linalg.pinv.html
949+
"""
950+
raise NotImplementedError()
951+
932952
def sqrtm(self, a):
933953
r"""
934954
Computes the matrix square root.
@@ -1283,6 +1303,11 @@ def rand(self, *size, type_as=None):
12831303
def randn(self, *size, type_as=None):
12841304
return self.rng_.randn(*size)
12851305

1306+
def randperm(self, size, type_as=None):
1307+
if not isinstance(size, int):
1308+
raise ValueError("size must be an integer")
1309+
return self.rng_.permutation(size)
1310+
12861311
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
12871312
if type_as is None:
12881313
return coo_matrix((data, (rows, cols)), shape=shape)
@@ -1368,6 +1393,9 @@ def trace(self, a):
13681393
def inv(self, a):
13691394
return scipy.linalg.inv(a)
13701395

1396+
def pinv(self, a, hermitian=False):
1397+
return np.linalg.pinv(a, hermitian=hermitian)
1398+
13711399
def sqrtm(self, a):
13721400
L, V = np.linalg.eigh(a)
13731401
L = np.sqrt(L)
@@ -1690,6 +1718,15 @@ def randn(self, *size, type_as=None):
16901718
else:
16911719
return jax.random.normal(subkey, shape=size)
16921720

1721+
def randperm(self, size, type_as=None):
1722+
self.rng_, subkey = jax.random.split(self.rng_)
1723+
if not isinstance(size, int):
1724+
raise ValueError("size must be an integer")
1725+
if type_as is not None:
1726+
return jax.random.permutation(subkey, size).astype(type_as.dtype)
1727+
else:
1728+
return jax.random.permutation(subkey, size)
1729+
16931730
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
16941731
# Currently, JAX does not support sparse matrices
16951732
data = self.to_numpy(data)
@@ -1781,6 +1818,9 @@ def trace(self, a):
17811818
def inv(self, a):
17821819
return jnp.linalg.inv(a)
17831820

1821+
def pinv(self, a, hermitian=False):
1822+
return jnp.linalg.pinv(a, hermitian=hermitian)
1823+
17841824
def sqrtm(self, a):
17851825
L, V = jnp.linalg.eigh(a)
17861826
L = jnp.sqrt(L)
@@ -2161,7 +2201,9 @@ def reshape(self, a, shape):
21612201
return torch.reshape(a, shape)
21622202

21632203
def seed(self, seed=None):
2164-
if isinstance(seed, int):
2204+
if seed is None:
2205+
pass
2206+
elif isinstance(seed, int):
21652207
self.rng_.manual_seed(seed)
21662208
self.rng_cuda_.manual_seed(seed)
21672209
elif isinstance(seed, torch.Generator):
@@ -2200,6 +2242,22 @@ def randn(self, *size, type_as=None):
22002242
else:
22012243
return torch.randn(size=size, generator=self.rng_)
22022244

2245+
def randperm(self, size, type_as=None):
2246+
if not isinstance(size, int):
2247+
raise ValueError("size must be an integer")
2248+
if type_as is not None:
2249+
generator = (
2250+
self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_
2251+
)
2252+
return torch.randperm(
2253+
n=size,
2254+
dtype=type_as.dtype,
2255+
generator=generator,
2256+
device=type_as.device,
2257+
)
2258+
else:
2259+
return torch.randperm(n=size, generator=self.rng_)
2260+
22032261
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
22042262
if type_as is None:
22052263
return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
@@ -2314,6 +2372,9 @@ def trace(self, a):
23142372
def inv(self, a):
23152373
return torch.linalg.inv(a)
23162374

2375+
def pinv(self, a, hermitian=False):
2376+
return torch.linalg.pinv(a, hermitian=hermitian)
2377+
23172378
def sqrtm(self, a):
23182379
L, V = torch.linalg.eigh(a)
23192380
L = torch.sqrt(L)
@@ -2624,6 +2685,15 @@ def randn(self, *size, type_as=None):
26242685
with cp.cuda.Device(type_as.device):
26252686
return self.rng_.randn(*size, dtype=type_as.dtype)
26262687

2688+
def randperm(self, size, type_as=None):
2689+
if not isinstance(size, int):
2690+
raise ValueError("size must be an integer")
2691+
if type_as is None:
2692+
return self.rng_.permutation(size)
2693+
else:
2694+
with cp.cuda.Device(type_as.device):
2695+
return self.rng_.permutation(size).astype(type_as.dtype)
2696+
26272697
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
26282698
data = self.from_numpy(data)
26292699
rows = self.from_numpy(rows)
@@ -2728,6 +2798,9 @@ def trace(self, a):
27282798
def inv(self, a):
27292799
return cp.linalg.inv(a)
27302800

2801+
def pinv(self, a, hermitian=False):
2802+
return cp.linalg.pinv(a)
2803+
27312804
def sqrtm(self, a):
27322805
L, V = cp.linalg.eigh(a)
27332806
L = cp.sqrt(L)
@@ -3048,6 +3121,19 @@ def randn(self, *size, type_as=None):
30483121
else:
30493122
return self.rng_.normal(size, dtype=type_as.dtype)
30503123

3124+
def randperm(self, size, type_as=None):
3125+
if not isinstance(size, int):
3126+
raise ValueError("size must be an integer")
3127+
local_seed = self.rng_.make_seeds(2)[0]
3128+
if type_as is None:
3129+
return tf.random.experimental.stateless_shuffle(
3130+
tf.range(size), seed=local_seed
3131+
)
3132+
else:
3133+
return tf.random.experimental.stateless_shuffle(
3134+
tf.range(size, dtype=type_as.dtype), seed=local_seed
3135+
)
3136+
30513137
def _convert_to_index_for_coo(self, tensor):
30523138
if isinstance(tensor, self.__type__):
30533139
return int(self.max(tensor)) + 1
@@ -3164,6 +3250,9 @@ def trace(self, a):
31643250
def inv(self, a):
31653251
return tf.linalg.inv(a)
31663252

3253+
def pinv(self, a, hermitian=False):
3254+
return tf.linalg.pinv(a)
3255+
31673256
def sqrtm(self, a):
31683257
L, V = tf.linalg.eigh(a)
31693258
L = tf.sqrt(L)

0 commit comments

Comments
 (0)