Skip to content

Commit cb6bdc5

Browse files
committed
Speed-up Sinkhorn
Speed-up in 3 places: - the computation of pairwise distance is faster with sklearn.metrics.pairwise.euclidean_distances - faster computation of K = np.exp(-M / reg) - faster computation of the error every 10 iterations Example with this little script: import time import numpy as np import ot rng = np.random.RandomState(0) transport = ot.da.SinkhornTransport() time1 = time.time() Xs, ys, Xt = rng.randn(10000, 100), rng.randint(0, 2, size=10000), rng.randn(10000, 100) transport.fit(Xs=Xs, Xt=Xt) time2 = time.time() print("OT Computation Time {:6.2f} sec".format(time2-time1)) transport = ot.da.SinkhornLpl1Transport() transport.fit(Xs=Xs, ys=ys, Xt=Xt) time3 = time.time() print("OT LpL1 Computation Time {:6.2f} sec".format(time3-time2)) Before OT Computation Time 19.93 sec OT LpL1 Computation Time 133.43 sec After OT Computation Time 7.55 sec OT LpL1 Computation Time 82.25 sec
1 parent 39cbcd3 commit cb6bdc5

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

ot/bregman.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
344344

345345
# print(reg)
346346

347-
K = np.exp(-M / reg)
347+
K = np.empty(M.shape, dtype=M.dtype)
348+
np.divide(M, -reg, out=K)
349+
np.exp(K, out=K)
350+
348351
# print(np.min(K))
352+
tmp = np.empty(K.shape, dtype=M.dtype)
353+
tmp2 = np.empty(b.shape, dtype=M.dtype)
349354

350355
Kp = (1 / a).reshape(-1, 1) * K
351356
cpt = 0
@@ -373,8 +378,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
373378
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
374379
np.sum((v - vprev)**2) / np.sum((v)**2)
375380
else:
376-
transp = u.reshape(-1, 1) * (K * v)
377-
err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
381+
np.multiply(u.reshape(-1, 1), K, out=tmp)
382+
np.multiply(tmp, v.reshape(1, -1), out=tmp)
383+
np.sum(tmp, axis=0, out=tmp2)
384+
tmp2 -= b
385+
err = np.linalg.norm(tmp2)**2
378386
if log:
379387
log['err'].append(err)
380388

ot/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
from scipy.spatial.distance import cdist
16+
from sklearn.metrics.pairwise import euclidean_distances
1617
import sys
1718
import warnings
1819
try:
@@ -104,7 +105,8 @@ def dist(x1, x2=None, metric='sqeuclidean'):
104105
"""
105106
if x2 is None:
106107
x2 = x1
107-
108+
if metric == "sqeuclidean":
109+
return euclidean_distances(x1, x2, squared=True)
108110
return cdist(x1, x2, metric=metric)
109111

110112

0 commit comments

Comments
 (0)