66import numpy as np
77# import compiled emd
88from .emd import emd_c
9-
9+ import multiprocessing
1010
1111def emd (a , b , M ):
1212 """Solves the Earth Movers distance problem and returns the OT matrix
@@ -70,9 +70,114 @@ def emd(a, b, M):
7070 b = np .asarray (b , dtype = np .float64 )
7171 M = np .asarray (M , dtype = np .float64 )
7272
73+ # if empty array given then use unifor distributions
7374 if len (a ) == 0 :
7475 a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
7576 if len (b ) == 0 :
7677 b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
7778
7879 return emd_c (a , b , M )
80+
81+ def emd2 (a , b , M ,processes = None ):
82+ """Solves the Earth Movers distance problem and returns the loss
83+
84+ .. math::
85+ \gamma = arg\min_\gamma <\gamma,M>_F
86+
87+ s.t. \gamma 1 = a
88+ \gamma^T 1= b
89+ \gamma\geq 0
90+ where :
91+
92+ - M is the metric cost matrix
93+ - a and b are the sample weights
94+
95+ Uses the algorithm proposed in [1]_
96+
97+ Parameters
98+ ----------
99+ a : (ns,) ndarray, float64
100+ Source histogram (uniform weigth if empty list)
101+ b : (nt,) ndarray, float64
102+ Target histogram (uniform weigth if empty list)
103+ M : (ns,nt) ndarray, float64
104+ loss matrix
105+
106+ Returns
107+ -------
108+ gamma: (ns x nt) ndarray
109+ Optimal transportation matrix for the given parameters
110+
111+
112+ Examples
113+ --------
114+
115+ Simple example with obvious solution. The function emd accepts lists and
116+ perform automatic conversion to numpy arrays
117+ >>> import ot
118+ >>> a=[.5,.5]
119+ >>> b=[.5,.5]
120+ >>> M=[[0.,1.],[1.,0.]]
121+ >>> ot.emd2(a,b,M)
122+ 0.0
123+
124+ References
125+ ----------
126+
127+ .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
128+ (2011, December). Displacement interpolation using Lagrangian mass
129+ transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
130+ 158). ACM.
131+
132+ See Also
133+ --------
134+ ot.bregman.sinkhorn : Entropic regularized OT
135+ ot.optim.cg : General regularized OT"""
136+
137+ a = np .asarray (a , dtype = np .float64 )
138+ b = np .asarray (b , dtype = np .float64 )
139+ M = np .asarray (M , dtype = np .float64 )
140+
141+ # if empty array given then use unifor distributions
142+ if len (a ) == 0 :
143+ a = np .ones ((M .shape [0 ], ), dtype = np .float64 )/ M .shape [0 ]
144+ if len (b ) == 0 :
145+ b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
146+
147+ if len (b .shape )== 1 :
148+ return np .sum (emd_c (a , b , M )* M )
149+ else :
150+ nb = b .shape [1 ]
151+ ls = [(a ,b [:,k ],M ) for k in range (nb )]
152+ # run emd in multiprocessing
153+ res = parmap (emd2 , ls ,processes )
154+ np .array (res )
155+ # with Pool(processes) as p:
156+ # res=p.map(f, ls)
157+ # return np.array(res)
158+
159+
160+ def fun (f , q_in , q_out ):
161+ while True :
162+ i , x = q_in .get ()
163+ if i is None :
164+ break
165+ q_out .put ((i , f (x )))
166+
167+ def parmap (f , X , nprocs ):
168+ q_in = multiprocessing .Queue (1 )
169+ q_out = multiprocessing .Queue ()
170+
171+ proc = [multiprocessing .Process (target = fun , args = (f , q_in , q_out ))
172+ for _ in range (nprocs )]
173+ for p in proc :
174+ p .daemon = True
175+ p .start ()
176+
177+ sent = [q_in .put ((i , x )) for i , x in enumerate (X )]
178+ [q_in .put ((None , None )) for _ in range (nprocs )]
179+ res = [q_out .get () for _ in range (len (sent ))]
180+
181+ [p .join () for p in proc ]
182+
183+ return [x for i , x in sorted (res )]
0 commit comments