Skip to content

Commit 44b8462

Browse files
authored
Merge pull request #8 from aje/master
sinkhorn GPU implementation
2 parents 48ec27d + fb95157 commit 44b8462

File tree

6 files changed

+213
-18
lines changed

6 files changed

+213
-18
lines changed

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020

2121
__version__ = "0.2"
2222

23-
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
23+
__all__ = ["emd", "emd2", "sinkhorn", "utils", 'datasets', 'bregman', 'lp',
2424
'plot', 'tic', 'toc', 'toq',
2525
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/bregman.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,11 @@ def sinkhorn(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=Fa
112112
while (err>stopThr and cpt<numItermax):
113113
uprev = u
114114
vprev = v
115-
v = np.divide(b,np.dot(K.T,u))
115+
KtransposeU = np.dot(K.T, u)
116+
v = np.divide(b, KtransposeU)
116117
u = 1./np.dot(Kp,v)
117-
if (np.any(np.dot(K.T,u)==0) or
118+
119+
if (np.any(KtransposeU==0) or
118120
np.any(np.isnan(u)) or np.any(np.isnan(v)) or
119121
np.any(np.isinf(u)) or np.any(np.isinf(v))):
120122
# we have reached the machine precision

ot/da.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -193,30 +193,30 @@ def sinkhorn_l1l2_gl(a,labels_a, b, M, reg, eta=0.1,numItermax = 10,numInnerIter
193193
194194
"""
195195
lstlab=np.unique(labels_a)
196-
196+
197197
def f(G):
198198
res=0
199199
for i in range(G.shape[1]):
200200
for lab in lstlab:
201201
temp=G[labels_a==lab,i]
202-
res+=np.linalg.norm(temp)
202+
res+=np.linalg.norm(temp)
203203
return res
204-
204+
205205
def df(G):
206-
W=np.zeros(G.shape)
206+
W=np.zeros(G.shape)
207207
for i in range(G.shape[1]):
208208
for lab in lstlab:
209209
temp=G[labels_a==lab,i]
210210
n=np.linalg.norm(temp)
211211
if n:
212-
W[labels_a==lab,i]=temp/n
213-
return W
212+
W[labels_a==lab,i]=temp/n
213+
return W
214+
214215

215-
216216
return gcg(a,b,M,reg,eta,f,df,G0=None,numItermax = numItermax,numInnerItermax=numInnerItermax, stopThr=stopInnerThr,verbose=verbose,log=log)
217-
218-
219-
217+
218+
219+
220220
def joint_OT_mapping_linear(xs,xt,mu=1,eta=0.001,bias=False,verbose=False,verbose2=False,numItermax = 100,numInnerItermax = 10,stopInnerThr=1e-6,stopThr=1e-5,log=False,**kwargs):
221221
"""Joint OT and linear mapping estimation as proposed in [8]
222222
@@ -606,7 +606,7 @@ def __init__(self,metric='sqeuclidean'):
606606
self.computed=False
607607

608608

609-
def fit(self,xs,xt,ws=None,wt=None):
609+
def fit(self,xs,xt,ws=None,wt=None,norm=None):
610610
""" Fit domain adaptation between samples is xs and xt (with optional weights)"""
611611
self.xs=xs
612612
self.xt=xt
@@ -620,6 +620,7 @@ def fit(self,xs,xt,ws=None,wt=None):
620620
self.wt=wt
621621

622622
self.M=dist(xs,xt,metric=self.metric)
623+
self.normalize()
623624
self.G=emd(ws,wt,self.M)
624625
self.computed=True
625626

@@ -684,12 +685,25 @@ def predict(self,x,direction=1):
684685
xf=self.interp(direction)# interp the source samples
685686
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
686687

688+
def normalizeM(self, norm):
689+
"""
690+
It may help to normalize the cost matrix self.M if there are numerical
691+
errors during the sinkhorn based algorithms.
692+
"""
693+
if norm == "median":
694+
self.M /= float(np.median(self.M))
695+
elif norm == "max":
696+
self.M /= float(np.max(self.M))
697+
elif norm == "log":
698+
self.M = np.log(1 + self.M)
699+
elif norm == "loglog":
700+
self.M = np.log(1 + np.log(1 + self.M))
687701

688702

689703
class OTDA_sinkhorn(OTDA):
690704
"""Class for domain adaptation with optimal transport with entropic regularization"""
691705

692-
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
706+
def fit(self,xs,xt,reg=1,ws=None,wt=None,norm=None,**kwargs):
693707
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
694708
self.xs=xs
695709
self.xt=xt
@@ -703,6 +717,7 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
703717
self.wt=wt
704718

705719
self.M=dist(xs,xt,metric=self.metric)
720+
self.normalizeM(norm)
706721
self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
707722
self.computed=True
708723

@@ -711,7 +726,7 @@ class OTDA_lpl1(OTDA):
711726
"""Class for domain adaptation with optimal transport with entropic and group regularization"""
712727

713728

714-
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
729+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs):
715730
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
716731
self.xs=xs
717732
self.xt=xt
@@ -725,14 +740,15 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
725740
self.wt=wt
726741

727742
self.M=dist(xs,xt,metric=self.metric)
743+
self.normalizeM(norm)
728744
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
729745
self.computed=True
730-
746+
731747
class OTDA_l1l2(OTDA):
732748
"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
733749

734750

735-
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
751+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs):
736752
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
737753
self.xs=xs
738754
self.xt=xt
@@ -746,6 +762,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
746762
self.wt=wt
747763

748764
self.M=dist(xs,xt,metric=self.metric)
765+
self.normalizeM(norm)
749766
self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs)
750767
self.computed=True
751768

ot/gpu/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from . import bregman
4+
from . import da
5+
from .bregman import sinkhorn
6+
7+
__all__ = ["bregman", "da", "sinkhorn"]

ot/gpu/bregman.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Bregman projections for regularized OT with GPU
4+
"""
5+
6+
import numpy as np
7+
import cudamat
8+
9+
10+
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
11+
log=False):
12+
# init data
13+
Nini = len(a)
14+
Nfin = len(b)
15+
16+
if log:
17+
log = {'err': []}
18+
19+
# we assume that no distances are null except those of the diagonal of
20+
# distances
21+
u = (np.ones(Nini)/Nini).reshape((Nini, 1))
22+
u_GPU = cudamat.CUDAMatrix(u)
23+
a_GPU = cudamat.CUDAMatrix(a.reshape((Nini, 1)))
24+
ones_GPU = cudamat.empty(u_GPU.shape).assign(1)
25+
v = (np.ones(Nfin)/Nfin).reshape((Nfin, 1))
26+
v_GPU = cudamat.CUDAMatrix(v)
27+
b_GPU = cudamat.CUDAMatrix(b.reshape((Nfin, 1)))
28+
29+
M_GPU.divide(-reg)
30+
31+
K_GPU = cudamat.exp(M_GPU)
32+
33+
ones_GPU.divide(a_GPU, target=a_GPU)
34+
Kp_GPU = cudamat.empty(K_GPU.shape)
35+
K_GPU.mult_by_col(a_GPU, target=Kp_GPU)
36+
37+
tmp_GPU = cudamat.empty(K_GPU.shape)
38+
39+
cpt = 0
40+
err = 1
41+
while (err > stopThr and cpt < numItermax):
42+
uprev_GPU = u_GPU.copy()
43+
vprev_GPU = v_GPU.copy()
44+
45+
KtransposeU_GPU = K_GPU.transpose().dot(u_GPU)
46+
b_GPU.divide(KtransposeU_GPU, target=v_GPU)
47+
ones_GPU.divide(Kp_GPU.dot(v_GPU), target=u_GPU)
48+
49+
if (np.any(KtransposeU_GPU.asarray() == 0) or
50+
not u_GPU.allfinite() or not v_GPU.allfinite()):
51+
# we have reached the machine precision
52+
# come back to previous solution and quit loop
53+
print('Warning: numerical errors at iteration', cpt)
54+
u_GPU = uprev_GPU.copy()
55+
v_GPU = vprev_GPU.copy()
56+
break
57+
if cpt % 10 == 0:
58+
# we can speed up the process by checking for the error only all
59+
# the 10th iterations
60+
K_GPU.mult_by_col(u_GPU, target=tmp_GPU)
61+
tmp_GPU.mult_by_row(v_GPU.transpose(), target=tmp_GPU)
62+
63+
bcopy_GPU = b_GPU.copy().transpose()
64+
bcopy_GPU.add_sums(tmp_GPU, axis=0, beta=-1)
65+
err = bcopy_GPU.euclid_norm()**2
66+
if log:
67+
log['err'].append(err)
68+
69+
if verbose:
70+
if cpt % 200 == 0:
71+
print('{:5s}|{:12s}'.format('It.', 'Err')+'\n'+'-'*19)
72+
print('{:5d}|{:8e}|'.format(cpt, err))
73+
cpt += 1
74+
if log:
75+
log['u'] = u_GPU.asarray()
76+
log['v'] = v_GPU.asarray()
77+
78+
K_GPU.mult_by_col(u_GPU, target=K_GPU)
79+
K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
80+
if log:
81+
return K_GPU.asarray(), log
82+
else:
83+
return K_GPU.asarray()

ot/gpu/da.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Domain adaptation with optimal transport and GPU
4+
"""
5+
6+
import numpy as np
7+
from ..utils import unif
8+
from ..da import OTDA
9+
from .bregman import sinkhorn
10+
import cudamat
11+
12+
13+
def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False):
14+
# a is shape (n, f) and b shape (m, f). Return matrix c of shape (n, m).
15+
# First compute in c_GPU the squared euclidean distance. And return its
16+
# square root. At each cell [i,j] of c, we want to have
17+
# sum{k in range(f)} ( (a[i,k] - b[j,k])^2 ). We know that
18+
# (a-b)^2 = a^2 -2ab +b^2. Thus we want to have in each cell of c:
19+
# sum{k in range(f)} ( a[i,k]^2 -2a[i,k]b[j,k] +b[j,k]^2).
20+
21+
a_GPU = cudamat.CUDAMatrix(a)
22+
b_GPU = cudamat.CUDAMatrix(b)
23+
24+
# Multiply a by b transpose to obtain in each cell [i,j] of c the
25+
# value sum{k in range(f)} ( a[i,k]b[j,k] )
26+
c_GPU = cudamat.dot(a_GPU, b_GPU.transpose())
27+
# multiply by -2 to have sum{k in range(f)} ( -2a[i,k]b[j,k] )
28+
c_GPU.mult(-2)
29+
30+
# Compute the vectors of the sum of squared elements.
31+
a_GPU = cudamat.pow(a_GPU, 2).sum(axis=1)
32+
b_GPU = cudamat.pow(b_GPU, 2).sum(axis=1)
33+
34+
# Add the vectors in each columns (respectivly rows) of c.
35+
# sum{k in range(f)} ( a[i,k]^2 -2a[i,k]b[j,k] )
36+
c_GPU.add_col_vec(a_GPU)
37+
# sum{k in range(f)} ( a[i,k]^2 -2a[i,k]b[j,k] +b[j,k]^2)
38+
c_GPU.add_row_vec(b_GPU.transpose())
39+
40+
if not squared:
41+
c_GPU = cudamat.sqrt(c_GPU)
42+
43+
if returnAsGPU:
44+
return c_GPU
45+
else:
46+
return c_GPU.asarray()
47+
48+
49+
class OTDA_GPU(OTDA):
50+
def normalizeM(self, norm):
51+
if norm == "median":
52+
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
53+
elif norm == "max":
54+
self.M_GPU.divide(float(np.max(self.M_GPU.asarray())))
55+
elif norm == "log":
56+
self.M_GPU.add(1)
57+
cudamat.log(self.M_GPU)
58+
elif norm == "loglog":
59+
self.M_GPU.add(1)
60+
cudamat.log(self.M_GPU)
61+
self.M_GPU.add(1)
62+
cudamat.log(self.M_GPU)
63+
64+
65+
class OTDA_sinkhorn(OTDA_GPU):
66+
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
67+
cudamat.init()
68+
xs = np.asarray(xs, dtype=np.float64)
69+
xt = np.asarray(xt, dtype=np.float64)
70+
71+
self.xs = xs
72+
self.xt = xt
73+
74+
if wt is None:
75+
wt = unif(xt.shape[0])
76+
if ws is None:
77+
ws = unif(xs.shape[0])
78+
79+
self.ws = ws
80+
self.wt = wt
81+
82+
self.M_GPU = pairwiseEuclideanGPU(xs, xt, returnAsGPU=True,
83+
squared=True)
84+
self.normalizeM(norm)
85+
self.G = sinkhorn(ws, wt, self.M_GPU, reg, **kwargs)
86+
self.computed = True

0 commit comments

Comments
 (0)