|
14 | 14 | # from .utils import unif, dist |
15 | 15 |
|
16 | 16 |
|
17 | | -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000, |
| 17 | +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, |
18 | 18 | stopThr=1e-6, verbose=False, log=False, **kwargs): |
19 | 19 | r""" |
20 | 20 | Solve the unbalanced entropic regularization optimal transport problem |
@@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numI |
120 | 120 | """ |
121 | 121 |
|
122 | 122 | if method.lower() == 'sinkhorn': |
123 | | - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div, |
| 123 | + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, |
124 | 124 | numItermax=numItermax, |
125 | 125 | stopThr=stopThr, verbose=verbose, |
126 | 126 | log=log, **kwargs) |
127 | 127 |
|
128 | 128 | elif method.lower() == 'sinkhorn_stabilized': |
129 | | - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div, |
| 129 | + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, |
130 | 130 | numItermax=numItermax, |
131 | 131 | stopThr=stopThr, |
132 | 132 | verbose=verbose, |
133 | 133 | log=log, **kwargs) |
134 | 134 | elif method.lower() in ['sinkhorn_reg_scaling']: |
135 | 135 | warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') |
136 | | - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg, |
| 136 | + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, |
137 | 137 | numItermax=numItermax, |
138 | 138 | stopThr=stopThr, verbose=verbose, |
139 | 139 | log=log, **kwargs) |
@@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', |
261 | 261 | else: |
262 | 262 | raise ValueError('Unknown method %s.' % method) |
263 | 263 |
|
264 | | -# TODO: update the doc |
265 | | -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, |
| 264 | + |
| 265 | +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, |
266 | 266 | stopThr=1e-6, verbose=False, log=False, **kwargs): |
267 | 267 | r""" |
268 | 268 | Solve the entropic regularization unbalanced optimal transport problem and return the loss |
@@ -349,7 +349,6 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, |
349 | 349 | """ |
350 | 350 |
|
351 | 351 | a = np.asarray(a, dtype=np.float64) |
352 | | - print(a) |
353 | 352 | b = np.asarray(b, dtype=np.float64) |
354 | 353 | M = np.asarray(M, dtype=np.float64) |
355 | 354 |
|
@@ -377,39 +376,24 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, |
377 | 376 | else: |
378 | 377 | u = np.ones(dim_a) / dim_a |
379 | 378 | v = np.ones(dim_b) / dim_b |
380 | | - u = np.ones(dim_a) |
381 | | - v = np.ones(dim_b) |
382 | 379 |
|
383 | 380 | # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute |
384 | 381 | K = np.empty(M.shape, dtype=M.dtype) |
385 | | - np.true_divide(M, -reg, out=K) |
| 382 | + np.divide(M, -reg, out=K) |
386 | 383 | np.exp(K, out=K) |
387 | | - |
388 | | - if div == "KL": |
389 | | - fi = reg_m / (reg_m + reg) |
390 | | - elif div == "TV": |
391 | | - fi = reg_m / reg |
| 384 | + |
| 385 | + fi = reg_m / (reg_m + reg) |
392 | 386 |
|
393 | 387 | err = 1. |
394 | | - |
395 | | - dx = np.ones(dim_a) / dim_a |
396 | | - dy = np.ones(dim_b) / dim_b |
397 | | - z = 1 |
398 | 388 |
|
399 | 389 | for i in range(numItermax): |
400 | 390 | uprev = u |
401 | 391 | vprev = v |
402 | 392 |
|
403 | | - Kv = z*K.dot(v*dy) |
404 | | - u = scaling_iter_prox(Kv, a, fi, div) |
405 | | - #u = (a / Kv) ** fi |
406 | | - Ktu = z*K.T.dot(u*dx) |
407 | | - v = scaling_iter_prox(Ktu, b, fi, div) |
408 | | - #v = (b / Ktu) ** fi |
409 | | - #print(v*dy) |
410 | | - z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35 |
411 | | - print(z) |
412 | | - |
| 393 | + Kv = K.dot(v) |
| 394 | + u = (a / Kv) ** fi |
| 395 | + Ktu = K.T.dot(u) |
| 396 | + v = (b / Ktu) ** fi |
413 | 397 |
|
414 | 398 | if (np.any(Ktu == 0.) |
415 | 399 | or np.any(np.isnan(u)) or np.any(np.isnan(v)) |
@@ -450,12 +434,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, |
450 | 434 | if log: |
451 | 435 | return u[:, None] * K * v[None, :], log |
452 | 436 | else: |
453 | | - return z*u[:, None] * K * v[None, :] |
| 437 | + return u[:, None] * K * v[None, :] |
| 438 | + |
454 | 439 |
|
455 | | -# TODO: update the doc |
456 | | -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, |
457 | | - numItermax=1000, stopThr=1e-6, |
458 | | - verbose=False, log=False, **kwargs): |
| 440 | +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, |
| 441 | + stopThr=1e-6, verbose=False, log=False, |
| 442 | + **kwargs): |
459 | 443 | r""" |
460 | 444 | Solve the entropic regularization unbalanced optimal transport |
461 | 445 | problem and return the loss |
@@ -580,10 +564,7 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, |
580 | 564 | np.divide(M, -reg, out=K) |
581 | 565 | np.exp(K, out=K) |
582 | 566 |
|
583 | | - if div == "KL": |
584 | | - fi = reg_m / (reg_m + reg) |
585 | | - elif div == "TV": |
586 | | - fi = reg_m / reg |
| 567 | + fi = reg_m / (reg_m + reg) |
587 | 568 |
|
588 | 569 | cpt = 0 |
589 | 570 | err = 1. |
@@ -669,15 +650,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, |
669 | 650 | else: |
670 | 651 | return ot_matrix |
671 | 652 |
|
672 | | -def scaling_iter_prox(s, p, fi, div): |
673 | | - if div == "KL": |
674 | | - return (p / s) ** fi |
675 | | - elif div == "TV": |
676 | | - return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s |
677 | | - else: |
678 | | - raise ValueError("Unknown divergence '%s'." % div) |
679 | | - |
680 | | - |
681 | 653 |
|
682 | 654 | def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, |
683 | 655 | numItermax=1000, stopThr=1e-6, |
|
0 commit comments