@@ -135,7 +135,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
135135
136136
137137def cg (a , b , M , reg , f , df , G0 = None , numItermax = 200 ,
138- stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
138+ stopThr = 1e-9 , stopThr2 = 1e-9 , verbose = False , log = False , ** kwargs ):
139139 """
140140 Solve the general regularized OT problem with conditional gradient
141141
@@ -173,7 +173,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
173173 numItermax : int, optional
174174 Max number of iterations
175175 stopThr : float, optional
176- Stop threshol on error (>0)
176+ Stop threshol on the relative variation (>0)
177+ stopThr2 : float, optional
178+ Stop threshol on the absolute variation (>0)
177179 verbose : bool, optional
178180 Print information along iterations
179181 log : bool, optional
@@ -249,8 +251,9 @@ def cost(G):
249251 if it >= numItermax :
250252 loop = 0
251253
252- delta_fval = (f_val - old_fval ) / abs (f_val )
253- if abs (delta_fval ) < stopThr :
254+ abs_delta_fval = abs (f_val - old_fval )
255+ relative_delta_fval = abs_delta_fval / abs (f_val )
256+ if relative_delta_fval < stopThr and abs_delta_fval < stopThr2 :
254257 loop = 0
255258
256259 if log :
@@ -259,8 +262,8 @@ def cost(G):
259262 if verbose :
260263 if it % 20 == 0 :
261264 print ('{:5s}|{:12s}|{:8s}' .format (
262- 'It.' , 'Loss' , 'Delta loss' ) + '\n ' + '-' * 32 )
263- print ('{:5d}|{:8e}|{:8e}' .format (it , f_val , delta_fval ))
265+ 'It.' , 'Loss' , 'Relative variation loss' , 'Absolute variation loss' ) + '\n ' + '-' * 32 )
266+ print ('{:5d}|{:8e}|{:8e}|{:8e} ' .format (it , f_val , relative_delta_fval , abs_delta_fval ))
264267
265268 if log :
266269 return G , log
@@ -269,7 +272,7 @@ def cost(G):
269272
270273
271274def gcg (a , b , M , reg1 , reg2 , f , df , G0 = None , numItermax = 10 ,
272- numInnerItermax = 200 , stopThr = 1e-9 , verbose = False , log = False ):
275+ numInnerItermax = 200 , stopThr = 1e-9 , stopThr2 = 1e-9 , verbose = False , log = False ):
273276 """
274277 Solve the general regularized OT problem with the generalized conditional gradient
275278
@@ -312,7 +315,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
312315 numInnerItermax : int, optional
313316 Max number of iterations of Sinkhorn
314317 stopThr : float, optional
315- Stop threshol on error (>0)
318+ Stop threshol on the relative variation (>0)
319+ stopThr2 : float, optional
320+ Stop threshol on the absolute variation (>0)
316321 verbose : bool, optional
317322 Print information along iterations
318323 log : bool, optional
@@ -386,8 +391,10 @@ def cost(G):
386391 if it >= numItermax :
387392 loop = 0
388393
389- delta_fval = (f_val - old_fval ) / abs (f_val )
390- if abs (delta_fval ) < stopThr :
394+ abs_delta_fval = abs (f_val - old_fval )
395+ relative_delta_fval = abs_delta_fval / abs (f_val )
396+
397+ if relative_delta_fval < stopThr and abs_delta_fval < stopThr2 :
391398 loop = 0
392399
393400 if log :
@@ -396,8 +403,8 @@ def cost(G):
396403 if verbose :
397404 if it % 20 == 0 :
398405 print ('{:5s}|{:12s}|{:8s}' .format (
399- 'It.' , 'Loss' , 'Delta loss' ) + '\n ' + '-' * 32 )
400- print ('{:5d}|{:8e}|{:8e}' .format (it , f_val , delta_fval ))
406+ 'It.' , 'Loss' , 'Relative variation loss' , 'Absolute variation loss' ) + '\n ' + '-' * 32 )
407+ print ('{:5d}|{:8e}|{:8e}|{:8e} ' .format (it , f_val , relative_delta_fval , abs_delta_fval ))
401408
402409 if log :
403410 return G , log
0 commit comments