1919from .utils import unif , dist
2020from scipy .optimize import fmin_l_bfgs_b
2121
22+
2223def sinkhorn (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 ,
2324 stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
2425 r"""
@@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
539540 old_v = v [i_2 ]
540541 v [i_2 ] = b [i_2 ] / (K [:, i_2 ].T .dot (u ))
541542 G [:, i_2 ] = u * K [:, i_2 ] * v [i_2 ]
542- #aviol = (G@one_m - a)
543- #aviol_2 = (G.T@one_n - b)
543+ # aviol = (G@one_m - a)
544+ # aviol_2 = (G.T@one_n - b)
544545 viol += (- old_v + v [i_2 ]) * K [:, i_2 ] * u
545546 viol_2 [i_2 ] = v [i_2 ] * K [:, i_2 ].dot (u ) - b [i_2 ]
546547
547- #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
548+ # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
548549
549550 if stopThr_val <= stopThr :
550551 break
@@ -715,7 +716,7 @@ def get_Gamma(alpha, beta, u, v):
715716 if np .abs (u ).max () > tau or np .abs (v ).max () > tau :
716717 if n_hists :
717718 alpha , beta = alpha + reg * \
718- np .max (np .log (u ), 1 ), beta + reg * np .max (np .log (v ))
719+ np .max (np .log (u ), 1 ), beta + reg * np .max (np .log (v ))
719720 else :
720721 alpha , beta = alpha + reg * np .log (u ), beta + reg * np .log (v )
721722 if n_hists :
@@ -940,7 +941,7 @@ def get_reg(n): # exponential decreasing
940941 # the 10th iterations
941942 transp = G
942943 err = np .linalg .norm (
943- (np .sum (transp , axis = 0 ) - b ))** 2 + np .linalg .norm ((np .sum (transp , axis = 1 ) - a ))** 2
944+ (np .sum (transp , axis = 0 ) - b )) ** 2 + np .linalg .norm ((np .sum (transp , axis = 1 ) - a )) ** 2
944945 if log :
945946 log ['err' ].append (err )
946947
@@ -966,7 +967,7 @@ def get_reg(n): # exponential decreasing
966967
967968def geometricBar (weights , alldistribT ):
968969 """return the weighted geometric mean of distributions"""
969- assert (len (weights ) == alldistribT .shape [1 ])
970+ assert (len (weights ) == alldistribT .shape [1 ])
970971 return np .exp (np .dot (np .log (alldistribT ), weights .T ))
971972
972973
@@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
11081109 if weights is None :
11091110 weights = np .ones (A .shape [1 ]) / A .shape [1 ]
11101111 else :
1111- assert (len (weights ) == A .shape [1 ])
1112+ assert (len (weights ) == A .shape [1 ])
11121113
11131114 if log :
11141115 log = {'err' : []}
@@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
12061207 if weights is None :
12071208 weights = np .ones (n_hists ) / n_hists
12081209 else :
1209- assert (len (weights ) == A .shape [1 ])
1210+ assert (len (weights ) == A .shape [1 ])
12101211
12111212 if log :
12121213 log = {'err' : []}
@@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13341335 if weights is None :
13351336 weights = np .ones (A .shape [0 ]) / A .shape [0 ]
13361337 else :
1337- assert (len (weights ) == A .shape [0 ])
1338+ assert (len (weights ) == A .shape [0 ])
13381339
13391340 if log :
13401341 log = {'err' : []}
@@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13501351 # this is equivalent to blurring on horizontal then vertical directions
13511352 t = np .linspace (0 , 1 , A .shape [1 ])
13521353 [Y , X ] = np .meshgrid (t , t )
1353- xi1 = np .exp (- (X - Y )** 2 / reg )
1354+ xi1 = np .exp (- (X - Y ) ** 2 / reg )
13541355
13551356 t = np .linspace (0 , 1 , A .shape [2 ])
13561357 [Y , X ] = np .meshgrid (t , t )
1357- xi2 = np .exp (- (X - Y )** 2 / reg )
1358+ xi2 = np .exp (- (X - Y ) ** 2 / reg )
13581359
13591360 def K (x ):
13601361 return np .dot (np .dot (xi1 , x ), xi2 )
@@ -1501,6 +1502,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
15011502 else :
15021503 return np .sum (K0 , axis = 1 )
15031504
1505+
15041506def jcpot_barycenter (Xs , Ys , Xt , reg , metric = 'sqeuclidean' , numItermax = 100 ,
15051507 stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
15061508 r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
@@ -1658,6 +1660,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16581660 else :
16591661 return couplings , bary
16601662
1663+
16611664def empirical_sinkhorn (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' ,
16621665 numIterMax = 10000 , stopThr = 1e-9 , verbose = False ,
16631666 log = False , ** kwargs ):
@@ -1749,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
17491752 return pi
17501753
17511754
1752- def empirical_sinkhorn2 (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
1755+ def empirical_sinkhorn2 (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 ,
1756+ verbose = False , log = False , ** kwargs ):
17531757 r'''
17541758 Solve the entropic regularization optimal transport problem from empirical
17551759 data and return the OT loss
@@ -1831,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
18311835 M = dist (X_s , X_t , metric = metric )
18321836
18331837 if log :
1834- sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
1838+ sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log ,
1839+ ** kwargs )
18351840 return sinkhorn_loss , log
18361841 else :
1837- sinkhorn_loss = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
1842+ sinkhorn_loss = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log ,
1843+ ** kwargs )
18381844 return sinkhorn_loss
18391845
18401846
1841- def empirical_sinkhorn_divergence (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
1847+ def empirical_sinkhorn_divergence (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 ,
1848+ verbose = False , log = False , ** kwargs ):
18421849 r'''
18431850 Compute the sinkhorn divergence loss from empirical data
18441851
@@ -1924,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
19241931 .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
19251932 '''
19261933 if log :
1927- sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1934+ sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax ,
1935+ stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
19281936
1929- sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1937+ sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax ,
1938+ stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
19301939
1931- sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1940+ sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax ,
1941+ stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
19321942
19331943 sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
19341944
@@ -1943,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
19431953 return max (0 , sinkhorn_div ), log
19441954
19451955 else :
1946- sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1956+ sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 ,
1957+ verbose = verbose , log = log , ** kwargs )
19471958
1948- sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1959+ sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 ,
1960+ verbose = verbose , log = log , ** kwargs )
19491961
1950- sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1962+ sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 ,
1963+ verbose = verbose , log = log , ** kwargs )
19511964
19521965 sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
19531966 return max (0 , sinkhorn_div )
@@ -2039,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
20392052 try :
20402053 import bottleneck
20412054 except ImportError :
2042- warnings .warn ("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance." )
2055+ warnings .warn (
2056+ "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance." )
20432057 bottleneck = np
20442058
20452059 a = np .asarray (a , dtype = np .float64 )
@@ -2173,10 +2187,11 @@ def projection(u, epsilon):
21732187
21742188 # box constraints in L-BFGS-B (see Proposition 1 in [26])
21752189 bounds_u = [(max (a_I_min / ((nt - nt_budget ) * epsilon + nt_budget * (b_J_max / (
2176- ns * epsilon * kappa * K_min ))), epsilon / kappa ), a_I_max / (nt * epsilon * K_min ))] * ns_budget
2190+ ns * epsilon * kappa * K_min ))), epsilon / kappa ), a_I_max / (nt * epsilon * K_min ))] * ns_budget
21772191
2178- bounds_v = [(max (b_J_min / ((ns - ns_budget ) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min ))),
2179- epsilon * kappa ), b_J_max / (ns * epsilon * K_min ))] * nt_budget
2192+ bounds_v = [(
2193+ max (b_J_min / ((ns - ns_budget ) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min ))),
2194+ epsilon * kappa ), b_J_max / (ns * epsilon * K_min ))] * nt_budget
21802195
21812196 # pre-calculated constants for the objective
21822197 vec_eps_IJc = epsilon * kappa * (K_IJc * np .ones (nt - nt_budget ).reshape ((1 , - 1 ))).sum (axis = 1 )
@@ -2225,7 +2240,8 @@ def restricted_sinkhorn(usc, vsc, max_iter=5):
22252240 return usc , vsc
22262241
22272242 def screened_obj (usc , vsc ):
2228- part_IJ = np .dot (np .dot (usc , K_IJ ), vsc ) - kappa * np .dot (a_I , np .log (usc )) - (1. / kappa ) * np .dot (b_J , np .log (vsc ))
2243+ part_IJ = np .dot (np .dot (usc , K_IJ ), vsc ) - kappa * np .dot (a_I , np .log (usc )) - (1. / kappa ) * np .dot (b_J ,
2244+ np .log (vsc ))
22292245 part_IJc = np .dot (usc , vec_eps_IJc )
22302246 part_IcJ = np .dot (vec_eps_IcJ , vsc )
22312247 psi_epsilon = part_IJ + part_IJc + part_IcJ
@@ -2247,9 +2263,9 @@ def bfgspost(theta):
22472263 g = np .hstack ([g_u , g_v ])
22482264 return f , g
22492265
2250- #----------------------------------------------------------------------------------------------------------------#
2266+ # ----------------------------------------------------------------------------------------------------------------#
22512267 # Step 2: L-BFGS-B solver #
2252- #----------------------------------------------------------------------------------------------------------------#
2268+ # ----------------------------------------------------------------------------------------------------------------#
22532269
22542270 u0 , v0 = restricted_sinkhorn (u0 , v0 )
22552271 theta0 = np .hstack ([u0 , v0 ])
0 commit comments