Skip to content

Commit 6484c9e

Browse files
committed
Tests + contributions
1 parent 11c2c26 commit 6484c9e

File tree

4 files changed

+89
-4
lines changed

4 files changed

+89
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ The contributors to this library are:
164164
* Erwan Vautier (Gromov-Wasserstein)
165165
* [Kilian Fatras](https://kilianfatras.github.io/)
166166
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
167+
* [Vayer Titouan](https://tvayer.github.io/)
167168

168169
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
169170

ot/gromov.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,10 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
926926
"Optimal Transport for structured data with application on graphs"
927927
International Conference on Machine Learning (ICML). 2019.
928928
"""
929+
930+
class UndefinedParameter(Exception):
931+
pass
932+
929933
S = len(Cs)
930934
d = Ys[0].shape[1] #dimension on the node features
931935
if p is None:
@@ -938,7 +942,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
938942

939943
if fixed_structure:
940944
if init_C is None:
941-
C=Cs[0]
945+
raise UndefinedParameter('If C is fixed it must be initialized')
942946
else:
943947
C=init_C
944948
else:
@@ -950,7 +954,7 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
950954

951955
if fixed_features:
952956
if init_X is None:
953-
X=Ys[0]
957+
raise UndefinedParameter('If X is fixed it must be initialized')
954958
else :
955959
X= init_X
956960
else:
@@ -1004,13 +1008,13 @@ def fgw_barycenters(N,Ys,Cs,ps,lambdas,alpha,fixed_structure=False,fixed_feature
10041008
# Cs is ns,ns
10051009
# p is N,1
10061010
# ps is ns,1
1007-
1011+
10081012
T = [fused_gromov_wasserstein((1-alpha)*Ms[s],C,Cs[s],p,ps[s],loss_fun,alpha,numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10091013

10101014
# T is N,ns
10111015

10121016
log_['Ts_iter'].append(T)
1013-
err_feature = np.linalg.norm(X - Xprev.reshape(d,N))
1017+
err_feature = np.linalg.norm(X - Xprev.reshape(N,d))
10141018
err_structure = np.linalg.norm(C - Cprev)
10151019

10161020
if log:

test/test_gromov.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,78 @@ def test_gromov_entropic_barycenter():
143143
'kl_loss', 2e-3,
144144
max_iter=100, tol=1e-3)
145145
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
146+
147+
def test_fgw():
148+
n_samples = 50 # nb samples
149+
150+
mu_s = np.array([0, 0])
151+
cov_s = np.array([[1, 0], [0, 1]])
152+
153+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
154+
155+
xt = xs[::-1].copy()
156+
157+
ys = np.random.randn(xs.shape[0],2)
158+
yt= ys[::-1].copy()
159+
160+
p = ot.unif(n_samples)
161+
q = ot.unif(n_samples)
162+
163+
C1 = ot.dist(xs, xs)
164+
C2 = ot.dist(xt, xt)
165+
166+
C1 /= C1.max()
167+
C2 /= C2.max()
168+
169+
M=ot.dist(ys,yt)
170+
M/=M.max()
171+
172+
G = ot.gromov.fused_gromov_wasserstein(M,C1, C2, p, q, 'square_loss',alpha=0.5)
173+
174+
# check constratints
175+
np.testing.assert_allclose(
176+
p, G.sum(1), atol=1e-04) # cf convergence fgw
177+
np.testing.assert_allclose(
178+
q, G.sum(0), atol=1e-04) # cf convergence fgw
179+
180+
181+
def test_fgw_barycenter():
182+
183+
ns = 50
184+
nt = 60
185+
186+
Xs, ys = ot.datasets.make_data_classif('3gauss', ns)
187+
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt)
188+
189+
ys = np.random.randn(Xs.shape[0],2)
190+
yt= np.random.randn(Xt.shape[0],2)
191+
192+
C1 = ot.dist(Xs)
193+
C2 = ot.dist(Xt)
194+
195+
n_samples = 3
196+
X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
197+
fixed_structure=False,fixed_features=False,
198+
p=ot.unif(n_samples),loss_fun='square_loss',
199+
max_iter=100, tol=1e-3)
200+
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
201+
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
202+
203+
xalea = np.random.randn(n_samples, 2)
204+
init_C = ot.dist(xalea, xalea)
205+
206+
X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],ps=[ot.unif(ns), ot.unif(nt)],lambdas=[.5, .5],alpha=0.5,
207+
fixed_structure=True,init_C=init_C,fixed_features=False,
208+
p=ot.unif(n_samples),loss_fun='square_loss',
209+
max_iter=100, tol=1e-3)
210+
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
211+
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
212+
213+
init_X=np.random.randn(n_samples,ys.shape[1])
214+
215+
X,C,log = ot.gromov.fgw_barycenters(n_samples,[ys,yt] ,[C1, C2],[ot.unif(ns), ot.unif(nt)],[.5, .5],0.5,
216+
fixed_structure=False,fixed_features=True, init_X=init_X,
217+
p=ot.unif(n_samples),loss_fun='square_loss',
218+
max_iter=100, tol=1e-3)
219+
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
220+
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))

test/test_optim.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ def df(G):
6565

6666
np.testing.assert_allclose(a, G.sum(1), atol=1e-05)
6767
np.testing.assert_allclose(b, G.sum(0), atol=1e-05)
68+
69+
def test_solve_1d_linesearch_quad_funct():
70+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1,-1,0),0.5)
71+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,5,0),0)
72+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1,0.5,0),1)

0 commit comments

Comments
 (0)