Skip to content

Commit 9351bfa

Browse files
authored
Merge pull request #58 from rflamary/speedup
Speedup Sinkhorn with einsum + bench
2 parents 5cd6c0a + f4bfeb7 commit 9351bfa

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

Makefile

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22

33
PYTHON=python3
4+
branch := $(shell git symbolic-ref --short -q HEAD)
45

56
help :
67
@echo "The following make targets are available:"
@@ -57,6 +58,16 @@ rdoc :
5758
notebook :
5859
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
5960

61+
bench :
62+
@git stash >/dev/null 2>&1
63+
@echo 'Branch master'
64+
@git checkout master >/dev/null 2>&1
65+
python3 $(script)
66+
@echo 'Branch $(branch)'
67+
@git checkout $(branch) >/dev/null 2>&1
68+
python3 $(script)
69+
@git stash apply >/dev/null 2>&1
70+
6071
autopep8 :
6172
autopep8 -ir test ot examples --jobs -1
6273

ot/bregman.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
350350
np.exp(K, out=K)
351351

352352
# print(np.min(K))
353-
tmp = np.empty(K.shape, dtype=M.dtype)
354353
tmp2 = np.empty(b.shape, dtype=M.dtype)
355354

356355
Kp = (1 / a).reshape(-1, 1) * K
@@ -359,6 +358,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
359358
while (err > stopThr and cpt < numItermax):
360359
uprev = u
361360
vprev = v
361+
362362
KtransposeU = np.dot(K.T, u)
363363
v = np.divide(b, KtransposeU)
364364
u = 1. / np.dot(Kp, v)
@@ -379,11 +379,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
379379
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
380380
np.sum((v - vprev)**2) / np.sum((v)**2)
381381
else:
382-
np.multiply(u.reshape(-1, 1), K, out=tmp)
383-
np.multiply(tmp, v.reshape(1, -1), out=tmp)
384-
np.sum(tmp, axis=0, out=tmp2)
385-
tmp2 -= b
386-
err = np.linalg.norm(tmp2)**2
382+
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
383+
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
384+
err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
387385
if log:
388386
log['err'].append(err)
389387

@@ -398,10 +396,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
398396
log['v'] = v
399397

400398
if nbb: # return only loss
401-
res = np.zeros((nbb))
402-
for i in range(nbb):
403-
res[i] = np.sum(
404-
u[:, i].reshape((-1, 1)) * K * v[:, i].reshape((1, -1)) * M)
399+
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
405400
if log:
406401
return res, log
407402
else:

0 commit comments

Comments
 (0)