@@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
384384
385385 fi = reg_m / (reg_m + reg )
386386
387- cpt = 0
388387 err = 1.
389388
390- while ( err > stopThr and cpt < numItermax ):
389+ for i in range ( numItermax ):
391390 uprev = u
392391 vprev = v
393392
@@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
401400 or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
402401 # we have reached the machine precision
403402 # come back to previous solution and quit loop
404- warnings .warn ('Numerical errors at iteration %s' % cpt )
403+ warnings .warn ('Numerical errors at iteration %s' % i )
405404 u = uprev
406405 v = vprev
407406 break
408- if cpt % 10 == 0 :
409- # we can speed up the process by checking for the error only all
410- # the 10th iterations
411- err_u = abs (u - uprev ).max () / max (abs (u ).max (), abs (uprev ).max (), 1. )
412- err_v = abs (v - vprev ).max () / max (abs (v ).max (), abs (vprev ).max (), 1. )
413- err = 0.5 * (err_u + err_v )
414- if log :
415- log ['err' ].append (err )
407+
408+ err_u = abs (u - uprev ).max () / max (abs (u ).max (), abs (uprev ).max (), 1. )
409+ err_v = abs (v - vprev ).max () / max (abs (v ).max (), abs (vprev ).max (), 1. )
410+ err = 0.5 * (err_u + err_v )
411+ if log :
412+ log ['err' ].append (err )
416413 if verbose :
417- if cpt % 200 == 0 :
414+ if i % 50 == 0 :
418415 print (
419416 '{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
420- print ('{:5d}|{:8e}|' .format (cpt , err ))
421- cpt += 1
417+ print ('{:5d}|{:8e}|' .format (i , err ))
418+ if err < stopThr :
419+ break
422420
423421 if log :
424- log ['logu' ] = np .log (u + 1e-16 )
425- log ['logv' ] = np .log (v + 1e-16 )
422+ log ['logu' ] = np .log (u + 1e-300 )
423+ log ['logv' ] = np .log (v + 1e-300 )
426424
427425 if n_hists : # return only loss
428426 res = np .einsum ('ik,ij,jk,ij->k' , u , K , v , M )
@@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
747745 alpha = np .zeros (dim )
748746 beta = np .zeros (dim )
749747 q = np .ones (dim ) / dim
750- while ( err > stopThr and cpt < numItermax ):
751- qprev = q
748+ for i in range ( numItermax ):
749+ qprev = q . copy ()
752750 Kv = K .dot (v )
753751 f_alpha = np .exp (- alpha / (reg + reg_m ))
754752 f_beta = np .exp (- beta / (reg + reg_m ))
@@ -777,28 +775,29 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
777775 warnings .warn ('Numerical errors at iteration %s' % cpt )
778776 q = qprev
779777 break
780- if (cpt % 10 == 0 and not absorbing ) or cpt == 0 :
778+ if (i % 10 == 0 and not absorbing ) or i == 0 :
781779 # we can speed up the process by checking for the error only all
782780 # the 10th iterations
783781 err = abs (q - qprev ).max () / max (abs (q ).max (),
784782 abs (qprev ).max (), 1. )
785783 if log :
786784 log ['err' ].append (err )
787785 if verbose :
788- if cpt % 50 == 0 :
786+ if i % 50 == 0 :
789787 print (
790788 '{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
791- print ('{:5d}|{:8e}|' .format (cpt , err ))
789+ print ('{:5d}|{:8e}|' .format (i , err ))
790+ if err < stopThr :
791+ break
792792
793- cpt += 1
794793 if err > stopThr :
795794 warnings .warn ("Stabilized Unbalanced Sinkhorn did not converge." +
796795 "Try a larger entropy `reg` or a lower mass `reg_m`." +
797796 "Or a larger absorption threshold `tau`." )
798797 if log :
799- log ['niter' ] = cpt
800- log ['logu' ] = np .log (u + 1e-16 )
801- log ['logv' ] = np .log (v + 1e-16 )
798+ log ['niter' ] = i
799+ log ['logu' ] = np .log (u + 1e-300 )
800+ log ['logv' ] = np .log (v + 1e-300 )
802801 return q , log
803802 else :
804803 return q
@@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
882881
883882 fi = reg_m / (reg_m + reg )
884883
885- v = np .ones ((dim , n_hists )) / dim
886- u = np .ones ((dim , 1 )) / dim
887-
888- cpt = 0
884+ v = np .ones ((dim , n_hists ))
885+ u = np .ones ((dim , 1 ))
886+ q = np .ones (dim )
889887 err = 1.
890888
891- while (err > stopThr and cpt < numItermax ):
892- uprev = u
893- vprev = v
889+ for i in range (numItermax ):
890+ uprev = u .copy ()
891+ vprev = v .copy ()
892+ qprev = q .copy ()
894893
895894 Kv = K .dot (v )
896895 u = (A / Kv ) ** fi
@@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
905904 or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
906905 # we have reached the machine precision
907906 # come back to previous solution and quit loop
908- warnings .warn ('Numerical errors at iteration %s' % cpt )
907+ warnings .warn ('Numerical errors at iteration %s' % i )
909908 u = uprev
910909 v = vprev
910+ q = qprev
911911 break
912- if cpt % 10 == 0 :
913- # we can speed up the process by checking for the error only all
914- # the 10th iterations
915- err_u = abs (u - uprev ).max ()
916- err_u /= max (abs (u ).max (), abs (uprev ).max (), 1. )
917- err_v = abs (v - vprev ).max ()
918- err_v /= max (abs (v ).max (), abs (vprev ).max (), 1. )
919- err = 0.5 * (err_u + err_v )
920- if log :
921- log ['err' ].append (err )
922- if verbose :
923- if cpt % 50 == 0 :
924- print (
925- '{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
926- print ('{:5d}|{:8e}|' .format (cpt , err ))
912+ # compute change in barycenter
913+ err = abs (q - qprev ).max ()
914+ err /= max (abs (q ).max (), abs (qprev ).max (), 1. )
915+ if log :
916+ log ['err' ].append (err )
917+ # if barycenter did not change + at least 10 iterations - stop
918+ if err < stopThr and i > 10 :
919+ break
920+
921+ if verbose :
922+ if i % 10 == 0 :
923+ print (
924+ '{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
925+ print ('{:5d}|{:8e}|' .format (i , err ))
927926
928- cpt += 1
929927 if log :
930- log ['niter' ] = cpt
931- log ['logu' ] = np .log (u + 1e-16 )
932- log ['logv' ] = np .log (v + 1e-16 )
928+ log ['niter' ] = i
929+ log ['logu' ] = np .log (u + 1e-300 )
930+ log ['logv' ] = np .log (v + 1e-300 )
933931 return q , log
934932 else :
935933 return q
@@ -1002,19 +1000,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
10021000
10031001 if method .lower () == 'sinkhorn' :
10041002 return barycenter_unbalanced_sinkhorn (A , M , reg , reg_m ,
1003+ weights = weights ,
10051004 numItermax = numItermax ,
10061005 stopThr = stopThr , verbose = verbose ,
10071006 log = log , ** kwargs )
10081007
10091008 elif method .lower () == 'sinkhorn_stabilized' :
10101009 return barycenter_unbalanced_stabilized (A , M , reg , reg_m ,
1010+ weights = weights ,
10111011 numItermax = numItermax ,
10121012 stopThr = stopThr ,
10131013 verbose = verbose ,
10141014 log = log , ** kwargs )
10151015 elif method .lower () in ['sinkhorn_reg_scaling' ]:
10161016 warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
10171017 return barycenter_unbalanced (A , M , reg , reg_m ,
1018+ weights = weights ,
10181019 numItermax = numItermax ,
10191020 stopThr = stopThr , verbose = verbose ,
10201021 log = log , ** kwargs )
0 commit comments