Skip to content

Commit 9421ddd

Browse files
committed
Doc+armijo
1 parent 94d2fe5 commit 9421ddd

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

ot/gromov.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'):
3333
* C2 : Metric cost matrix in the target space
3434
* T : A coupling between those two spaces
3535
36-
The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
36+
The square-loss function L(a,b)=|a-b|^2 is read as :
3737
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
38-
* f1(a)=(a^2)/2
39-
* f2(b)=(b^2)/2
38+
* f1(a)=(a^2)
39+
* f2(b)=(b^2)
4040
* h1(a)=a
41-
* h2(b)=b
41+
* h2(b)=2*b
4242
4343
The kl-loss function L(a,b)=a*log(a/b)-a+b is read as :
4444
L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
@@ -269,7 +269,7 @@ def update_kl_loss(p, lambdas, T, Cs):
269269
return np.exp(np.divide(tmpsum, ppt))
270270

271271

272-
def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
272+
def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
273273
"""
274274
Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
275275
@@ -307,8 +307,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs)
307307
Print information along iterations
308308
log : bool, optional
309309
record log if True
310-
amijo : bool, optional
311-
If True the steps of the line-search is found via an amijo research. Else closed form is used.
310+
armijo : bool, optional
311+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
312312
If there is convergence issues use False.
313313
**kwargs : dict
314314
parameters can be directly pased to the ot.optim.cg solver
@@ -344,14 +344,14 @@ def df(G):
344344
return gwggrad(constC, hC1, hC2, G)
345345

346346
if log:
347-
res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
347+
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
348348
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
349349
return res, log
350350
else:
351-
return cg(p, q, 0, 1, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
351+
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
352352

353353

354-
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, amijo=False, **kwargs):
354+
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
355355
"""
356356
Computes the FGW distance between two graphs see [3]
357357
.. math::
@@ -363,6 +363,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
363363
- M is the (ns,nt) metric cost matrix
364364
- :math:`f` is the regularization term ( and df is its gradient)
365365
- a and b are source and target weights (sum to 1)
366+
- L is a loss function to account for the misfit between the similarity matrices
366367
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
367368
Parameters
368369
----------
@@ -386,8 +387,8 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
386387
Print information along iterations
387388
log : bool, optional
388389
record log if True
389-
amijo : bool, optional
390-
If True the steps of the line-search is found via an amijo research. Else closed form is used.
390+
armijo : bool, optional
391+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
391392
If there is convergence issues use False.
392393
**kwargs : dict
393394
parameters can be directly pased to the ot.optim.cg solver
@@ -415,10 +416,10 @@ def f(G):
415416
def df(G):
416417
return gwggrad(constC, hC1, hC2, G)
417418

418-
return cg(p, q, M, alpha, f, df, G0, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
419+
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
419420

420421

421-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs):
422+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
422423
"""
423424
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
424425
@@ -456,8 +457,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, amijo=False, **kwargs
456457
Print information along iterations
457458
log : bool, optional
458459
record log if True
459-
amijo : bool, optional
460-
If True the steps of the line-search is found via an amijo research. Else closed form is used.
460+
armijo : bool, optional
461+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
461462
If there is convergence issues use False.
462463
Returns
463464
-------
@@ -487,7 +488,7 @@ def f(G):
487488

488489
def df(G):
489490
return gwggrad(constC, hC1, hC2, G)
490-
res, log = cg(p, q, 0, 1, f, df, G0, log=True, amijo=amijo, C1=C1, C2=C2, constC=constC, **kwargs)
491+
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
491492
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
492493
log['T'] = res
493494
if log:
@@ -890,7 +891,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
890891
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
891892
verbose=False, log=True, init_C=None, init_X=None):
892893
"""
893-
Compute the fgw barycenter as presented eq (5) in [3].
894+
Compute the fgw barycenter as presented eq (5) in [24].
894895
----------
895896
N : integer
896897
Desired number of samples of the target barycenter
@@ -1065,7 +1066,7 @@ def update_sructure_matrix(p, lambdas, T, Cs):
10651066

10661067
def update_feature_matrix(lambdas, Ys, Ts, p):
10671068
"""
1068-
Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [3]
1069+
Updates the feature with respect to the S Ts couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24]
10691070
calculated at each iteration
10701071
Parameters
10711072
----------

ot/optim.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@ def phi(alpha1):
7373

7474

7575
def do_linesearch(cost, G, deltaG, Mi, f_val,
76-
amijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
76+
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
7777
"""
7878
Solve the linesearch in the FW iterations
7979
Parameters
8080
----------
8181
cost : method
82-
The FGW cost
82+
Cost in the FW for the linesearch
8383
G : ndarray, shape(ns,nt)
8484
The transport map at a given iteration of the FW
8585
deltaG : ndarray (ns,nt)
@@ -88,21 +88,21 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
8888
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
8989
f_val : float
9090
Value of the cost at G
91-
amijo : bool, optionnal
92-
If True the steps of the line-search is found via an amijo research. Else closed form is used.
91+
armijo : bool, optionnal
92+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
9393
If there is convergence issues use False.
9494
C1 : ndarray (ns,ns), optionnal
95-
Structure matrix in the source domain. Only used when amijo=False
95+
Structure matrix in the source domain. Only used when armijo=False
9696
C2 : ndarray (nt,nt), optionnal
97-
Structure matrix in the target domain. Only used when amijo=False
97+
Structure matrix in the target domain. Only used when armijo=False
9898
reg : float, optionnal
99-
Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when amijo=False
99+
Regularization parameter. Only used when armijo=False
100100
Gc : ndarray (ns,nt)
101-
Optimal map found by linearization in the FW algorithm. Only used when amijo=False
101+
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
102102
constC : ndarray (ns,nt)
103-
Constant for the gromov cost. See [3]. Only used when amijo=False
103+
Constant for the gromov cost. See [24]. Only used when armijo=False
104104
M : ndarray (ns,nt), optionnal
105-
Cost matrix between the features. Only used when amijo=False
105+
Cost matrix between the features. Only used when armijo=False
106106
Returns
107107
-------
108108
alpha : float
@@ -118,7 +118,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
118118
"Optimal Transport for structured data with application on graphs"
119119
International Conference on Machine Learning (ICML). 2019.
120120
"""
121-
if amijo:
121+
if armijo:
122122
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
123123
else: # requires symetric matrices
124124
dot1 = np.dot(C1, deltaG)

0 commit comments

Comments
 (0)