Skip to content

Commit 8fc74ed

Browse files
author
Leo gautheron
committed
performance improvement sinkhorn lpl1
- instead of updating individually for each target examples, update for all target examples at once using numpy functions. This allows for a faster computation (for me, divided by 4 on 3000*100 random matricies and random labels in [0,1]). - if I understoud correctly, a value of -1 in the array labels_a meant that we didn't have a label for this example. But in machine learning, we often encounter the binary case where we say we have the positive class (+1) and negative class (-1); thus with a dataset like this, the algorithm wouldn't work as expected. I replaced the default value for 'no label' to '-99' instead of '-1', and I added a parameter to modify it.
1 parent bd325a3 commit 8fc74ed

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

ot/da.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from .optim import gcg
1212

1313

14-
def indices(a, func):
15-
return [i for (i, val) in enumerate(a) if func(val)]
16-
17-
18-
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False):
14+
def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,unlabelledValue=-99,verbose=False,log=False):
1915
"""
2016
Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
2117
@@ -46,7 +42,7 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
4642
labels_a : np.ndarray (ns,)
4743
labels of samples in the source domain
4844
b : np.ndarray (nt,)
49-
samples in the target domain
45+
samples weights in the target domain
5046
M : np.ndarray (ns,nt)
5147
loss matrix
5248
reg : float
@@ -59,6 +55,8 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
5955
Max number of iterations (inner sinkhorn solver)
6056
stopInnerThr : float, optional
6157
Stop threshold on error (inner sinkhorn solver) (>0)
58+
unlabelledValue : int, optional
59+
this value in array labels_a means this is an unlabelled example
6260
verbose : bool, optional
6361
Print information along iterations
6462
log : bool, optional
@@ -94,9 +92,9 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
9492
Nfin = len(b)
9593

9694
indices_labels = []
97-
idx_begin = np.min(labels_a)
98-
for c in range(idx_begin,np.max(labels_a)+1):
99-
idxc = indices(labels_a, lambda x: x==c)
95+
classes = np.unique(labels_a)
96+
for c in classes:
97+
idxc, = np.where(labels_a == c)
10098
indices_labels.append(idxc)
10199

102100
W=np.zeros(M.shape)
@@ -106,20 +104,21 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
106104
transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr)
107105
# the transport has been computed. Check if classes are really separated
108106
W = np.ones((Nini,Nfin))
109-
for t in range(Nfin):
110-
column = transp[:,t]
111-
all_maj = []
112-
for c in range(idx_begin,np.max(labels_a)+1):
113-
col_c = column[indices_labels[c-idx_begin]]
114-
if c!=-1:
115-
maj = p*((sum(col_c)+epsilon)**(p-1))
116-
W[indices_labels[c-idx_begin],t]=maj
117-
all_maj.append(maj)
118-
119-
# now we majorize the unlabelled by the min of the majorizations
120-
# do it only for unlabbled data
121-
if idx_begin==-1:
122-
W[indices_labels[0],t]=np.min(all_maj)
107+
all_majs = []
108+
idx_unlabelled = -1
109+
for (i, c) in enumerate(classes):
110+
if c != unlabelledValue:
111+
majs = np.sum(transp[indices_labels[i]], axis=0)
112+
majs = p*((majs+epsilon)**(p-1))
113+
W[indices_labels[i]] = majs
114+
all_majs.append(majs)
115+
else:
116+
idx_unlabelled = i
117+
118+
# now we majorize the unlabelled (if there are any) by the min of
119+
# the majorizations. do it only for unlabbled data
120+
if idx_unlabelled != -1:
121+
W[indices_labels[idx_unlabelled]] = np.min(all_majs, axis=0)
123122

124123
return transp
125124

0 commit comments

Comments
 (0)