1616
1717from .bregman import sinkhorn , jcpot_barycenter
1818from .lp import emd
19- from .utils import unif , dist , kernel , cost_normalization , label_normalization , laplacian
19+ from .utils import unif , dist , kernel , cost_normalization , label_normalization , laplacian , dots
2020from .utils import check_params , BaseEstimator
2121from .unbalanced import sinkhorn_unbalanced
2222from .optim import cg
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748 return A , b
749749
750750
751- def emd_laplace (a , b , xs , xt , M , sim , eta , alpha ,
751+ def emd_laplace (a , b , xs , xt , M , sim , reg , eta , alpha ,
752752 numItermax , stopThr , numInnerItermax ,
753753 stopInnerThr , log = False , verbose = False , ** kwargs ):
754754 r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
785785 samples in the target domain
786786 M : np.ndarray (ns,nt)
787787 loss matrix
788+ reg : string
789+ Type of Laplacian regularization
788790 eta : float
789791 Regularization term for Laplacian regularization
790792 alpha : float
@@ -844,6 +846,8 @@ def emd_laplace(a, b, xs, xt, M, sim, eta, alpha,
844846 sS = (sS + sS .T ) / 2
845847 sT = kneighbors_graph (xt , kwargs ['nn' ]).toarray ()
846848 sT = (sT + sT .T ) / 2
849+ else :
850+ raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = sim ))
847851
848852 lS = laplacian (sS )
849853 lT = laplacian (sT )
@@ -852,9 +856,18 @@ def f(G):
852856 return alpha * np .trace (np .dot (xt .T , np .dot (G .T , np .dot (lS , np .dot (G , xt ))))) \
853857 + (1 - alpha ) * np .trace (np .dot (xs .T , np .dot (G , np .dot (lT , np .dot (G .T , xs )))))
854858
859+ ls2 = lS + lS .T
860+ lt2 = lT + lT .T
861+ xt2 = np .dot (xt , xt .T )
862+
863+ if reg == 'disp' :
864+ Cs = - eta * alpha / xs .shape [0 ] * dots (ls2 , xs , xt .T )
865+ Ct = - eta * (1 - alpha ) / xt .shape [0 ] * dots (xs , xt .T , lt2 )
866+ M = M + Cs + Ct
867+
855868 def df (G ):
856- return alpha * np .dot (lS + lS . T , np .dot (G , np . dot ( xt , xt . T ) ))\
857- + (1 - alpha ) * np .dot (xs , np .dot (xs .T , np .dot (G , lT + lT . T )))
869+ return alpha * np .dot (ls2 , np .dot (G , xt2 ))\
870+ + (1 - alpha ) * np .dot (xs , np .dot (xs .T , np .dot (G , lt2 )))
858871
859872 return cg (a , b , M , reg = eta , f = f , df = df , G0 = None , numItermax = numItermax , numItermaxEmd = numInnerItermax ,
860873 stopThr = stopThr , stopThr2 = stopInnerThr , verbose = verbose , log = log )
@@ -1694,6 +1707,9 @@ class EMDLaplaceTransport(BaseTransport):
16941707
16951708 Parameters
16961709 ----------
1710+ reg_type : string optional (default='pos')
1711+ Type of the regularization term: 'pos' and 'disp' for
1712+ regularization term defined in [2] and [6], respectively.
16971713 reg_lap : float, optional (default=1)
16981714 Laplacian regularization parameter
16991715 reg_src : float, optional (default=0.5)
@@ -1737,11 +1753,12 @@ class EMDLaplaceTransport(BaseTransport):
17371753 in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
17381754 """
17391755
1740- def __init__ (self , reg_lap = 1. , reg_src = 1. , alpha = 0.5 ,
1756+ def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , alpha = 0.5 ,
17411757 metric = "sqeuclidean" , norm = None , similarity = "knn" , max_iter = 100 , tol = 1e-9 ,
17421758 max_inner_iter = 100000 , inner_tol = 1e-9 , log = False , verbose = False ,
17431759 distribution_estimation = distribution_estimation_uniform ,
17441760 out_of_sample_map = 'ferradans' ):
1761+ self .reg = reg_type
17451762 self .reg_lap = reg_lap
17461763 self .reg_src = reg_src
17471764 self .alpha = alpha
@@ -1785,7 +1802,7 @@ class label
17851802 super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
17861803
17871804 returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1788- xt = self .xt_ , M = self .cost_ , sim = self .similarity , eta = self .reg_lap , alpha = self .reg_src ,
1805+ xt = self .xt_ , M = self .cost_ , reg = self . reg , sim = self .similarity , eta = self .reg_lap , alpha = self .reg_src ,
17891806 numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
17901807 stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
17911808
0 commit comments