1+ # -*- coding: utf-8 -*-
2+ """
3+ LP solvers for optimal transport using cvxopt
4+ """
5+
6+ # Author: Remi Flamary <remi.flamary@unice.fr>
7+ #
8+ # License: MIT License
9+
10+ import numpy as np
11+ import scipy as sp
12+ import scipy .sparse as sps
13+
14+ try :
15+ import cvxopt
16+ from cvxopt import solvers , matrix , sparse , spmatrix
17+ except ImportError :
18+ cvxopt = False
19+
20+ def scipy_sparse_to_spmatrix (A ):
21+ """Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
22+ coo = A .tocoo ()
23+ SP = spmatrix (coo .data .tolist (), coo .row .tolist (), coo .col .tolist (), size = A .shape )
24+ return SP
25+
26+ def barycenter (A , M , weights = None , verbose = False , log = False ,solver = 'interior-point' ):
27+ """Compute the entropic regularized wasserstein barycenter of distributions A
28+
29+ The function solves the following optimization problem [16]:
30+
31+ .. math::
32+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
33+
34+ where :
35+
36+ - :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
37+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
38+
39+ The linear program is solved using the default cvxopt solver if installed.
40+ If cvxopt is not installed it uses the lp solver from scipy.optimize.
41+
42+ Parameters
43+ ----------
44+ A : np.ndarray (d,n)
45+ n training distributions of size d
46+ M : np.ndarray (d,d)
47+ loss matrix for OT
48+ reg : float
49+ Regularization term >0
50+ weights : np.ndarray (n,)
51+ Weights of each histogram i_i on the simplex
52+ verbose : bool, optional
53+ Print information along iterations
54+ log : bool, optional
55+ record log if True
56+ solver : string, optional
57+ the solver used, default 'interior-point' use the lp solver from
58+ scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
59+
60+ Returns
61+ -------
62+ a : (d,) ndarray
63+ Wasserstein barycenter
64+ log : dict
65+ log dictionary return only if log==True in parameters
66+
67+
68+ References
69+ ----------
70+
71+ .. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
72+
73+
74+
75+ """
76+
77+ if weights is None :
78+ weights = np .ones (A .shape [1 ]) / A .shape [1 ]
79+ else :
80+ assert (len (weights ) == A .shape [1 ])
81+
82+ n_distributions = A .shape [1 ]
83+ n = A .shape [0 ]
84+
85+ n2 = n * n
86+ c = np .zeros ((0 ))
87+ b_eq1 = np .zeros ((0 ))
88+ for i in range (n_distributions ):
89+ c = np .concatenate ((c ,M .ravel ()* weights [i ]))
90+ b_eq1 = np .concatenate ((b_eq1 ,A [:,i ]))
91+ c = np .concatenate ((c ,np .zeros (n )))
92+
93+ lst_idiag1 = [sps .kron (sps .eye (n ),np .ones ((1 ,n ))) for i in range (n_distributions )]
94+ # row constraints
95+ A_eq1 = sps .hstack ((sps .block_diag (lst_idiag1 ),sps .coo_matrix ((n_distributions * n ,n ))))
96+
97+ # columns constraints
98+ lst_idiag2 = []
99+ lst_eye = []
100+ for i in range (n_distributions ):
101+ if i == 0 :
102+ lst_idiag2 .append (sps .kron (np .ones ((1 ,n )),sps .eye (n )))
103+ lst_eye .append (- sps .eye (n ))
104+ else :
105+ lst_idiag2 .append (sps .kron (np .ones ((1 ,n )),sps .eye (n - 1 ,n )))
106+ lst_eye .append (- sps .eye (n - 1 ,n ))
107+
108+ A_eq2 = sps .hstack ((sps .block_diag (lst_idiag2 ),sps .vstack (lst_eye )))
109+ b_eq2 = np .zeros ((A_eq2 .shape [0 ]))
110+
111+ # full problem
112+ A_eq = sps .vstack ((A_eq1 ,A_eq2 ))
113+ b_eq = np .concatenate ((b_eq1 ,b_eq2 ))
114+
115+ if not cvxopt or solver in ['interior-point' ]: # cvxopt not installed or simplex/interior point
116+
117+ if solver is None :
118+ solver = 'interior-point'
119+
120+ options = {'sparse' :True ,'disp' : verbose }
121+ sol = sp .optimize .linprog (c ,A_eq = A_eq ,b_eq = b_eq ,method = solver ,options = options )
122+ x = sol .x
123+ b = x [- n :]
124+
125+ else :
126+
127+ h = np .zeros ((n_distributions * n2 + n ))
128+ G = - sps .eye (n_distributions * n2 + n )
129+
130+ sol = solvers .lp (matrix (c ),scipy_sparse_to_spmatrix (G ),matrix (h ),A = scipy_sparse_to_spmatrix (A_eq ),b = matrix (b_eq ),solver = solver )
131+
132+ x = np .array (sol ['x' ])
133+ b = x [- n :].ravel ()
134+
135+ if log :
136+ return b , sol
137+ else :
138+ return b
0 commit comments