Skip to content

Commit d432038

Browse files
committed
relative+absolute loss
1 parent 9421ddd commit d432038

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

ot/optim.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
135135

136136

137137
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
138-
stopThr=1e-9, verbose=False, log=False, **kwargs):
138+
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
139139
"""
140140
Solve the general regularized OT problem with conditional gradient
141141
@@ -173,7 +173,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
173173
numItermax : int, optional
174174
Max number of iterations
175175
stopThr : float, optional
176-
Stop threshol on error (>0)
176+
Stop threshol on the relative variation (>0)
177+
stopThr2 : float, optional
178+
Stop threshol on the absolute variation (>0)
177179
verbose : bool, optional
178180
Print information along iterations
179181
log : bool, optional
@@ -249,8 +251,9 @@ def cost(G):
249251
if it >= numItermax:
250252
loop = 0
251253

252-
delta_fval = (f_val - old_fval) / abs(f_val)
253-
if abs(delta_fval) < stopThr:
254+
abs_delta_fval = abs(f_val - old_fval)
255+
relative_delta_fval = abs_delta_fval / abs(f_val)
256+
if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
254257
loop = 0
255258

256259
if log:
@@ -259,8 +262,8 @@ def cost(G):
259262
if verbose:
260263
if it % 20 == 0:
261264
print('{:5s}|{:12s}|{:8s}'.format(
262-
'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
263-
print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
265+
'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
266+
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
264267

265268
if log:
266269
return G, log
@@ -269,7 +272,7 @@ def cost(G):
269272

270273

271274
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
272-
numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
275+
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False):
273276
"""
274277
Solve the general regularized OT problem with the generalized conditional gradient
275278
@@ -312,7 +315,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
312315
numInnerItermax : int, optional
313316
Max number of iterations of Sinkhorn
314317
stopThr : float, optional
315-
Stop threshol on error (>0)
318+
Stop threshol on the relative variation (>0)
319+
stopThr2 : float, optional
320+
Stop threshol on the absolute variation (>0)
316321
verbose : bool, optional
317322
Print information along iterations
318323
log : bool, optional
@@ -386,8 +391,10 @@ def cost(G):
386391
if it >= numItermax:
387392
loop = 0
388393

389-
delta_fval = (f_val - old_fval) / abs(f_val)
390-
if abs(delta_fval) < stopThr:
394+
abs_delta_fval = abs(f_val - old_fval)
395+
relative_delta_fval = abs_delta_fval / abs(f_val)
396+
397+
if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
391398
loop = 0
392399

393400
if log:
@@ -396,8 +403,8 @@ def cost(G):
396403
if verbose:
397404
if it % 20 == 0:
398405
print('{:5s}|{:12s}|{:8s}'.format(
399-
'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
400-
print('{:5d}|{:8e}|{:8e}'.format(it, f_val, delta_fval))
406+
'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
407+
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
401408

402409
if log:
403410
return G, log

0 commit comments

Comments
 (0)