Skip to content

Commit 94a5e37

Browse files
partial entropic fgw solvers
1 parent 8e60257 commit 94a5e37

File tree

3 files changed

+419
-6
lines changed

3 files changed

+419
-6
lines changed

ot/gromov/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
solve_partial_gromov_linesearch,
108108
entropic_partial_gromov_wasserstein,
109109
entropic_partial_gromov_wasserstein2,
110+
entropic_partial_fused_gromov_wasserstein,
111+
entropic_partial_fused_gromov_wasserstein2,
110112
)
111113

112114

@@ -180,4 +182,6 @@
180182
"solve_partial_gromov_linesearch",
181183
"entropic_partial_gromov_wasserstein",
182184
"entropic_partial_gromov_wasserstein2",
185+
"entropic_partial_fused_gromov_wasserstein",
186+
"entropic_partial_fused_gromov_wasserstein2",
183187
]

ot/gromov/_partial.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,3 +1433,380 @@ def entropic_partial_gromov_wasserstein2(
14331433
return log_gw["partial_gw_dist"], log_gw
14341434
else:
14351435
return log_gw["partial_gw_dist"]
1436+
1437+
1438+
def entropic_partial_fused_gromov_wasserstein(
1439+
M,
1440+
C1,
1441+
C2,
1442+
p=None,
1443+
q=None,
1444+
reg=1.0,
1445+
m=None,
1446+
loss_fun="square_loss",
1447+
alpha=0.5,
1448+
G0=None,
1449+
numItermax=1000,
1450+
tol=1e-7,
1451+
symmetric=None,
1452+
log=False,
1453+
verbose=False,
1454+
):
1455+
r"""
1456+
Returns the entropic partial Fused Gromov-Wasserstein transport between
1457+
:math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and
1458+
:math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise
1459+
distance matrix :math:`\mathbf{M}` between node feature matrices.
1460+
1461+
The function solves the following optimization problem:
1462+
1463+
.. math::
1464+
\gamma = \mathop{\arg \min}_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
1465+
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
1466+
\gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
1467+
1468+
.. math::
1469+
s.t. \ \gamma &\geq 0
1470+
1471+
\gamma \mathbf{1} &\leq \mathbf{a}
1472+
1473+
\gamma^T \mathbf{1} &\leq \mathbf{b}
1474+
1475+
\mathbf{1}^T \gamma^T \mathbf{1} = m
1476+
&\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
1477+
1478+
where :
1479+
1480+
- :math:`\mathbf{M}`: metric cost matrix between features across domains
1481+
- :math:`\mathbf{C_1}` is the metric cost matrix in the source space
1482+
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
1483+
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
1484+
- `L`: quadratic loss function
1485+
- :math:`\Omega` is the entropic regularization term,
1486+
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1487+
- `m` is the amount of mass to be transported
1488+
1489+
The formulation of the FGW problem has been proposed in
1490+
:ref:`[24] <references-entropic-partial-fused-gromov-wasserstein>` and the
1491+
partial GW in :ref:`[29] <references-entropic-partial-fused-gromov-wasserstein>`
1492+
1493+
Parameters
1494+
----------
1495+
M : array-like, shape (ns, nt)
1496+
Metric cost matrix between features across domains
1497+
C1 : array-like, shape (ns, ns)
1498+
Metric cost matrix in the source space
1499+
C2 : array-like, shape (nt, nt)
1500+
Metric cost matrix in the target space
1501+
p : array-like, shape (ns,), optional
1502+
Distribution in the source space.
1503+
If let to its default value None, uniform distribution is taken.
1504+
q : array-like, shape (nt,), optional
1505+
Distribution in the target space.
1506+
If let to its default value None, uniform distribution is taken.
1507+
reg: float, optional. Default is 1.
1508+
entropic regularization parameter
1509+
m : float, optional
1510+
Amount of mass to be transported (default:
1511+
:math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
1512+
loss_fun : str, optional
1513+
Loss function used for the solver either 'square_loss' or 'kl_loss'.
1514+
alpha : float, optional
1515+
Trade-off parameter (0 < alpha < 1)
1516+
G0 : array-like, shape (ns, nt), optional
1517+
Initialization of the transportation matrix
1518+
numItermax : int, optional
1519+
Max number of iterations
1520+
tol : float, optional
1521+
Stop threshold on error (>0)
1522+
symmetric : bool, optional
1523+
Either C1 and C2 are to be assumed symmetric or not.
1524+
If let to its default None value, a symmetry test will be conducted.
1525+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
1526+
log : bool, optional
1527+
return log if True
1528+
verbose : bool, optional
1529+
Print information along iterations
1530+
1531+
Returns
1532+
-------
1533+
:math: `gamma` : (dim_a, dim_b) ndarray
1534+
Optimal transportation matrix for the given parameters
1535+
log : dict
1536+
log dictionary returned only if `log` is `True`
1537+
1538+
1539+
.. _references-entropic-partial-fused-gromov-wasserstein:
1540+
References
1541+
----------
1542+
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
1543+
and Courty Nicolas "Optimal Transport for structured data with
1544+
application on graphs", International Conference on Machine Learning
1545+
(ICML). 2019.
1546+
1547+
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
1548+
Transport with Applications on Positive-Unlabeled Learning".
1549+
NeurIPS.
1550+
1551+
See Also
1552+
--------
1553+
ot.gromov.partial_fused_gromov_wasserstein: exact Partial Fused Gromov-Wasserstein
1554+
"""
1555+
1556+
arr = [M, C1, C2, G0]
1557+
if p is not None:
1558+
p = list_to_array(p)
1559+
arr.append(p)
1560+
if q is not None:
1561+
q = list_to_array(q)
1562+
arr.append(q)
1563+
1564+
nx = get_backend(*arr)
1565+
1566+
if p is None:
1567+
p = nx.ones(C1.shape[0], type_as=C1) / C1.shape[0]
1568+
if q is None:
1569+
q = nx.ones(C2.shape[0], type_as=C2) / C2.shape[0]
1570+
1571+
if m is None:
1572+
m = min(nx.sum(p), nx.sum(q))
1573+
elif m < 0:
1574+
raise ValueError("Problem infeasible. Parameter m should be greater" " than 0.")
1575+
elif m > min(nx.sum(p), nx.sum(q)):
1576+
raise ValueError(
1577+
"Problem infeasible. Parameter m should lower or"
1578+
" equal than min(|a|_1, |b|_1)."
1579+
)
1580+
1581+
if G0 is None:
1582+
G0 = (
1583+
nx.outer(p, q) * m / (nx.sum(p) * nx.sum(q))
1584+
) # make sure |G0|=m, G01_m\leq p, G0.T1_n\leq q.
1585+
1586+
else:
1587+
# Check marginals of G0
1588+
assert nx.any(nx.sum(G0, 1) <= p)
1589+
assert nx.any(nx.sum(G0, 0) <= q)
1590+
1591+
if symmetric is None:
1592+
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(
1593+
C2, C2.T, atol=1e-10
1594+
)
1595+
1596+
# Setup gradient computation
1597+
fC1, fC2, hC1, hC2 = _transform_matrix(C1, C2, loss_fun, nx)
1598+
fC2t = fC2.T
1599+
if not symmetric:
1600+
fC1t, hC1t, hC2t = fC1.T, hC1.T, hC2.T
1601+
1602+
ones_p = nx.ones(p.shape[0], type_as=p)
1603+
ones_q = nx.ones(q.shape[0], type_as=q)
1604+
1605+
def f(G):
1606+
pG = nx.sum(G, 1)
1607+
qG = nx.sum(G, 0)
1608+
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
1609+
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
1610+
return alpha * gwloss(constC1 + constC2, hC1, hC2, G, nx) + (
1611+
1 - alpha
1612+
) * nx.sum(G * M)
1613+
1614+
if symmetric:
1615+
1616+
def df(G):
1617+
pG = nx.sum(G, 1)
1618+
qG = nx.sum(G, 0)
1619+
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
1620+
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
1621+
return alpha * gwggrad(constC1 + constC2, hC1, hC2, G, nx) + (
1622+
1 - alpha
1623+
) * nx.sum(G * M)
1624+
else:
1625+
1626+
def df(G):
1627+
pG = nx.sum(G, 1)
1628+
qG = nx.sum(G, 0)
1629+
constC1 = nx.outer(nx.dot(fC1, pG), ones_q)
1630+
constC2 = nx.outer(ones_p, nx.dot(qG, fC2t))
1631+
constC1t = nx.outer(nx.dot(fC1t, pG), ones_q)
1632+
constC2t = nx.outer(ones_p, nx.dot(qG, fC2))
1633+
1634+
return 0.5 * alpha * (
1635+
gwggrad(constC1 + constC2, hC1, hC2, G, nx)
1636+
+ gwggrad(constC1t + constC2t, hC1t, hC2t, G, nx)
1637+
) + (1 - alpha) * nx.sum(G * M)
1638+
1639+
cpt = 0
1640+
err = 1
1641+
1642+
loge = {"err": []}
1643+
1644+
while err > tol and cpt < numItermax:
1645+
Gprev = G0
1646+
M_entr = df(G0)
1647+
G0 = entropic_partial_wasserstein(p, q, M_entr, reg, m)
1648+
if cpt % 10 == 0: # to speed up the computations
1649+
err = np.linalg.norm(G0 - Gprev)
1650+
if log:
1651+
loge["err"].append(err)
1652+
if verbose:
1653+
if cpt % 200 == 0:
1654+
print(
1655+
"{:5s}|{:12s}|{:12s}".format("It.", "Err", "Loss")
1656+
+ "\n"
1657+
+ "-" * 31
1658+
)
1659+
print("{:5d}|{:8e}|{:8e}".format(cpt, err, f(G0)))
1660+
1661+
cpt += 1
1662+
1663+
if log:
1664+
loge["partial_fgw_dist"] = f(G0)
1665+
return G0, loge
1666+
else:
1667+
return G0
1668+
1669+
1670+
def entropic_partial_fused_gromov_wasserstein2(
1671+
M,
1672+
C1,
1673+
C2,
1674+
p=None,
1675+
q=None,
1676+
reg=1.0,
1677+
m=None,
1678+
loss_fun="square_loss",
1679+
alpha=0.5,
1680+
G0=None,
1681+
numItermax=1000,
1682+
tol=1e-7,
1683+
symmetric=None,
1684+
log=False,
1685+
verbose=False,
1686+
):
1687+
r"""
1688+
Returns the entropic partial Fused Gromov-Wasserstein discrepancy between
1689+
:math:`(\mathbf{C_1}, \mathbf{F_1}, \mathbf{p})` and
1690+
:math:`(\mathbf{C_2}, \mathbf{F_2}, \mathbf{q})`, with pairwise
1691+
distance matrix :math:`\mathbf{M}` between node feature matrices.
1692+
1693+
The function solves the following optimization problem:
1694+
1695+
.. math::
1696+
PGW = \min_{\gamma} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
1697+
+ \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot
1698+
\gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma)
1699+
1700+
.. math::
1701+
s.t. \ \gamma &\geq 0
1702+
1703+
\gamma \mathbf{1} &\leq \mathbf{a}
1704+
1705+
\gamma^T \mathbf{1} &\leq \mathbf{b}
1706+
1707+
\mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\}
1708+
1709+
where :
1710+
1711+
- :math:`\mathbf{M}`: metric cost matrix between features across domains
1712+
- :math:`\mathbf{C_1}` is the metric cost matrix in the source space
1713+
- :math:`\mathbf{C_2}` is the metric cost matrix in the target space
1714+
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights
1715+
- `L`: Loss function to account for the misfit between the similarity matrices.
1716+
- :math:`\Omega` is the entropic regularization term,
1717+
:math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
1718+
- `m` is the amount of mass to be transported
1719+
1720+
The formulation of the FGW problem has been proposed in
1721+
:ref:`[24] <references-entropic-partial-fused-gromov-wasserstein2>` and the
1722+
partial GW in :ref:`[29] <references-entropic-partial-fused-gromov-wasserstein2>`
1723+
1724+
Parameters
1725+
----------
1726+
M : array-like, shape (ns, nt)
1727+
Metric cost matrix between features across domains
1728+
C1 : ndarray, shape (ns, ns)
1729+
Metric cost matrix in the source space
1730+
C2 : ndarray, shape (nt, nt)
1731+
Metric cost matrix in the target space
1732+
p : array-like, shape (ns,), optional
1733+
Distribution in the source space.
1734+
If let to its default value None, uniform distribution is taken.
1735+
q : array-like, shape (nt,), optional
1736+
Distribution in the target space.
1737+
If let to its default value None, uniform distribution is taken.
1738+
reg: float
1739+
entropic regularization parameter
1740+
m : float, optional
1741+
Amount of mass to be transported (default:
1742+
:math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`)
1743+
loss_fun : str, optional
1744+
Loss function used for the solver either 'square_loss' or 'kl_loss'.
1745+
alpha : float, optional
1746+
Trade-off parameter (0 < alpha < 1)
1747+
G0 : ndarray, shape (ns, nt), optional
1748+
Initialization of the transportation matrix
1749+
numItermax : int, optional
1750+
Max number of iterations
1751+
tol : float, optional
1752+
Stop threshold on error (>0)
1753+
symmetric : bool, optional
1754+
Either C1 and C2 are to be assumed symmetric or not.
1755+
If let to its default None value, a symmetry test will be conducted.
1756+
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
1757+
log : bool, optional
1758+
return log if True
1759+
verbose : bool, optional
1760+
Print information along iterations
1761+
1762+
1763+
Returns
1764+
-------
1765+
partial_fgw_dist: float
1766+
Partial Entropic Fused Gromov-Wasserstein discrepancy
1767+
log : dict
1768+
log dictionary returned only if `log` is `True`
1769+
1770+
.. _references-entropic-partial-fused-gromov-wasserstein2:
1771+
References
1772+
----------
1773+
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
1774+
and Courty Nicolas "Optimal Transport for structured data with
1775+
application on graphs", International Conference on Machine Learning
1776+
(ICML). 2019.
1777+
1778+
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
1779+
Transport with Applications on Positive-Unlabeled Learning".
1780+
NeurIPS.
1781+
"""
1782+
nx = get_backend(M, C1, C2)
1783+
1784+
T, log_pfgw = entropic_partial_fused_gromov_wasserstein(
1785+
M,
1786+
C1,
1787+
C2,
1788+
p,
1789+
q,
1790+
reg,
1791+
m,
1792+
loss_fun,
1793+
alpha,
1794+
G0,
1795+
numItermax,
1796+
tol,
1797+
symmetric,
1798+
True,
1799+
verbose,
1800+
)
1801+
1802+
log_pfgw["T"] = T
1803+
1804+
# setup for ot.solve_gromov
1805+
lin_term = nx.sum(T * M)
1806+
log_pfgw["quad_loss"] = log_pfgw["partial_fgw_dist"] - (1 - alpha) * lin_term
1807+
log_pfgw["lin_loss"] = lin_term * (1 - alpha)
1808+
1809+
if log:
1810+
return log_pfgw["partial_fgw_dist"], log_pfgw
1811+
else:
1812+
return log_pfgw["partial_fgw_dist"]

0 commit comments

Comments
 (0)