Skip to content

Commit 1ef8dce

Browse files
merge
2 parents b11f629 + a11b055 commit 1ef8dce

File tree

8 files changed

+125
-118
lines changed

8 files changed

+125
-118
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
88
- Automatic PR labeling and release file update check (PR #704)
99
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
10+
- Implement projected gradient descent solvers for entropic partial FGW (PR #702)
1011

1112
#### Closed issues
1213
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
1314
- Fixed numerical errors in `ot.gmm` (PR #690, Issue #689)
1415
- Add version number to the documentation (PR #696)
1516
- Update doc for default regularization in `ot.unbalanced` sinkhorn solvers (Issue #691, PR #700)
16-
- Clean documentation from `gromov`, `lp` and `unbalanced` folders (PR #710)
17+
- Clean documentation for `gromov`, `lp` and `unbalanced` folders (PR #710)
1718

1819
## 0.9.5
1920

ot/gromov/_lowrank.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def lowrank_gromov_wasserstein_samples(
9292
9393
where :
9494
95-
- :math: `A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain.
96-
- :math: `B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain.
97-
- :math: `\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan.
98-
- :math: `Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
99-
- :math: `g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
95+
- :math:`A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain.
96+
- :math:`B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain.
97+
- :math:`\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan.
98+
- :math:`Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
99+
- :math:`g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
100100
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1).
101-
- :math: `r` is the rank of the Gromov-Wasserstein plan.
102-
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem.
101+
- :math:`r` is the rank of the Gromov-Wasserstein plan.
102+
- :math:`\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem.
103103
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
104104
105105

ot/gromov/_partial.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,18 +1002,18 @@ def solve_partial_gromov_linesearch(
10021002
Parameters
10031003
----------
10041004
1005-
G : array-like, shape(ns,nt)
1005+
G : array-like, shape(ns, nt)
10061006
The transport map at a given iteration of the FW
1007-
deltaG : array-like (ns,nt)
1007+
deltaG : array-like, shape (ns, nt)
10081008
Difference between the optimal map `Gc` found by linearization in the
10091009
FW algorithm and the value at a given iteration
10101010
cost_G : float
10111011
Value of the cost at `G`
1012-
df_G : array-like (ns,nt)
1012+
df_G : array-like, shape (ns, nt)
10131013
Gradient of the GW cost at `G`
1014-
df_Gc : array-like (ns,nt)
1014+
df_Gc : array-like, shape (ns, nt)
10151015
Gradient of the GW cost at `Gc`
1016-
M : array-like (ns,nt)
1016+
M : array-like, shape (ns, nt)
10171017
Cost matrix between the features.
10181018
reg : float
10191019
Regularization parameter.
@@ -1032,7 +1032,7 @@ def solve_partial_gromov_linesearch(
10321032
nb of function call. Useless here
10331033
cost_G : float
10341034
The value of the cost for the next iteration
1035-
df_G : array-like (ns,nt)
1035+
df_G : array-like, shape (ns, nt)
10361036
Updated gradient of the GW cost
10371037
10381038
References

ot/lp/solver_1d.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def emd_1d(
160160
where :
161161
162162
- d is the metric
163-
- x_a and x_b are the samples
163+
- :math:`x_a` and :math:`x_b` are the samples
164164
- a and b are the sample weights
165165
166166
This implementation only supports metrics
@@ -170,21 +170,21 @@ def emd_1d(
170170
171171
Parameters
172172
----------
173-
x_a : (ns,) or (ns, 1) ndarray, float64
173+
x_a : ndarray of float64, shape (ns,) or (ns, 1)
174174
Source dirac locations (on the real line)
175-
x_b : (nt,) or (ns, 1) ndarray, float64
175+
x_b : ndarray of float64, shape (nt,) or (ns, 1)
176176
Target dirac locations (on the real line)
177-
a : (ns,) ndarray, float64, optional
177+
a : ndarray of float64, shape (ns,), optional
178178
Source histogram (default is uniform weight)
179-
b : (nt,) ndarray, float64, optional
179+
b : ndarray of float64, shape (nt,), optional
180180
Target histogram (default is uniform weight)
181181
metric: str, optional (default='sqeuclidean')
182182
Metric to be used. Only works with either of the strings
183183
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
184184
p: float, optional (default=1.0)
185185
The p-norm to apply for if metric='minkowski'
186186
dense: boolean, optional (default=True)
187-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
187+
If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt).
188188
Otherwise returns a sparse representation using scipy's `coo_matrix`
189189
format. Due to implementation details, this function runs faster when
190190
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
@@ -198,7 +198,7 @@ def emd_1d(
198198
199199
Returns
200200
-------
201-
gamma: (ns, nt) ndarray
201+
gamma: ndarray, shape (ns, nt)
202202
Optimal transportation matrix for the given parameters
203203
log: dict
204204
If input log is True, a dictionary containing the cost
@@ -318,7 +318,7 @@ def emd2_1d(
318318
where :
319319
320320
- d is the metric
321-
- x_a and x_b are the samples
321+
- :math:`x_a` and :math:`x_b` are the samples
322322
- a and b are the sample weights
323323
324324
This implementation only supports metrics
@@ -328,21 +328,21 @@ def emd2_1d(
328328
329329
Parameters
330330
----------
331-
x_a : (ns,) or (ns, 1) ndarray, float64
331+
x_a : ndarray of float64, shape (ns,) or (ns, 1)
332332
Source dirac locations (on the real line)
333-
x_b : (nt,) or (ns, 1) ndarray, float64
333+
x_b : ndarray of float64, shape (nt,) or (ns, 1)
334334
Target dirac locations (on the real line)
335-
a : (ns,) ndarray, float64, optional
335+
a : ndarray of float64, shape (ns,), optional
336336
Source histogram (default is uniform weight)
337-
b : (nt,) ndarray, float64, optional
337+
b : ndarray of float64, shape (nt,), optional
338338
Target histogram (default is uniform weight)
339339
metric: str, optional (default='sqeuclidean')
340340
Metric to be used. Only works with either of the strings
341341
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
342342
p: float, optional (default=1.0)
343343
The p-norm to apply for if metric='minkowski'
344344
dense: boolean, optional (default=True)
345-
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
345+
If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt).
346346
Otherwise returns a sparse representation using scipy's `coo_matrix`
347347
format. Only used if log is set to True. Due to implementation details,
348348
this function runs faster when dense is set to False.
@@ -405,9 +405,9 @@ def roll_cols(M, shifts):
405405
406406
Parameters
407407
----------
408-
M : (nr, nc) ndarray
408+
M : ndarray, shape (nr, nc)
409409
Matrix to shift
410-
shifts: int or (nr,) ndarray
410+
shifts: int or ndarray, shape (nr,)
411411
412412
Returns
413413
-------
@@ -1046,7 +1046,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
10461046
10471047
Parameters
10481048
----------
1049-
u_values: ndarray, shape (n, ...)
1049+
u_values : ndarray, shape (n, ...)
10501050
Samples
10511051
u_weights : ndarray, shape (n, ...), optional
10521052
samples weights in the source domain

ot/plot.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def plot1D_mat(
3232
r"""Plot matrix :math:`\mathbf{M}` with the source and target 1D distributions.
3333
3434
Creates a subplot with the source distribution :math:`\mathbf{a}` and target
35-
distribution :math:`\mathbf{b}`t.
35+
distribution :math:`\mathbf{b}`.
3636
In 'yx' mode (default), the source is on the left and
3737
the target on the top, and in 'xy' mode, source on the bottom (upside
3838
down) and the target on the left.
@@ -69,8 +69,9 @@ def plot1D_mat(
6969
ax2 : target plot ax
7070
ax3 : coupling plot ax
7171
72-
.. seealso::
73-
:func:`rescale_for_imshow_plot`
72+
See Also
73+
--------
74+
:func:`rescale_for_imshow_plot`
7475
"""
7576
assert plot_style in ["yx", "xy"], "plot_style should be 'yx' or 'xy'"
7677
na, nb = M.shape
@@ -188,8 +189,9 @@ def rescale_for_imshow_plot(x, y, n, m=None, a_y=None, b_y=None):
188189
yr : ndarray, shape (nx,)
189190
Rescaled y values (due to slicing, may have less elements than y)
190191
191-
.. seealso::
192-
:func:`plot1D_mat`
192+
See Also
193+
--------
194+
:func:`plot1D_mat`
193195
194196
"""
195197
# slice over the y values that are in the y range

ot/solvers.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def solve(
8383
8484
Parameters
8585
----------
86-
M : array_like, shape (dim_a, dim_b)
86+
M : array-like, shape (dim_a, dim_b)
8787
Loss matrix
8888
a : array-like, shape (dim_a,), optional
8989
Samples weights in the source domain (default is uniform)
@@ -92,10 +92,10 @@ def solve(
9292
reg : float, optional
9393
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
9494
OT)
95-
c : array-like (dim_a, dim_b), optional (default=None)
95+
c : array-like, shape (dim_a, dim_b), optional (default=None)
9696
Reference measure for the regularization.
9797
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
98-
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
98+
If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
9999
reg_type : str, optional
100100
Type of regularization :math:`R` either "KL", "L2", "entropy",
101101
by default "KL". a tuple of functions can be provided for general
@@ -120,9 +120,9 @@ def solve(
120120
Number of OMP threads for exact OT solver, by default 1
121121
max_iter : int, optional
122122
Maximum number of iterations, by default None (default values in each solvers)
123-
plan_init : array_like, shape (dim_a, dim_b), optional
123+
plan_init : array-like, shape (dim_a, dim_b), optional
124124
Initialization of the OT plan for iterative methods, by default None
125-
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
125+
potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional
126126
Initialization of the OT dual potentials for iterative methods, by default None
127127
tol : _type_, optional
128128
Tolerance for solution precision, by default None (default values in each solvers)
@@ -632,11 +632,11 @@ def solve_gromov(
632632
633633
Parameters
634634
----------
635-
Ca : array_like, shape (dim_a, dim_a)
635+
Ca : array-like, shape (dim_a, dim_a)
636636
Cost matrix in the source domain
637-
Cb : array_like, shape (dim_b, dim_b)
637+
Cb : array-like, shape (dim_b, dim_b)
638638
Cost matrix in the target domain
639-
M : array_like, shape (dim_a, dim_b), optional
639+
M : array-like, shape (dim_a, dim_b), optional
640640
Linear cost matrix for Fused Gromov-Wasserstein (default is None).
641641
a : array-like, shape (dim_a,), optional
642642
Samples weights in the source domain (default is uniform)
@@ -674,7 +674,7 @@ def solve_gromov(
674674
max_iter : int, optional
675675
Maximum number of iterations, by default None (default values in each
676676
solvers)
677-
plan_init : array_like, shape (dim_a, dim_b), optional
677+
plan_init : array-like, shape (dim_a, dim_b), optional
678678
Initialization of the OT plan for iterative methods, by default None
679679
tol : float, optional
680680
Tolerance for solution precision, by default None (default values in
@@ -1399,10 +1399,10 @@ def solve_sample(
13991399
reg : float, optional
14001400
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
14011401
OT)
1402-
c : array-like (dim_a, dim_b), optional (default=None)
1402+
c : array-like, shape (dim_a, dim_b), optional (default=None)
14031403
Reference measure for the regularization.
14041404
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
1405-
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
1405+
If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
14061406
reg_type : str, optional
14071407
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
14081408
unbalanced : float or indexable object of length 1 or 2
@@ -1431,13 +1431,13 @@ def solve_sample(
14311431
Number of OMP threads for exact OT solver, by default 1
14321432
max_iter : int, optional
14331433
Maximum number of iteration, by default None (default values in each solvers)
1434-
plan_init : array_like, shape (dim_a, dim_b), optional
1434+
plan_init : array-like, shape (dim_a, dim_b), optional
14351435
Initialization of the OT plan for iterative methods, by default None
14361436
rank : int, optional
14371437
Rank of the OT matrix for lazy solers (method='factored'), by default 100
14381438
scaling : float, optional
14391439
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
1440-
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
1440+
potentials_init : (array-like(dim_a,),array-like(dim_b,)), optional
14411441
Initialization of the OT dual potentials for iterative methods, by default None
14421442
tol : _type_, optional
14431443
Tolerance for solution precision, by default None (default values in each solvers)
@@ -1568,7 +1568,7 @@ def solve_sample(
15681568
.. math::
15691569
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
15701570
1571-
with M_{i,j} = d(x_i,y_j)
1571+
\text{with} \ M_{i,j} = d(x_i,y_j)
15721572
15731573
can be solved with the following code:
15741574
@@ -1587,7 +1587,7 @@ def solve_sample(
15871587
.. math::
15881588
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
15891589
1590-
with M_{i,j} = d(x_i,y_j)
1590+
\text{with} \ M_{i,j} = d(x_i,y_j)
15911591
15921592
can be solved with the following code:
15931593

0 commit comments

Comments
 (0)