Skip to content

Commit 9200af5

Browse files
author
ievred
committed
laplace v1
1 parent 6b8477d commit 9200af5

File tree

4 files changed

+49
-34
lines changed

4 files changed

+49
-34
lines changed

ot/bregman.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .utils import unif, dist
2020
from scipy.optimize import fmin_l_bfgs_b
2121

22+
2223
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
2324
stopThr=1e-9, verbose=False, log=False, **kwargs):
2425
r"""
@@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
539540
old_v = v[i_2]
540541
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
541542
G[:, i_2] = u * K[:, i_2] * v[i_2]
542-
#aviol = (G@one_m - a)
543-
#aviol_2 = (G.T@one_n - b)
543+
# aviol = (G@one_m - a)
544+
# aviol_2 = (G.T@one_n - b)
544545
viol += (-old_v + v[i_2]) * K[:, i_2] * u
545546
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]
546547

547-
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
548+
# print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
548549

549550
if stopThr_val <= stopThr:
550551
break
@@ -715,7 +716,7 @@ def get_Gamma(alpha, beta, u, v):
715716
if np.abs(u).max() > tau or np.abs(v).max() > tau:
716717
if n_hists:
717718
alpha, beta = alpha + reg * \
718-
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
719+
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
719720
else:
720721
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
721722
if n_hists:
@@ -940,7 +941,7 @@ def get_reg(n): # exponential decreasing
940941
# the 10th iterations
941942
transp = G
942943
err = np.linalg.norm(
943-
(np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2
944+
(np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2
944945
if log:
945946
log['err'].append(err)
946947

@@ -966,7 +967,7 @@ def get_reg(n): # exponential decreasing
966967

967968
def geometricBar(weights, alldistribT):
968969
"""return the weighted geometric mean of distributions"""
969-
assert(len(weights) == alldistribT.shape[1])
970+
assert (len(weights) == alldistribT.shape[1])
970971
return np.exp(np.dot(np.log(alldistribT), weights.T))
971972

972973

@@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
11081109
if weights is None:
11091110
weights = np.ones(A.shape[1]) / A.shape[1]
11101111
else:
1111-
assert(len(weights) == A.shape[1])
1112+
assert (len(weights) == A.shape[1])
11121113

11131114
if log:
11141115
log = {'err': []}
@@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
12061207
if weights is None:
12071208
weights = np.ones(n_hists) / n_hists
12081209
else:
1209-
assert(len(weights) == A.shape[1])
1210+
assert (len(weights) == A.shape[1])
12101211

12111212
if log:
12121213
log = {'err': []}
@@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13341335
if weights is None:
13351336
weights = np.ones(A.shape[0]) / A.shape[0]
13361337
else:
1337-
assert(len(weights) == A.shape[0])
1338+
assert (len(weights) == A.shape[0])
13381339

13391340
if log:
13401341
log = {'err': []}
@@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13501351
# this is equivalent to blurring on horizontal then vertical directions
13511352
t = np.linspace(0, 1, A.shape[1])
13521353
[Y, X] = np.meshgrid(t, t)
1353-
xi1 = np.exp(-(X - Y)**2 / reg)
1354+
xi1 = np.exp(-(X - Y) ** 2 / reg)
13541355

13551356
t = np.linspace(0, 1, A.shape[2])
13561357
[Y, X] = np.meshgrid(t, t)
1357-
xi2 = np.exp(-(X - Y)**2 / reg)
1358+
xi2 = np.exp(-(X - Y) ** 2 / reg)
13581359

13591360
def K(x):
13601361
return np.dot(np.dot(xi1, x), xi2)
@@ -1501,6 +1502,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
15011502
else:
15021503
return np.sum(K0, axis=1)
15031504

1505+
15041506
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15051507
stopThr=1e-6, verbose=False, log=False, **kwargs):
15061508
r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
@@ -1658,6 +1660,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16581660
else:
16591661
return couplings, bary
16601662

1663+
16611664
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
16621665
numIterMax=10000, stopThr=1e-9, verbose=False,
16631666
log=False, **kwargs):
@@ -1749,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
17491752
return pi
17501753

17511754

1752-
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1755+
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
1756+
verbose=False, log=False, **kwargs):
17531757
r'''
17541758
Solve the entropic regularization optimal transport problem from empirical
17551759
data and return the OT loss
@@ -1831,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
18311835
M = dist(X_s, X_t, metric=metric)
18321836

18331837
if log:
1834-
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1838+
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1839+
**kwargs)
18351840
return sinkhorn_loss, log
18361841
else:
1837-
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
1842+
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
1843+
**kwargs)
18381844
return sinkhorn_loss
18391845

18401846

1841-
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
1847+
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
1848+
verbose=False, log=False, **kwargs):
18421849
r'''
18431850
Compute the sinkhorn divergence loss from empirical data
18441851
@@ -1924,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
19241931
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
19251932
'''
19261933
if log:
1927-
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1934+
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
1935+
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
19281936

1929-
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1937+
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax,
1938+
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
19301939

1931-
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1940+
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
1941+
stopThr=1e-9, verbose=verbose, log=log, **kwargs)
19321942

19331943
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
19341944

@@ -1943,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
19431953
return max(0, sinkhorn_div), log
19441954

19451955
else:
1946-
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1956+
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
1957+
verbose=verbose, log=log, **kwargs)
19471958

1948-
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1959+
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
1960+
verbose=verbose, log=log, **kwargs)
19491961

1950-
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs)
1962+
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
1963+
verbose=verbose, log=log, **kwargs)
19511964

19521965
sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b)
19531966
return max(0, sinkhorn_div)
@@ -2039,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
20392052
try:
20402053
import bottleneck
20412054
except ImportError:
2042-
warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
2055+
warnings.warn(
2056+
"Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
20432057
bottleneck = np
20442058

20452059
a = np.asarray(a, dtype=np.float64)
@@ -2173,10 +2187,11 @@ def projection(u, epsilon):
21732187

21742188
# box constraints in L-BFGS-B (see Proposition 1 in [26])
21752189
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
2176-
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
2190+
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget
21772191

2178-
bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
2179-
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
2192+
bounds_v = [(
2193+
max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
2194+
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget
21802195

21812196
# pre-calculated constants for the objective
21822197
vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1)
@@ -2225,7 +2240,8 @@ def restricted_sinkhorn(usc, vsc, max_iter=5):
22252240
return usc, vsc
22262241

22272242
def screened_obj(usc, vsc):
2228-
part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc))
2243+
part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J,
2244+
np.log(vsc))
22292245
part_IJc = np.dot(usc, vec_eps_IJc)
22302246
part_IcJ = np.dot(vec_eps_IcJ, vsc)
22312247
psi_epsilon = part_IJ + part_IJc + part_IcJ
@@ -2247,9 +2263,9 @@ def bfgspost(theta):
22472263
g = np.hstack([g_u, g_v])
22482264
return f, g
22492265

2250-
#----------------------------------------------------------------------------------------------------------------#
2266+
# ----------------------------------------------------------------------------------------------------------------#
22512267
# Step 2: L-BFGS-B solver #
2252-
#----------------------------------------------------------------------------------------------------------------#
2268+
# ----------------------------------------------------------------------------------------------------------------#
22532269

22542270
u0, v0 = restricted_sinkhorn(u0, v0)
22552271
theta0 = np.hstack([u0, v0])

ot/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s):
3030
1D histogram for a gaussian distribution
3131
"""
3232
x = np.arange(n, dtype=np.float64)
33-
h = np.exp(-(x - m)**2 / (2 * s**2))
33+
h = np.exp(-(x - m) ** 2 / (2 * s ** 2))
3434
return h / h.sum()
3535

3636

@@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
8080
return make_2D_samples_gauss(n, m, sigma, random_state=None)
8181

8282

83-
def make_data_classif(dataset, n, nz=.5, theta=0, p = .5, random_state=None, **kwargs):
83+
def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs):
8484
"""Dataset generation for classification problems
8585
8686
Parameters

ot/lp/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
"""
33
Solvers for the original linear program OT problem
44
5-
6-
75
"""
86

97
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -18,7 +16,7 @@
1816
from .import cvx
1917

2018
# import compiled emd
21-
from .emd_wrap import emd_c, check_result, emd_1d_sorted
19+
#from .emd_wrap import emd_c, check_result, emd_1d_sorted
2220
from ..utils import parmap
2321
from .cvx import barycenter
2422
from ..utils import dist

ot/plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
7878
thr : float, optional
7979
threshold above which the line is drawn
8080
**kwargs : dict
81-
paameters given to the plot functions (default color is black if
81+
parameters given to the plot functions (default color is black if
8282
nothing given)
8383
"""
84+
8485
if ('color' not in kwargs) and ('c' not in kwargs):
8586
kwargs['color'] = 'k'
8687
mx = G.max()

0 commit comments

Comments
 (0)