Skip to content

Commit fb95157

Browse files
author
Leo gautheron
committed
more changes from feeback
in addition add the posibility to normalize the cost matrix through the function fit
1 parent 9c5cc82 commit fb95157

File tree

4 files changed

+48
-24
lines changed

4 files changed

+48
-24
lines changed

ot/da.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def __init__(self,metric='sqeuclidean'):
606606
self.computed=False
607607

608608

609-
def fit(self,xs,xt,ws=None,wt=None):
609+
def fit(self,xs,xt,ws=None,wt=None,norm=None):
610610
""" Fit domain adaptation between samples is xs and xt (with optional weights)"""
611611
self.xs=xs
612612
self.xt=xt
@@ -620,6 +620,7 @@ def fit(self,xs,xt,ws=None,wt=None):
620620
self.wt=wt
621621

622622
self.M=dist(xs,xt,metric=self.metric)
623+
self.normalize()
623624
self.G=emd(ws,wt,self.M)
624625
self.computed=True
625626

@@ -684,11 +685,25 @@ def predict(self,x,direction=1):
684685
xf=self.interp(direction)# interp the source samples
685686
return xf[idx,:]+x-x0[idx,:] # aply the delta to the interpolation
686687

688+
def normalizeM(self, norm):
689+
"""
690+
It may help to normalize the cost matrix self.M if there are numerical
691+
errors during the sinkhorn based algorithms.
692+
"""
693+
if norm == "median":
694+
self.M /= float(np.median(self.M))
695+
elif norm == "max":
696+
self.M /= float(np.max(self.M))
697+
elif norm == "log":
698+
self.M = np.log(1 + self.M)
699+
elif norm == "loglog":
700+
self.M = np.log(1 + np.log(1 + self.M))
701+
687702

688703
class OTDA_sinkhorn(OTDA):
689704
"""Class for domain adaptation with optimal transport with entropic regularization"""
690705

691-
def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
706+
def fit(self,xs,xt,reg=1,ws=None,wt=None,norm=None,**kwargs):
692707
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
693708
self.xs=xs
694709
self.xt=xt
@@ -702,6 +717,7 @@ def fit(self,xs,xt,reg=1,ws=None,wt=None,**kwargs):
702717
self.wt=wt
703718

704719
self.M=dist(xs,xt,metric=self.metric)
720+
self.normalizeM(norm)
705721
self.G=sinkhorn(ws,wt,self.M,reg,**kwargs)
706722
self.computed=True
707723

@@ -710,7 +726,7 @@ class OTDA_lpl1(OTDA):
710726
"""Class for domain adaptation with optimal transport with entropic and group regularization"""
711727

712728

713-
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
729+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs):
714730
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
715731
self.xs=xs
716732
self.xt=xt
@@ -724,14 +740,15 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
724740
self.wt=wt
725741

726742
self.M=dist(xs,xt,metric=self.metric)
743+
self.normalizeM(norm)
727744
self.G=sinkhorn_lpl1_mm(ws,ys,wt,self.M,reg,eta,**kwargs)
728745
self.computed=True
729746

730747
class OTDA_l1l2(OTDA):
731748
"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
732749

733750

734-
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
751+
def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,norm=None,**kwargs):
735752
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
736753
self.xs=xs
737754
self.xt=xt
@@ -745,6 +762,7 @@ def fit(self,xs,ys,xt,reg=1,eta=1,ws=None,wt=None,**kwargs):
745762
self.wt=wt
746763

747764
self.M=dist(xs,xt,metric=self.metric)
765+
self.normalizeM(norm)
748766
self.G=sinkhorn_l1l2_gl(ws,ys,wt,self.M,reg,eta,**kwargs)
749767
self.computed=True
750768

ot/gpu/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22

33
from . import bregman
44
from . import da
5+
from .bregman import sinkhorn
56

6-
__all__ = ["bregman", "da"]
7+
__all__ = ["bregman", "da", "sinkhorn"]

ot/gpu/bregman.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""
55

66
import numpy as np
7+
import cudamat
78

89

910
def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
10-
log=False, cudamat=None):
11+
log=False):
1112
# init data
1213
Nini = len(a)
1314
Nfin = len(b)
@@ -74,7 +75,6 @@ def sinkhorn(a, b, M_GPU, reg, numItermax=1000, stopThr=1e-9, verbose=False,
7475
log['u'] = u_GPU.asarray()
7576
log['v'] = v_GPU.asarray()
7677

77-
# print('err=',err,' cpt=',cpt)
7878
K_GPU.mult_by_col(u_GPU, target=K_GPU)
7979
K_GPU.mult_by_row(v_GPU.transpose(), target=K_GPU)
8080
if log:

ot/gpu/da.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from ..utils import unif
88
from ..da import OTDA
99
from .bregman import sinkhorn
10+
import cudamat
1011

1112

12-
def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None):
13+
def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False):
1314
# a is shape (n, f) and b shape (m, f). Return matrix c of shape (n, m).
1415
# First compute in c_GPU the squared euclidean distance. And return its
1516
# square root. At each cell [i,j] of c, we want to have
@@ -45,9 +46,24 @@ def pairwiseEuclideanGPU(a, b, returnAsGPU=False, squared=False, cudamat=None):
4546
return c_GPU.asarray()
4647

4748

48-
class OTDA_sinkhorn_GPU(OTDA):
49+
class OTDA_GPU(OTDA):
50+
def normalizeM(self, norm):
51+
if norm == "median":
52+
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
53+
elif norm == "max":
54+
self.M_GPU.divide(float(np.max(self.M_GPU.asarray())))
55+
elif norm == "log":
56+
self.M_GPU.add(1)
57+
cudamat.log(self.M_GPU)
58+
elif norm == "loglog":
59+
self.M_GPU.add(1)
60+
cudamat.log(self.M_GPU)
61+
self.M_GPU.add(1)
62+
cudamat.log(self.M_GPU)
63+
64+
65+
class OTDA_sinkhorn(OTDA_GPU):
4966
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
50-
import cudamat
5167
cudamat.init()
5268
xs = np.asarray(xs, dtype=np.float64)
5369
xt = np.asarray(xt, dtype=np.float64)
@@ -64,18 +80,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
6480
self.wt = wt
6581

6682
self.M_GPU = pairwiseEuclideanGPU(xs, xt, returnAsGPU=True,
67-
squared=True, cudamat=cudamat)
68-
69-
if norm == "median":
70-
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
71-
elif norm == "max":
72-
self.M_GPU.divide(float(np.max(self.M_GPU.asarray())))
73-
elif norm == "log":
74-
M = np.log(1 + self.M_GPU.asarray())
75-
self.M_GPU = cudamat.CUDAMatrix(M)
76-
elif norm == "loglog":
77-
M = np.log(1 + np.log(1 + self.M_GPU.asarray()))
78-
self.M_GPU = cudamat.CUDAMatrix(M)
79-
80-
self.G = sinkhorn(ws, wt, self.M_GPU, reg, cudamat=cudamat, **kwargs)
83+
squared=True)
84+
self.normalizeM(norm)
85+
self.G = sinkhorn(ws, wt, self.M_GPU, reg, **kwargs)
8186
self.computed = True

0 commit comments

Comments
 (0)