@@ -323,15 +323,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
323323 if len (b .shape ) < 2 :
324324 if method .lower () == 'sinkhorn' :
325325 res = sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
326- stopThr = stopThr , verbose = verbose , log = log ,
326+ stopThr = stopThr , verbose = verbose ,
327+ log = log , warn = warn ,
327328 ** kwargs )
328329 elif method .lower () == 'sinkhorn_log' :
329330 res = sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
330- stopThr = stopThr , verbose = verbose , log = log ,
331+ stopThr = stopThr , verbose = verbose ,
332+ log = log , warn = warn ,
331333 ** kwargs )
332334 elif method .lower () == 'sinkhorn_stabilized' :
333335 res = sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
334- stopThr = stopThr , verbose = verbose , log = log ,
336+ stopThr = stopThr , verbose = verbose ,
337+ log = log , warn = warn ,
335338 ** kwargs )
336339 else :
337340 raise ValueError ("Unknown method '%s'." % method )
@@ -344,15 +347,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
344347
345348 if method .lower () == 'sinkhorn' :
346349 return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
347- stopThr = stopThr , verbose = verbose , log = log ,
350+ stopThr = stopThr , verbose = verbose ,
351+ log = log , warn = warn ,
348352 ** kwargs )
349353 elif method .lower () == 'sinkhorn_log' :
350354 return sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
351- stopThr = stopThr , verbose = verbose , log = log ,
355+ stopThr = stopThr , verbose = verbose ,
356+ log = log , warn = warn ,
352357 ** kwargs )
353358 elif method .lower () == 'sinkhorn_stabilized' :
354359 return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
355- stopThr = stopThr , verbose = verbose , log = log ,
360+ stopThr = stopThr , verbose = verbose ,
361+ log = log , warn = warn ,
356362 ** kwargs )
357363 else :
358364 raise ValueError ("Unknown method '%s'." % method )
0 commit comments