55
66import numpy as np
77# import compiled emd
8- from .emd import emd_c
8+ from .emd import emd_c , emd2_c
9+ from ..utils import parmap
910import multiprocessing
1011
12+
13+
1114def emd (a , b , M ):
1215 """Solves the Earth Movers distance problem and returns the OT matrix
1316
@@ -145,41 +148,14 @@ def emd2(a, b, M,processes=multiprocessing.cpu_count()):
145148 b = np .ones ((M .shape [1 ], ), dtype = np .float64 )/ M .shape [1 ]
146149
147150 if len (b .shape )== 1 :
148- return np . sum ( emd_c ( a , b , M ) * M )
151+ return emd2_c ( a , b , M )
149152 else :
150153 nb = b .shape [1 ]
151- ls = [(a ,b [:,k ],M ) for k in range (nb )]
152- def f (l ):
153- return emd2 (l [0 ],l [1 ],l [2 ])
154- # run emd in multiprocessing
155- res = parmap (f , ls ,processes )
154+ #res=[emd2_c(a,b[:,i].copy(),M) for i in range(nb)]
155+ def f (b ):
156+ return emd2_c (a ,b ,M )
157+ res = parmap (f , [b [:,i ] for i in range (nb )],processes )
156158 return np .array (res )
157- # with Pool(processes) as p:
158- # res=p.map(f, ls)
159- # return np.array(res)
160159
161160
162- def fun (f , q_in , q_out ):
163- while True :
164- i , x = q_in .get ()
165- if i is None :
166- break
167- q_out .put ((i , f (x )))
168-
169- def parmap (f , X , nprocs = multiprocessing .cpu_count ()):
170- q_in = multiprocessing .Queue (1 )
171- q_out = multiprocessing .Queue ()
172-
173- proc = [multiprocessing .Process (target = fun , args = (f , q_in , q_out ))
174- for _ in range (nprocs )]
175- for p in proc :
176- p .daemon = True
177- p .start ()
178-
179- sent = [q_in .put ((i , x )) for i , x in enumerate (X )]
180- [q_in .put ((None , None )) for _ in range (nprocs )]
181- res = [q_out .get () for _ in range (len (sent ))]
182-
183- [p .join () for p in proc ]
184-
185- return [x for i , x in sorted (res )]
161+
0 commit comments