|
10 | 10 | import numpy as np |
11 | 11 | import scipy as sp |
12 | 12 | import scipy.sparse as sps |
| 13 | +import ot |
13 | 14 |
|
14 | 15 | try: |
15 | 16 | import cvxopt |
@@ -144,3 +145,82 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po |
144 | 145 | return b, sol |
145 | 146 | else: |
146 | 147 | return b |
| 148 | + |
| 149 | + |
| 150 | + |
| 151 | + |
| 152 | +def free_support_barycenter(data_positions, data_weights, X_init, b_init, lamda, numItermax=100, stopThr=1e-5, verbose=False, log=False, **kwargs): |
| 153 | + |
| 154 | + """ |
| 155 | + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) |
| 156 | +
|
| 157 | + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. |
| 158 | + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: |
| 159 | + - we do not optimize over the weights |
| 160 | + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. |
| 161 | +
|
| 162 | + Parameters |
| 163 | + ---------- |
| 164 | + data_positions : list of (k_i,d) np.ndarray |
| 165 | + The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) |
| 166 | + data_weights : list of (k_i,) np.ndarray |
| 167 | + Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure |
| 168 | +
|
| 169 | + X_init : (k,d) np.ndarray |
| 170 | + Initialization of the support locations (on k atoms) of the barycenter |
| 171 | + b_init : (k,) np.ndarray |
| 172 | + Initialization of the weights of the barycenter (non-negatives, sum to 1) |
| 173 | + lambda : (k,) np.ndarray |
| 174 | + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) |
| 175 | +
|
| 176 | + numItermax : int, optional |
| 177 | + Max number of iterations |
| 178 | + stopThr : float, optional |
| 179 | + Stop threshol on error (>0) |
| 180 | + verbose : bool, optional |
| 181 | + Print information along iterations |
| 182 | + log : bool, optional |
| 183 | + record log if True |
| 184 | +
|
| 185 | + Returns |
| 186 | + ------- |
| 187 | + X : (k,d) np.ndarray |
| 188 | + Support locations (on k atoms) of the barycenter |
| 189 | +
|
| 190 | + References |
| 191 | + ---------- |
| 192 | +
|
| 193 | + .. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. |
| 194 | +
|
| 195 | + .. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. |
| 196 | +
|
| 197 | + """ |
| 198 | + |
| 199 | + iter_count = 0 |
| 200 | + |
| 201 | + d = X_init.shape[1] |
| 202 | + k = b_init.size |
| 203 | + N = len(data_positions) |
| 204 | + |
| 205 | + X = X_init |
| 206 | + |
| 207 | + displacement_square_norm = 1e3 |
| 208 | + |
| 209 | + while ( displacement_square_norm > stopThr and iter_count < numItermax ): |
| 210 | + |
| 211 | + T_sum = np.zeros((k, d)) |
| 212 | + |
| 213 | + for (data_positions_i, data_weights_i) in zip(data_positions, data_weights): |
| 214 | + M_i = ot.dist(X, data_positions_i) |
| 215 | + T_i = ot.emd(b_init, data_weights_i, M_i) |
| 216 | + T_sum += np.reshape(1. / b_init, (-1, 1)) * np.matmul(T_i, data_positions_i) |
| 217 | + |
| 218 | + X_previous = X |
| 219 | + X = T_sum / N |
| 220 | + |
| 221 | + displacement_square_norm = np.sum(np.square(X-X_previous)) |
| 222 | + |
| 223 | + iter_count += 1 |
| 224 | + |
| 225 | + return X |
| 226 | + |
0 commit comments