Skip to content

Commit 42a62c1

Browse files
[FEAT] add the sparsity-constrained optimal transport funtionality and example (#459)
* add sparsity-constrained ot funtionality and example * correct typos; add projection_sparse_simplex * add gradcheck; merge ot.sparse into ot.smooth. * reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type * address pep8 error * add backends for * update releases --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 03ca4ef commit 42a62c1

File tree

7 files changed

+291
-40
lines changed

7 files changed

+291
-40
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
308308
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
309309

310310
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
311+
312+
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

RELEASES.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
## 0.9.1dev
44

55
#### New features
6-
76
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
8-
7+
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
98
#### Closed issues
109

1110
- Fix circleci-redirector action and codecov (PR #460)

examples/plot_OT_1D_smooth.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# -*- coding: utf-8 -*-
22
"""
33
================================
4-
Smooth optimal transport example
4+
Smooth and sparse OT example
55
================================
66
7-
This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
8-
and their visualization.
7+
This example illustrates the computation of
8+
Smooth and Sparse (KL an L2 reg.) OT and
9+
sparsity-constrained OT, together with their visualizations.
910
1011
"""
1112

@@ -58,32 +59,6 @@
5859
pl.figure(2, figsize=(5, 5))
5960
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
6061

61-
##############################################################################
62-
# Solve EMD
63-
# ---------
64-
65-
66-
#%% EMD
67-
68-
G0 = ot.emd(a, b, M)
69-
70-
pl.figure(3, figsize=(5, 5))
71-
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
72-
73-
##############################################################################
74-
# Solve Sinkhorn
75-
# --------------
76-
77-
78-
#%% Sinkhorn
79-
80-
lambd = 2e-3
81-
Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
82-
83-
pl.figure(4, figsize=(5, 5))
84-
ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
85-
86-
pl.show()
8762

8863
##############################################################################
8964
# Solve Smooth OT
@@ -95,18 +70,30 @@
9570
lambd = 2e-3
9671
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
9772

98-
pl.figure(5, figsize=(5, 5))
73+
pl.figure(3, figsize=(5, 5))
9974
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
10075

10176
pl.show()
10277

10378

104-
#%% Smooth OT with KL regularization
79+
#%% Smooth OT with squared l2 regularization
10580

10681
lambd = 1e-1
10782
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
10883

109-
pl.figure(6, figsize=(5, 5))
84+
pl.figure(4, figsize=(5, 5))
11085
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
11186

11287
pl.show()
88+
89+
#%% Sparsity-constrained OT
90+
91+
lambd = 1e-1
92+
93+
max_nz = 2 # two non-zero entries are permitted per column of the OT plan
94+
Gsc = ot.smooth.smooth_ot_dual(
95+
a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz)
96+
pl.figure(5, figsize=(5, 5))
97+
ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.')
98+
99+
pl.show()

ot/smooth.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,42 @@
2424

2525
# Author: Mathieu Blondel
2626
# Remi Flamary <remi.flamary@unice.fr>
27+
# Tianlin Liu <t.liu@unibas.ch>
2728

2829
"""
29-
Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)
30+
Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers.
3031
3132
Implementation of :
3233
Smooth and Sparse Optimal Transport.
3334
Mathieu Blondel, Vivien Seguy, Antoine Rolet.
3435
In Proc. of AISTATS 2018.
3536
https://arxiv.org/abs/1710.06276
3637
38+
(Original code from https://github.com/mblondel/smooth-ot/)
39+
40+
Sparsity-Constrained Optimal Transport.
41+
Liu, T., Puigcerver, J., & Blondel, M. (2023).
42+
Sparsity-constrained optimal transport.
43+
Proceedings of the Eleventh International Conference on
44+
Learning Representations (ICLR).
45+
https://arxiv.org/abs/2209.15466
46+
47+
3748
[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
3849
Transport. Proceedings of the Twenty-First International Conference on
3950
Artificial Intelligence and Statistics (AISTATS).
4051
41-
Original code from https://github.com/mblondel/smooth-ot/
52+
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
53+
Sparsity-constrained optimal transport.
54+
Proceedings of the Eleventh International Conference on
55+
Learning Representations (ICLR).
4256
4357
"""
4458

4559
import numpy as np
4660
from scipy.optimize import minimize
4761
from .backend import get_backend
62+
import ot
4863

4964

5065
def projection_simplex(V, z=1, axis=None):
@@ -209,6 +224,39 @@ def Omega(self, T):
209224
return 0.5 * self.gamma * np.sum(T ** 2)
210225

211226

227+
class SparsityConstrained(Regularization):
228+
""" Squared L2 regularization with sparsity constraints """
229+
230+
def __init__(self, max_nz, gamma=1.0):
231+
self.max_nz = max_nz
232+
self.gamma = gamma
233+
234+
def delta_Omega(self, X):
235+
# For each column of X, find entries that are not among the top max_nz.
236+
non_top_indices = np.argpartition(
237+
-X, self.max_nz, axis=0)[self.max_nz:]
238+
# Set these entries to -inf.
239+
if X.ndim == 1:
240+
X[non_top_indices] = 0.0
241+
else:
242+
X[non_top_indices, np.arange(X.shape[1])] = 0.0
243+
max_X = np.maximum(X, 0)
244+
val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
245+
G = max_X / self.gamma
246+
return val, G
247+
248+
def max_Omega(self, X, b):
249+
# Project the scaled X onto the simplex with sparsity constraint.
250+
G = ot.utils.projection_sparse_simplex(
251+
X / (b * self.gamma), self.max_nz, axis=0)
252+
val = np.sum(X * G, axis=0)
253+
val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
254+
return val, G
255+
256+
def Omega(self, T):
257+
return 0.5 * self.gamma * np.sum(T ** 2)
258+
259+
212260
def dual_obj_grad(alpha, beta, a, b, C, regul):
213261
r"""
214262
Compute objective value and gradients of dual objective.
@@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
435483
return regul.max_Omega(X, b)[1] * b
436484

437485

438-
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
439-
numItermax=500, verbose=False, log=False):
486+
def smooth_ot_dual(a, b, M, reg, reg_type='l2',
487+
method="L-BFGS-B", stopThr=1e-9,
488+
numItermax=500, verbose=False, log=False, max_nz=None):
440489
r"""
441490
Solve the regularized OT problem in the dual and return the OT matrix
442491
@@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
477526
:ref:`[2] <references-smooth-ot-dual>`)
478527
479528
- 'l2' : Squared Euclidean regularization
529+
- 'sparsity_constrained' : Sparsity-constrained regularization [50]
530+
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
531+
not used for other regularization types.
480532
method : str
481533
Solver to use for scipy.optimize.minimize
482534
numItermax : int, optional
@@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
504556
505557
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
506558
559+
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
560+
507561
See Also
508562
--------
509563
ot.lp.emd : Unregularized OT
@@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
518572
regul = SquaredL2(gamma=reg)
519573
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
520574
regul = NegEntropy(gamma=reg)
575+
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
576+
if not isinstance(max_nz, int):
577+
raise ValueError(
578+
f'max_nz {max_nz} must be an integer')
579+
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
521580
else:
522581
raise NotImplementedError('Unknown regularization')
523582

@@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
539598
return G
540599

541600

542-
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
601+
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None,
602+
method="L-BFGS-B", stopThr=1e-9,
543603
numItermax=500, verbose=False, log=False):
544604
r"""
545605
Solve the regularized OT problem in the semi-dual and return the OT matrix
@@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
583643
:ref:`[2] <references-smooth-ot-semi-dual>`)
584644
585645
- 'l2' : Squared Euclidean regularization
646+
- 'sparsity_constrained' : Sparsity-constrained regularization [50]
647+
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
648+
not used for other regularization types.
586649
method : str
587650
Solver to use for scipy.optimize.minimize
588651
numItermax : int, optional
@@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
610673
611674
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
612675
676+
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
677+
613678
See Also
614679
--------
615680
ot.lp.emd : Unregularized OT
@@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
621686
regul = SquaredL2(gamma=reg)
622687
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
623688
regul = NegEntropy(gamma=reg)
689+
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
690+
if not isinstance(max_nz, int):
691+
raise ValueError(
692+
f'max_nz {max_nz} must be an integer')
693+
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
624694
else:
625695
raise NotImplementedError('Unknown regularization')
626696

ot/utils.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import sys
1616
import warnings
1717
from inspect import signature
18-
from .backend import get_backend, Backend, NumpyBackend
18+
from .backend import get_backend, Backend, NumpyBackend, JaxBackend
1919

2020
__time_tic_toc = time.time()
2121

@@ -117,6 +117,85 @@ def proj_simplex(v, z=1):
117117
return w
118118

119119

120+
def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None):
121+
r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.
122+
123+
.. math::
124+
P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2
125+
126+
Parameters
127+
----------
128+
V: 1-dim or 2-dim ndarray
129+
z: float or array
130+
If array, len(z) must be compatible with :math:`\mathbf{V}`
131+
axis: None or int
132+
- axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
133+
- axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
134+
- axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`
135+
136+
Returns
137+
-------
138+
projection: ndarray, shape :math:`\mathbf{V}`.shape
139+
140+
References:
141+
Sparse projections onto the simplex
142+
Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
143+
ICML 2013
144+
https://arxiv.org/abs/1206.1529
145+
"""
146+
if nx is None:
147+
nx = get_backend(V)
148+
if V.ndim == 1:
149+
return projection_sparse_simplex(
150+
# V[nx.newaxis, :], max_nz, z, axis=1).ravel()
151+
V[None, :], max_nz, z, axis=1).ravel()
152+
153+
if V.ndim > 2:
154+
raise ValueError('V.ndim must be <= 2')
155+
156+
if axis == 1:
157+
# For each row of V, find top max_nz values; arrange the
158+
# corresponding column indices such that their values are
159+
# in a descending order.
160+
max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:]
161+
max_nz_indices = nx.flip(max_nz_indices, axis=1)
162+
163+
row_indices = nx.arange(V.shape[0])
164+
row_indices = row_indices.reshape(-1, 1)
165+
print(row_indices.shape)
166+
# Extract the top max_nz values for each row
167+
# and then project to simplex.
168+
U = V[row_indices, max_nz_indices]
169+
z = nx.ones(len(U)) * z
170+
cssv = nx.cumsum(U, axis=1) - z[:, None]
171+
ind = nx.arange(max_nz) + 1
172+
cond = U - cssv / ind > 0
173+
# rho = nx.count_nonzero(cond, axis=1)
174+
rho = nx.sum(cond, axis=1)
175+
theta = cssv[nx.arange(len(U)), rho - 1] / rho
176+
nz_projection = nx.maximum(U - theta[:, None], 0)
177+
178+
# Put the projection of max_nz_values to their original column indices
179+
# while keeping other values zero.
180+
sparse_projection = nx.zeros(V.shape, type_as=nz_projection)
181+
182+
if isinstance(nx, JaxBackend):
183+
# in Jax, we need to use the `at` property of `jax.numpy.ndarray`
184+
# to do in-place array modificatons.
185+
sparse_projection = sparse_projection.at[
186+
row_indices, max_nz_indices].set(nz_projection)
187+
else:
188+
sparse_projection[row_indices, max_nz_indices] = nz_projection
189+
return sparse_projection
190+
191+
elif axis == 0:
192+
return projection_sparse_simplex(V.T, max_nz, z, axis=1).T
193+
194+
else:
195+
V = V.ravel().reshape(1, -1)
196+
return projection_sparse_simplex(V, max_nz, z, axis=1).ravel()
197+
198+
120199
def unif(n, type_as=None):
121200
r"""
122201
Return a uniform histogram of length `n` (simplex).

0 commit comments

Comments
 (0)