Skip to content

Commit 22036bd

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
da with GL
1 parent 0d1f3eb commit 22036bd

File tree

4 files changed

+154
-16
lines changed

4 files changed

+154
-16
lines changed

examples/demo_OTDA_classes.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,43 +48,62 @@
4848
da_entrop.fit(xs,xt,reg=lambd)
4949
xsts=da_entrop.interp()
5050

51-
# Group lasso regularization
51+
# non-convex Group lasso regularization
5252
reg=1e-1
5353
eta=1e0
5454
da_lpl1=ot.da.OTDA_lpl1()
55-
da_lpl1.fit(xs,ys,xt,reg=lambd,eta=eta)
55+
da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
5656
xstg=da_lpl1.interp()
5757

58+
59+
# True Group lasso regularization
60+
reg=1e-1
61+
eta=1e1
62+
da_l1l2=ot.da.OTDA_l1l2()
63+
da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
64+
xstgl=da_l1l2.interp()
65+
66+
5867
#%% plot interpolated source samples
59-
pl.figure(4,(15,10))
68+
pl.figure(4,(15,8))
6069

6170
param_img={'interpolation':'nearest','cmap':'jet'}
6271

63-
pl.subplot(2,3,1)
72+
pl.subplot(2,4,1)
6473
pl.imshow(da_emd.G,**param_img)
6574
pl.title('OT matrix')
6675

6776

68-
pl.subplot(2,3,2)
77+
pl.subplot(2,4,2)
6978
pl.imshow(da_entrop.G,**param_img)
7079
pl.title('OT matrix sinkhorn')
7180

72-
pl.subplot(2,3,3)
81+
pl.subplot(2,4,3)
7382
pl.imshow(da_lpl1.G,**param_img)
83+
pl.title('OT matrix non-convex Group Lasso')
84+
85+
pl.subplot(2,4,4)
86+
pl.imshow(da_l1l2.G,**param_img)
7487
pl.title('OT matrix Group Lasso')
7588

76-
pl.subplot(2,3,4)
89+
90+
pl.subplot(2,4,5)
7791
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
7892
pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
7993
pl.title('Interp samples')
8094
pl.legend(loc=0)
8195

82-
pl.subplot(2,3,5)
96+
pl.subplot(2,4,6)
8397
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
8498
pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
8599
pl.title('Interp samples Sinkhorn')
86100

87-
pl.subplot(2,3,6)
101+
pl.subplot(2,4,7)
88102
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
89103
pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
104+
pl.title('Interp samples non-convex Group Lasso')
105+
106+
pl.subplot(2,4,8)
107+
pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
108+
pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
90109
pl.title('Interp samples Group Lasso')

examples/demo_optim_OTreg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def df(G): return np.log(G)+1
6060
def f(G): return 0.5*np.sum(G**2)
6161
def df(G): return G
6262

63-
reg1=1e-3
64-
reg2=1e-3
63+
reg1=1e-1
64+
reg2=1e-1
6565

6666
Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)
6767

ot/da.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .lp import emd
99
from .utils import unif,dist,kernel
1010
from .optim import cg
11+
from .optim import gcg
1112

1213

1314
def indices(a, func):
@@ -122,6 +123,100 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
122123

123124
return transp
124125

126+
def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
127+
"""
128+
Solve the entropic regularization optimal transport problem with group lasso regularization
129+
130+
The function solves the following optimization problem:
131+
132+
.. math::
133+
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
134+
135+
s.t. \gamma 1 = a
136+
137+
\gamma^T 1= b
138+
139+
\gamma\geq 0
140+
where :
141+
142+
- M is the (ns,nt) metric cost matrix
143+
- :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
144+
- :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
145+
- a and b are source and target weights (sum to 1)
146+
147+
The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
148+
149+
150+
Parameters
151+
----------
152+
a : np.ndarray (ns,)
153+
samples weights in the source domain
154+
labels_a : np.ndarray (ns,)
155+
labels of samples in the source domain
156+
b : np.ndarray (nt,)
157+
samples in the target domain
158+
M : np.ndarray (ns,nt)
159+
loss matrix
160+
reg : float
161+
Regularization term for entropic regularization >0
162+
eta : float, optional
163+
Regularization term for group lasso regularization >0
164+
numItermax : int, optional
165+
Max number of iterations
166+
numInnerItermax : int, optional
167+
Max number of iterations (inner sinkhorn solver)
168+
stopInnerThr : float, optional
169+
Stop threshold on error (inner sinkhorn solver) (>0)
170+
verbose : bool, optional
171+
Print information along iterations
172+
log : bool, optional
173+
record log if True
174+
175+
176+
Returns
177+
-------
178+
gamma : (ns x nt) ndarray
179+
Optimal transportation matrix for the given parameters
180+
log : dict
181+
log dictionary return only if log==True in parameters
182+
183+
184+
References
185+
----------
186+
187+
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
188+
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
189+
190+
See Also
191+
--------
192+
ot.optim.gcg : Generalized conditional gradient for OT problems
193+
194+
"""
195+
lstlab=np.unique(labels_a)
196+
197+
def f(G):
198+
res=0
199+
for i in range(G.shape[1]):
200+
for lab in lstlab:
201+
temp=G[labels_a==lab,i]
202+
res+=np.linalg.norm(temp)
203+
return res
204+
205+
def df(G):
206+
W=np.zeros(G.shape)
207+
for i in range(G.shape[1]):
208+
for lab in lstlab:
209+
temp=G[labels_a==lab,i]
210+
n=np.linalg.norm(temp)
211+
if n:
212+
W[labels_a==lab,i]=temp/n
213+
return W
214+
215+
216+
return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log)
217+
218+
219+
125220
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
126221
"""Joint OT and linear mapping estimation as proposed in [8]
127222
@@ -632,6 +727,27 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
632727
self.M=dist(xs,xt,metric=self.metric)
633728
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
634729
self.computed=True
730+
731+
class OTDA_l1l2(OTDA):
732+
"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
733+
734+
735+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
736+
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
737+
self.xs=xs
738+
self.xt=xt
739+
740+
if wt is None:
741+
wt=unif(xt.shape[0])
742+
if ws is None:
743+
ws=unif(xs.shape[0])
744+
745+
self.ws=ws
746+
self.wt=wt
747+
748+
self.M=dist(xs,xt,metric=self.metric)
749+
self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs)
750+
self.computed=True
635751

636752
class OTDA_mapping_linear(OTDA):
637753
"""Class for optimal transport with joint linear mapping estimation as in [8]"""

ot/optim.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from scipy.optimize.linesearch import scalar_search_armijo
88
from .lp import emd
99
from .bregman import sinkhorn_stabilized
10+
from .bregman import sinkhorn
1011

1112
# The corresponding scipy function does not work for matrices
1213
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
@@ -195,7 +196,7 @@ def cost(G):
195196
else:
196197
return G
197198

198-
def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=False):
199+
def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 10,numInnerItermax = 200,stopThr=1e-9,verbose=False,log=False):
199200
"""
200201
Solve the general regularized OT problem with the generalized conditional gradient
201202
@@ -235,6 +236,8 @@ def gcg(a,b,M,reg1,reg2,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False
235236
initial guess (default is indep joint density)
236237
numItermax : int, optional
237238
Max number of iterations
239+
numInnerItermax : int, optional
240+
Max number of iterations of Sinkhorn
238241
stopThr : float, optional
239242
Stop threshol on error (>0)
240243
verbose : bool, optional
@@ -293,16 +296,16 @@ def cost(G):
293296

294297
# problem linearization
295298
Mi=M+reg2*df(G)
296-
# set M positive
297-
Mi+=Mi.min()
298299

299300
# solve linear program with Sinkhorn
300-
Gc = sinkhorn_stabilized(a,b, Mi, reg1)
301+
#Gc = sinkhorn_stabilized(a,b, Mi, reg1, numItermax = numInnerItermax)
302+
Gc = sinkhorn(a,b, Mi, reg1, numItermax = numInnerItermax)
301303

302304
deltaG=Gc-G
303305

304306
# line search
305-
alpha,fc,f_val = line_search_armijo(cost,G,deltaG,Mi,f_val)
307+
dcost=Mi+reg1*np.sum(deltaG*(1+np.log(G))) #??
308+
alpha,fc,f_val = line_search_armijo(cost,G,deltaG,dcost,f_val)
306309

307310
G=G+alpha*deltaG
308311

0 commit comments

Comments
 (0)