Skip to content

Commit 5cd6c0a

Browse files
authored
Merge pull request #57 from LeoGautheron/master
Speed-up Sinkhorn
2 parents 7c5c880 + 0764e35 commit 5cd6c0a

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

ot/bregman.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
344344

345345
# print(reg)
346346

347-
K = np.exp(-M / reg)
347+
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
348+
K = np.empty(M.shape, dtype=M.dtype)
349+
np.divide(M, -reg, out=K)
350+
np.exp(K, out=K)
351+
348352
# print(np.min(K))
353+
tmp = np.empty(K.shape, dtype=M.dtype)
354+
tmp2 = np.empty(b.shape, dtype=M.dtype)
349355

350356
Kp = (1 / a).reshape(-1, 1) * K
351357
cpt = 0
@@ -373,8 +379,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
373379
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
374380
np.sum((v - vprev)**2) / np.sum((v)**2)
375381
else:
376-
transp = u.reshape(-1, 1) * (K * v)
377-
err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
382+
np.multiply(u.reshape(-1, 1), K, out=tmp)
383+
np.multiply(tmp, v.reshape(1, -1), out=tmp)
384+
np.sum(tmp, axis=0, out=tmp2)
385+
tmp2 -= b
386+
err = np.linalg.norm(tmp2)**2
378387
if log:
379388
log['err'].append(err)
380389

ot/utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,34 @@ def clean_zeros(a, b, M):
7777
return a2, b2, M2
7878

7979

80+
def euclidean_distances(X, Y, squared=False):
81+
"""
82+
Considering the rows of X (and Y=X) as vectors, compute the
83+
distance matrix between each pair of vectors.
84+
Parameters
85+
----------
86+
X : {array-like}, shape (n_samples_1, n_features)
87+
Y : {array-like}, shape (n_samples_2, n_features)
88+
squared : boolean, optional
89+
Return squared Euclidean distances.
90+
Returns
91+
-------
92+
distances : {array}, shape (n_samples_1, n_samples_2)
93+
"""
94+
XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
95+
YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
96+
distances = np.dot(X, Y.T)
97+
distances *= -2
98+
distances += XX
99+
distances += YY
100+
np.maximum(distances, 0, out=distances)
101+
if X is Y:
102+
# Ensure that distances between vectors and themselves are set to 0.0.
103+
# This may not be the case due to floating point rounding errors.
104+
distances.flat[::distances.shape[0] + 1] = 0.0
105+
return distances if squared else np.sqrt(distances, out=distances)
106+
107+
80108
def dist(x1, x2=None, metric='sqeuclidean'):
81109
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
82110
@@ -104,7 +132,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
104132
"""
105133
if x2 is None:
106134
x2 = x1
107-
135+
if metric == "sqeuclidean":
136+
return euclidean_distances(x1, x2, squared=True)
108137
return cdist(x1, x2, metric=metric)
109138

110139

0 commit comments

Comments
 (0)