|
18 | 18 |
|
19 | 19 |
|
20 | 20 | ############################################################################# |
21 | | -# COMPUTE TRANSPORTATION MATRIX |
| 21 | +# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM |
22 | 22 | ############################################################################# |
23 | | - |
| 23 | +print("------------SEMI-DUAL PROBLEM------------") |
24 | 24 | ############################################################################# |
25 | 25 | # DISCRETE CASE |
26 | 26 | # Sample two discrete measures for the discrete case |
|
48 | 48 | # Call the "SAG" method to find the transportation matrix in the discrete case |
49 | 49 | # --------------------------------------------- |
50 | 50 | # |
51 | | -# Define the method "SAG", call ot.transportation_matrix_entropic and plot the |
| 51 | +# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the |
52 | 52 | # results. |
53 | 53 |
|
54 | 54 | method = "SAG" |
55 | | -sag_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, |
56 | | - numItermax, lr) |
| 55 | +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, |
| 56 | + numItermax, lr) |
57 | 57 | print(sag_pi) |
58 | 58 |
|
59 | 59 | ############################################################################# |
|
68 | 68 | n_source = 7 |
69 | 69 | n_target = 4 |
70 | 70 | reg = 1 |
71 | | -numItermax = 500000 |
| 71 | +numItermax = 100000 |
72 | 72 | lr = 1 |
| 73 | +log = True |
73 | 74 |
|
74 | 75 | a = ot.utils.unif(n_source) |
75 | 76 | b = ot.utils.unif(n_target) |
|
85 | 86 | # case |
86 | 87 | # --------------------------------------------- |
87 | 88 | # |
88 | | -# Define the method "ASGD", call ot.transportation_matrix_entropic and plot the |
| 89 | +# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the |
89 | 90 | # results. |
90 | 91 |
|
91 | 92 | method = "ASGD" |
92 | | -asgd_pi = ot.stochastic.transportation_matrix_entropic(a, b, M, reg, method, |
93 | | - numItermax, lr) |
| 93 | +asgd_pi, log = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, |
| 94 | + numItermax, lr, log) |
| 95 | +print(log['alpha'], log['beta']) |
94 | 96 | print(asgd_pi) |
95 | 97 |
|
96 | 98 | ############################################################################# |
|
100 | 102 | # |
101 | 103 | # Call the Sinkhorn algorithm from POT |
102 | 104 |
|
103 | | -sinkhorn_pi = ot.sinkhorn(a, b, M, 1) |
| 105 | +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) |
104 | 106 | print(sinkhorn_pi) |
105 | 107 |
|
106 | 108 |
|
|
113 | 115 | # ---------------- |
114 | 116 |
|
115 | 117 | pl.figure(4, figsize=(5, 5)) |
116 | | -ot.plot.plot1D_mat(a, b, sag_pi, 'OT matrix SAG') |
| 118 | +ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') |
117 | 119 | pl.show() |
118 | 120 |
|
119 | 121 |
|
|
122 | 124 | # ----------------- |
123 | 125 |
|
124 | 126 | pl.figure(4, figsize=(5, 5)) |
125 | | -ot.plot.plot1D_mat(a, b, asgd_pi, 'OT matrix ASGD') |
| 127 | +ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') |
| 128 | +pl.show() |
| 129 | + |
| 130 | + |
| 131 | +############################################################################## |
| 132 | +# Plot Sinkhorn results |
| 133 | +# --------------------- |
| 134 | + |
| 135 | +pl.figure(4, figsize=(5, 5)) |
| 136 | +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') |
| 137 | +pl.show() |
| 138 | + |
| 139 | +############################################################################# |
| 140 | +# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM |
| 141 | +############################################################################# |
| 142 | +print("------------DUAL PROBLEM------------") |
| 143 | +############################################################################# |
| 144 | +# SEMICONTINOUS CASE |
| 145 | +# Sample one general measure a, one discrete measures b for the semicontinous |
| 146 | +# case |
| 147 | +# --------------------------------------------- |
| 148 | +# |
| 149 | +# Define one general measure a, one discrete measures b, the points where |
| 150 | +# are defined the source and the target measures and finally the cost matrix c. |
| 151 | + |
| 152 | +n_source = 7 |
| 153 | +n_target = 4 |
| 154 | +reg = 1 |
| 155 | +numItermax = 100000 |
| 156 | +lr = 0.1 |
| 157 | +batch_size = 3 |
| 158 | +log = True |
| 159 | + |
| 160 | +a = ot.utils.unif(n_source) |
| 161 | +b = ot.utils.unif(n_target) |
| 162 | + |
| 163 | +rng = np.random.RandomState(0) |
| 164 | +X_source = rng.randn(n_source, 2) |
| 165 | +Y_target = rng.randn(n_target, 2) |
| 166 | +M = ot.dist(X_source, Y_target) |
| 167 | + |
| 168 | +############################################################################# |
| 169 | +# |
| 170 | +# Call the "SGD" dual method to find the transportation matrix in the semicontinous |
| 171 | +# case |
| 172 | +# --------------------------------------------- |
| 173 | +# |
| 174 | +# Call ot.solve_dual_entropic and plot the results. |
| 175 | + |
| 176 | +sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, |
| 177 | + numItermax, lr, log) |
| 178 | +print(log['alpha'], log['beta']) |
| 179 | +print(sgd_dual_pi) |
| 180 | + |
| 181 | +############################################################################# |
| 182 | +# |
| 183 | +# Compare the results with the Sinkhorn algorithm |
| 184 | +# --------------------------------------------- |
| 185 | +# |
| 186 | +# Call the Sinkhorn algorithm from POT |
| 187 | + |
| 188 | +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) |
| 189 | +print(sinkhorn_pi) |
| 190 | + |
| 191 | +############################################################################## |
| 192 | +# Plot SGD results |
| 193 | +# ----------------- |
| 194 | + |
| 195 | +pl.figure(4, figsize=(5, 5)) |
| 196 | +ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') |
126 | 197 | pl.show() |
127 | 198 |
|
128 | 199 |
|
|
0 commit comments