Skip to content

Commit fff2463

Browse files
authored
Merge branch 'master' into partial-W-and-GW
2 parents 9f63ee9 + 4cd4e09 commit fff2463

File tree

7 files changed

+102
-48
lines changed

7 files changed

+102
-48
lines changed

_config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
theme: jekyll-theme-slate

ot/bregman.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Titouan Vayer <titouan.vayer@irisa.fr>
1010
# Hicham Janati <hicham.janati@inria.fr>
1111
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
12+
# Alexander Tong <alexander.tong@yale.edu>
1213
#
1314
# License: MIT License
1415

@@ -1346,12 +1347,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13461347
err = 1
13471348

13481349
# build the convolution operator
1350+
# this is equivalent to blurring on horizontal then vertical directions
13491351
t = np.linspace(0, 1, A.shape[1])
13501352
[Y, X] = np.meshgrid(t, t)
13511353
xi1 = np.exp(-(X - Y)**2 / reg)
13521354

1355+
t = np.linspace(0, 1, A.shape[2])
1356+
[Y, X] = np.meshgrid(t, t)
1357+
xi2 = np.exp(-(X - Y)**2 / reg)
1358+
13531359
def K(x):
1354-
return np.dot(np.dot(xi1, x), xi1)
1360+
return np.dot(np.dot(xi1, x), xi2)
13551361

13561362
while (err > stopThr and cpt < numItermax):
13571363

ot/gromov.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
433433
434434
where :
435435
- M is the (ns,nt) metric cost matrix
436-
- :math:`f` is the regularization term ( and df is its gradient)
437-
- a and b are source and target weights (sum to 1)
436+
- p and q are source and target weights (sum to 1)
438437
- L is a loss function to account for the misfit between the similarity matrices
439438
440439
The algorithm used for solving the problem is conditional gradient as discussed in [24]_
@@ -453,17 +452,13 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
453452
Distribution in the target space
454453
loss_fun : str, optional
455454
Loss function used for the solver
456-
max_iter : int, optional
457-
Max number of iterations
458-
tol : float, optional
459-
Stop threshold on error (>0)
460-
verbose : bool, optional
461-
Print information along iterations
462-
log : bool, optional
463-
record log if True
455+
alpha : float, optional
456+
Trade-off parameter (0 < alpha < 1)
464457
armijo : bool, optional
465458
If True the steps of the line-search is found via an armijo research. Else closed form is used.
466459
If there is convergence issues use False.
460+
log : bool, optional
461+
record log if True
467462
**kwargs : dict
468463
parameters can be directly passed to the ot.optim.cg solver
469464
@@ -493,11 +488,11 @@ def df(G):
493488
return gwggrad(constC, hC1, hC2, G)
494489

495490
if log:
496-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
491+
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
497492
log['fgw_dist'] = log['loss'][::-1][0]
498493
return res, log
499494
else:
500-
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
495+
return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
501496

502497

503498
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
@@ -515,8 +510,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
515510
516511
where :
517512
- M is the (ns,nt) metric cost matrix
518-
- :math:`f` is the regularization term ( and df is its gradient)
519-
- a and b are source and target weights (sum to 1)
513+
- p and q are source and target weights (sum to 1)
520514
- L is a loss function to account for the misfit between the similarity matrices
521515
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
522516
@@ -534,17 +528,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
534528
Distribution in the target space.
535529
loss_fun : str, optional
536530
Loss function used for the solver.
537-
max_iter : int, optional
538-
Max number of iterations
539-
tol : float, optional
540-
Stop threshold on error (>0)
541-
verbose : bool, optional
542-
Print information along iterations
543-
log : bool, optional
544-
Record log if True.
531+
alpha : float, optional
532+
Trade-off parameter (0 < alpha < 1)
545533
armijo : bool, optional
546534
If True the steps of the line-search is found via an armijo research.
547535
Else closed form is used. If there is convergence issues use False.
536+
log : bool, optional
537+
Record log if True.
548538
**kwargs : dict
549539
Parameters can be directly pased to the ot.optim.cg solver.
550540
@@ -573,7 +563,7 @@ def f(G):
573563
def df(G):
574564
return gwggrad(constC, hC1, hC2, G)
575565

576-
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
566+
res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
577567
if log:
578568
log['fgw_dist'] = log['loss'][::-1][0]
579569
log['T'] = res
@@ -994,6 +984,16 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
994984
Whether to fix the structure of the barycenter during the updates
995985
fixed_features : bool
996986
Whether to fix the feature of the barycenter during the updates
987+
loss_fun : str
988+
Loss function used for the solver either 'square_loss' or 'kl_loss'
989+
max_iter : int, optional
990+
Max number of iterations
991+
tol : float, optional
992+
Stop threshol on error (>0).
993+
verbose : bool, optional
994+
Print information along iterations.
995+
log : bool, optional
996+
Record log if True.
997997
init_C : ndarray, shape (N,N), optional
998998
Initialization for the barycenters' structure matrix. If not set
999999
a random init is used.
@@ -1082,7 +1082,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10821082
T_temp = [t.T for t in T]
10831083
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
10841084

1085-
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
1085+
T = [fused_gromov_wasserstein(Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha,
10861086
numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10871087

10881088
# T is N,ns

ot/lp/__init__.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
import multiprocessing
1414
import sys
15+
1516
import numpy as np
1617
from scipy.sparse import coo_matrix
1718

18-
from .import cvx
19-
19+
from . import cvx
20+
from .cvx import barycenter
2021
# import compiled emd
2122
from .emd_wrap import emd_c, check_result, emd_1d_sorted
22-
from ..utils import parmap
23-
from .cvx import barycenter
2423
from ..utils import dist
24+
from ..utils import parmap
2525

2626
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
2727
'emd_1d', 'emd2_1d', 'wasserstein_1d']
@@ -458,7 +458,8 @@ def f(b):
458458
return res
459459

460460

461-
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
461+
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100,
462+
stopThr=1e-7, verbose=False, log=None):
462463
"""
463464
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
464465
@@ -525,8 +526,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
525526

526527
T_sum = np.zeros((k, d))
527528

528-
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
529-
529+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights,
530+
weights.tolist()):
530531
M_i = dist(X, measure_locations_i)
531532
T_i = emd(b, measure_weights_i, M_i)
532533
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
@@ -651,12 +652,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
651652
if b.ndim == 0 or len(b) == 0:
652653
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
653654

654-
x_a_1d = x_a.reshape((-1, ))
655-
x_b_1d = x_b.reshape((-1, ))
655+
x_a_1d = x_a.reshape((-1,))
656+
x_b_1d = x_b.reshape((-1,))
656657
perm_a = np.argsort(x_a_1d)
657658
perm_b = np.argsort(x_b_1d)
658659

659-
G_sorted, indices, cost = emd_1d_sorted(a, b,
660+
G_sorted, indices, cost = emd_1d_sorted(a[perm_a], b[perm_b],
660661
x_a_1d[perm_a], x_b_1d[perm_b],
661662
metric=metric, p=p)
662663
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),

setup.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
import numpy
99
import re
1010
import os
11+
import sys
12+
import subprocess
1113

1214
here = path.abspath(path.dirname(__file__))
1315

16+
17+
os.environ["CC"] = "g++"
18+
os.environ["CXX"] = "g++"
19+
1420
# dirty but working
1521
__version__ = re.search(
1622
r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # It excludes inline comment too
@@ -24,12 +30,13 @@
2430
with open(os.path.join(ROOT, 'README.md'), encoding="utf-8") as f:
2531
README = f.read()
2632

27-
# add platform dependant optional compilation argument
2833
opt_arg=["-O3"]
29-
import platform
30-
if platform.system()=='Darwin':
31-
if platform.release()=='18.0.0':
32-
opt_arg.append("-stdlib=libc++") # correspond to a compilation problem with Mojave and XCode 10
34+
35+
# add platform dependant optional compilation argument
36+
if sys.platform.startswith('darwin'):
37+
opt_arg.append("-stdlib=libc++")
38+
sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path'])
39+
os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))
3340

3441
setup(name='POT',
3542
version=__version__,

test/test_bregman.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,10 @@ def test_screenkhorn():
351351
# check marginals
352352
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
353353
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
354+
355+
356+
def test_convolutional_barycenter_non_square():
357+
# test for image with height not equal width
358+
A = np.ones((2, 2, 3)) / (2 * 3)
359+
b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
360+
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)

test/test_ot.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import warnings
88

99
import numpy as np
10+
import pytest
1011
from scipy.stats import wasserstein_distance
1112

1213
import ot
1314
from ot.datasets import make_1D_gauss as gauss
14-
import pytest
1515

1616

1717
def test_emd_dimension_mismatch():
@@ -75,12 +75,12 @@ def test_emd_1d_emd2_1d():
7575
np.testing.assert_allclose(wass, wass1d_emd2)
7676

7777
# check loss is similar to scipy's implementation for Euclidean metric
78-
wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
78+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
7979
np.testing.assert_allclose(wass_sp, wass1d_euc)
8080

8181
# check constraints
82-
np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
83-
np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
82+
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
83+
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
8484

8585
# check G is similar
8686
np.testing.assert_allclose(G, G_1d)
@@ -92,6 +92,42 @@ def test_emd_1d_emd2_1d():
9292
ot.emd_1d(u, v, [], [])
9393

9494

95+
def test_emd_1d_emd2_1d_with_weights():
96+
# test emd1d gives similar results as emd
97+
n = 20
98+
m = 30
99+
rng = np.random.RandomState(0)
100+
u = rng.randn(n, 1)
101+
v = rng.randn(m, 1)
102+
103+
w_u = rng.uniform(0., 1., n)
104+
w_u = w_u / w_u.sum()
105+
106+
w_v = rng.uniform(0., 1., m)
107+
w_v = w_v / w_v.sum()
108+
109+
M = ot.dist(u, v, metric='sqeuclidean')
110+
111+
G, log = ot.emd(w_u, w_v, M, log=True)
112+
wass = log["cost"]
113+
G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
114+
wass1d = log["cost"]
115+
wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
116+
wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
117+
118+
# check loss is similar
119+
np.testing.assert_allclose(wass, wass1d)
120+
np.testing.assert_allclose(wass, wass1d_emd2)
121+
122+
# check loss is similar to scipy's implementation for Euclidean metric
123+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v)
124+
np.testing.assert_allclose(wass_sp, wass1d_euc)
125+
126+
# check constraints
127+
np.testing.assert_allclose(w_u, G.sum(1))
128+
np.testing.assert_allclose(w_v, G.sum(0))
129+
130+
95131
def test_wass_1d():
96132
# test emd1d gives similar results as emd
97133
n = 20
@@ -135,7 +171,6 @@ def test_emd_empty():
135171

136172

137173
def test_emd_sparse():
138-
139174
n = 100
140175
rng = np.random.RandomState(0)
141176

@@ -211,7 +246,6 @@ def test_emd2_multi():
211246

212247

213248
def test_lp_barycenter():
214-
215249
a1 = np.array([1.0, 0, 0])[:, None]
216250
a2 = np.array([0, 0, 1.0])[:, None]
217251

@@ -228,7 +262,6 @@ def test_lp_barycenter():
228262

229263

230264
def test_free_support_barycenter():
231-
232265
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
233266
measures_weights = [np.array([1.]), np.array([1.])]
234267

@@ -244,7 +277,6 @@ def test_free_support_barycenter():
244277

245278
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
246279
def test_lp_barycenter_cvxopt():
247-
248280
a1 = np.array([1.0, 0, 0])[:, None]
249281
a2 = np.array([0, 0, 1.0])[:, None]
250282

0 commit comments

Comments
 (0)