Skip to content

Commit 14c30d4

Browse files
ncassereaurflamary
andauthored
[MRG] GPU bugs solve (#288)
* gpus tests now passing * pep8 compliance * GPU tests succeeding even if b has rank higher than 1 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 1c7e7ce commit 14c30d4

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
312312
numItermax : int, optional
313313
Max number of iterations
314314
stopThr : float, optional
315-
Stop threshol on error (>0)
315+
Stop threshold on error (>0)
316316
verbose : bool, optional
317317
Print information along iterations
318318
log : bool, optional

ot/gpu/bregman.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
5454
numItermax : int, optional
5555
Max number of iterations
5656
stopThr : float, optional
57-
Stop threshol on error (>0)
57+
Stop threshold on error (>0)
5858
verbose : bool, optional
5959
Print information along iterations
6060
log : bool, optional
@@ -148,13 +148,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
148148
# we can speed up the process by checking for the error only all
149149
# the 10th iterations
150150
if nbb:
151-
err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
152-
np.sum((v - vprev)**2) / np.sum((v)**2)
151+
err = np.sqrt(
152+
np.sum((u - uprev)**2) / np.sum((u)**2)
153+
+ np.sum((v - vprev)**2) / np.sum((v)**2)
154+
)
153155
else:
154156
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
155157
tmp2 = np.sum(u[:, None] * K * v[None, :], 0)
156158
#tmp2=np.einsum('i,ij,j->j', u, K, v)
157-
err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
159+
err = np.linalg.norm(tmp2 - b) # violation of marginal
158160
if log:
159161
log['err'].append(err)
160162

ot/gpu/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
120120
labels_a2 = cp.asnumpy(labels_a)
121121
classes = npp.unique(labels_a2)
122122
for c in classes:
123-
idxc, = utils.to_gpu(npp.where(labels_a2 == c))
123+
idxc = utils.to_gpu(*npp.where(labels_a2 == c))
124124
indices_labels.append(idxc)
125125

126126
W = np.zeros(M.shape)

0 commit comments

Comments
 (0)