|
11 | 11 | from .optim import gcg |
12 | 12 |
|
13 | 13 |
|
14 | | -def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,unlabelledValue=-99,verbose=False,log=False): |
| 14 | +def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerItermax = 200,stopInnerThr=1e-9,verbose=False,log=False): |
15 | 15 | """ |
16 | 16 | Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization |
17 | 17 |
|
@@ -55,8 +55,6 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter |
55 | 55 | Max number of iterations (inner sinkhorn solver) |
56 | 56 | stopInnerThr : float, optional |
57 | 57 | Stop threshold on error (inner sinkhorn solver) (>0) |
58 | | - unlabelledValue : int, optional |
59 | | - this value in array labels_a means this is an unlabelled example |
60 | 58 | verbose : bool, optional |
61 | 59 | Print information along iterations |
62 | 60 | log : bool, optional |
@@ -84,41 +82,28 @@ def sinkhorn_lpl1_mm(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter |
84 | 82 | ot.optim.cg : General regularized OT |
85 | 83 |
|
86 | 84 | """ |
87 | | - p=0.5 |
| 85 | + p = 0.5 |
88 | 86 | epsilon = 1e-3 |
89 | 87 |
|
90 | | - # init data |
91 | | - Nini = len(a) |
92 | | - Nfin = len(b) |
93 | | - |
94 | 88 | indices_labels = [] |
95 | 89 | classes = np.unique(labels_a) |
96 | 90 | for c in classes: |
97 | 91 | idxc, = np.where(labels_a == c) |
98 | 92 | indices_labels.append(idxc) |
99 | 93 |
|
100 | | - W=np.zeros(M.shape) |
| 94 | + W = np.zeros(M.shape) |
101 | 95 |
|
102 | 96 | for cpt in range(numItermax): |
103 | 97 | Mreg = M + eta*W |
104 | | - transp=sinkhorn(a,b,Mreg,reg,numItermax=numInnerItermax, stopThr=stopInnerThr) |
105 | | - # the transport has been computed. Check if classes are really separated |
106 | | - W = np.ones((Nini,Nfin)) |
107 | | - all_majs = [] |
108 | | - idx_unlabelled = -1 |
| 98 | + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, |
| 99 | + stopThr=stopInnerThr) |
| 100 | + # the transport has been computed. Check if classes are really |
| 101 | + # separated |
| 102 | + W = np.ones(M.shape) |
109 | 103 | for (i, c) in enumerate(classes): |
110 | | - if c != unlabelledValue: |
111 | | - majs = np.sum(transp[indices_labels[i]], axis=0) |
112 | | - majs = p*((majs+epsilon)**(p-1)) |
113 | | - W[indices_labels[i]] = majs |
114 | | - all_majs.append(majs) |
115 | | - else: |
116 | | - idx_unlabelled = i |
117 | | - |
118 | | - # now we majorize the unlabelled (if there are any) by the min of |
119 | | - # the majorizations. do it only for unlabbled data |
120 | | - if idx_unlabelled != -1: |
121 | | - W[indices_labels[idx_unlabelled]] = np.min(all_majs, axis=0) |
| 104 | + majs = np.sum(transp[indices_labels[i]], axis=0) |
| 105 | + majs = p*((majs+epsilon)**(p-1)) |
| 106 | + W[indices_labels[i]] = majs |
122 | 107 |
|
123 | 108 | return transp |
124 | 109 |
|
|
0 commit comments