@@ -3173,8 +3173,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
31733173 return loss
31743174
31753175 else :
3176- M = dist (nx .to_numpy (X_s ), nx .to_numpy (X_t ), metric = metric )
3177- M = nx .from_numpy (M , type_as = a )
3176+ M = dist (X_s , X_t , metric = metric )
31783177
31793178 if log :
31803179 sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax ,
@@ -3287,6 +3286,10 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
32873286 International Conference on Artficial Intelligence and Statistics,
32883287 (AISTATS) 21, 2018
32893288 '''
3289+ X_s , X_t = list_to_array (X_s , X_t )
3290+
3291+ nx = get_backend (X_s , X_t )
3292+
32903293 if log :
32913294 sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
32923295 numIterMax = numIterMax ,
@@ -3313,7 +3316,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33133316 log ['log_sinkhorn_a' ] = log_a
33143317 log ['log_sinkhorn_b' ] = log_b
33153318
3316- return max (0 , sinkhorn_div ), log
3319+ return nx . maximum (0 , sinkhorn_div ), log
33173320
33183321 else :
33193322 sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
@@ -3332,7 +3335,7 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33323335 warn = warn , ** kwargs )
33333336
33343337 sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b )
3335- return max (0 , sinkhorn_div )
3338+ return nx . maximum (0 , sinkhorn_div )
33363339
33373340
33383341def screenkhorn (a , b , M , reg , ns_budget = None , nt_budget = None , uniform = False ,
0 commit comments