|
8 | 8 | from .lp import emd |
9 | 9 | from .utils import unif,dist,kernel |
10 | 10 | from .optim import cg |
| 11 | +from .optim import gcg |
11 | 12 |
|
12 | 13 |
|
13 | 14 | def indices(a, func): |
@@ -122,6 +123,100 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter |
122 | 123 |
|
123 | 124 | return transp |
124 | 125 |
|
| 126 | +def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): |
| 127 | + """ |
| 128 | + Solve the entropic regularization optimal transport problem with group lasso regularization |
| 129 | +
|
| 130 | + The function solves the following optimization problem: |
| 131 | +
|
| 132 | + .. math:: |
| 133 | + \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma) |
| 134 | +
|
| 135 | + s.t. \gamma 1 = a |
| 136 | +
|
| 137 | + \gamma^T 1= b |
| 138 | +
|
| 139 | + \gamma\geq 0 |
| 140 | + where : |
| 141 | +
|
| 142 | + - M is the (ns,nt) metric cost matrix |
| 143 | + - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` |
| 144 | + - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain. |
| 145 | + - a and b are source and target weights (sum to 1) |
| 146 | +
|
| 147 | + The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_ |
| 148 | +
|
| 149 | +
|
| 150 | + Parameters |
| 151 | + ---------- |
| 152 | + a : np.ndarray (ns,) |
| 153 | + samples weights in the source domain |
| 154 | + labels_a : np.ndarray (ns,) |
| 155 | + labels of samples in the source domain |
| 156 | + b : np.ndarray (nt,) |
| 157 | + samples in the target domain |
| 158 | + M : np.ndarray (ns,nt) |
| 159 | + loss matrix |
| 160 | + reg : float |
| 161 | + Regularization term for entropic regularization >0 |
| 162 | + eta : float, optional |
| 163 | + Regularization term for group lasso regularization >0 |
| 164 | + numItermax : int, optional |
| 165 | + Max number of iterations |
| 166 | + numInnerItermax : int, optional |
| 167 | + Max number of iterations (inner sinkhorn solver) |
| 168 | + stopInnerThr : float, optional |
| 169 | + Stop threshold on error (inner sinkhorn solver) (>0) |
| 170 | + verbose : bool, optional |
| 171 | + Print information along iterations |
| 172 | + log : bool, optional |
| 173 | + record log if True |
| 174 | +
|
| 175 | +
|
| 176 | + Returns |
| 177 | + ------- |
| 178 | + gamma : (ns x nt) ndarray |
| 179 | + Optimal transportation matrix for the given parameters |
| 180 | + log : dict |
| 181 | + log dictionary return only if log==True in parameters |
| 182 | +
|
| 183 | +
|
| 184 | + References |
| 185 | + ---------- |
| 186 | +
|
| 187 | + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 |
| 188 | + .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. |
| 189 | +
|
| 190 | + See Also |
| 191 | + -------- |
| 192 | + ot.optim.gcg : Generalized conditional gradient for OT problems |
| 193 | +
|
| 194 | + """ |
| 195 | + lstlab=np.unique(labels_a) |
| 196 | + |
| 197 | + def f(G): |
| 198 | + res=0 |
| 199 | + for i in range(G.shape[1]): |
| 200 | + for lab in lstlab: |
| 201 | + temp=G[labels_a==lab,i] |
| 202 | + res+=np.linalg.norm(temp) |
| 203 | + return res |
| 204 | + |
| 205 | + def df(G): |
| 206 | + W=np.zeros(G.shape) |
| 207 | + for i in range(G.shape[1]): |
| 208 | + for lab in lstlab: |
| 209 | + temp=G[labels_a==lab,i] |
| 210 | + n=np.linalg.norm(temp) |
| 211 | + if n: |
| 212 | + W[labels_a==lab,i]=temp/n |
| 213 | + return W |
| 214 | + |
| 215 | + |
| 216 | + return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log) |
| 217 | + |
| 218 | + |
| 219 | + |
125 | 220 | def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs): |
126 | 221 | """Joint OT and linear mapping estimation as proposed in [8] |
127 | 222 |
|
@@ -632,6 +727,27 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): |
632 | 727 | self.M=dist(xs,xt,metric=self.metric) |
633 | 728 | self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs) |
634 | 729 | self.computed=True |
| 730 | + |
| 731 | +class OTDA_l1l2(OTDA): |
| 732 | + """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" |
| 733 | + |
| 734 | + |
| 735 | + def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs): |
| 736 | + """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" |
| 737 | + self.xs=xs |
| 738 | + self.xt=xt |
| 739 | + |
| 740 | + if wt is None: |
| 741 | + wt=unif(xt.shape[0]) |
| 742 | + if ws is None: |
| 743 | + ws=unif(xs.shape[0]) |
| 744 | + |
| 745 | + self.ws=ws |
| 746 | + self.wt=wt |
| 747 | + |
| 748 | + self.M=dist(xs,xt,metric=self.metric) |
| 749 | + self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs) |
| 750 | + self.computed=True |
635 | 751 |
|
636 | 752 | class OTDA_mapping_linear(OTDA): |
637 | 753 | """Class for optimal transport with joint linear mapping estimation as in [8]""" |
|
0 commit comments