@@ -1281,7 +1281,7 @@ def get_reg(n): # exponential decreasing
12811281 regi = get_reg (ii )
12821282
12831283 G , logi = sinkhorn_stabilized (a , b , M , regi ,
1284- numItermax = numInnerItermax , stopThr = 1e-9 ,
1284+ numItermax = numInnerItermax , stopThr = stopThr ,
12851285 warmstart = (alpha , beta ), verbose = False ,
12861286 print_period = 20 , tau = tau , log = True )
12871287
@@ -3306,17 +3306,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33063306 if log :
33073307 sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
33083308 numIterMax = numIterMax ,
3309- stopThr = 1e-9 , verbose = verbose ,
3309+ stopThr = stopThr , verbose = verbose ,
33103310 log = log , warn = warn , ** kwargs )
33113311
33123312 sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
33133313 numIterMax = numIterMax ,
3314- stopThr = 1e-9 , verbose = verbose ,
3314+ stopThr = stopThr , verbose = verbose ,
33153315 log = log , warn = warn , ** kwargs )
33163316
33173317 sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
33183318 numIterMax = numIterMax ,
3319- stopThr = 1e-9 , verbose = verbose ,
3319+ stopThr = stopThr , verbose = verbose ,
33203320 log = log , warn = warn , ** kwargs )
33213321
33223322 sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b )
@@ -3333,17 +3333,17 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33333333
33343334 else :
33353335 sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3336- numIterMax = numIterMax , stopThr = 1e-9 ,
3336+ numIterMax = numIterMax , stopThr = stopThr ,
33373337 verbose = verbose , log = log ,
33383338 warn = warn , ** kwargs )
33393339
33403340 sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3341- numIterMax = numIterMax , stopThr = 1e-9 ,
3341+ numIterMax = numIterMax , stopThr = stopThr ,
33423342 verbose = verbose , log = log ,
33433343 warn = warn , ** kwargs )
33443344
33453345 sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3346- numIterMax = numIterMax , stopThr = 1e-9 ,
3346+ numIterMax = numIterMax , stopThr = stopThr ,
33473347 verbose = verbose , log = log ,
33483348 warn = warn , ** kwargs )
33493349
0 commit comments