1818# The corresponding scipy function does not work for matrices
1919
2020
21- def line_search_armijo (f , xk , pk , gfk , old_fval ,
22- args = (), c1 = 1e-4 , alpha0 = 0.99 ):
21+ def line_search_armijo (
22+ f , xk , pk , gfk , old_fval , args = (), c1 = 1e-4 ,
23+ alpha0 = 0.99 , alpha_min = None , alpha_max = None
24+ ):
2325 r"""
2426 Armijo linesearch function that works with matrices
2527
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
4446 :math:`c_1` const in armijo rule (>0)
4547 alpha0 : float, optional
4648 initial step (>0)
49+ alpha_min : float, optional
50+ minimum value for alpha
51+ alpha_max : float, optional
52+ maximum value for alpha
4753
4854 Returns
4955 -------
@@ -80,13 +86,15 @@ def phi(alpha1):
8086 if alpha is None :
8187 return 0. , fc [0 ], phi0
8288 else :
83- # scalar_search_armijo can return alpha > 1
84- alpha = min ( 1 , alpha )
89+ if alpha_min is not None or alpha_max is not None :
90+ alpha = np . clip ( alpha , alpha_min , alpha_max )
8591 return alpha , fc [0 ], phi1
8692
8793
88- def solve_linesearch (cost , G , deltaG , Mi , f_val ,
89- armijo = True , C1 = None , C2 = None , reg = None , Gc = None , constC = None , M = None ):
94+ def solve_linesearch (
95+ cost , G , deltaG , Mi , f_val , armijo = True , C1 = None , C2 = None ,
96+ reg = None , Gc = None , constC = None , M = None , alpha_min = None , alpha_max = None
97+ ):
9098 """
9199 Solve the linesearch in the FW iterations
92100
@@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
117125 Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
118126 M : array-like (ns,nt), optional
119127 Cost matrix between the features. Only used and necessary when armijo=False
128+ alpha_min : float, optional
129+ Minimum value for alpha
130+ alpha_max : float, optional
131+ Maximum value for alpha
120132
121133 Returns
122134 -------
@@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
136148 International Conference on Machine Learning (ICML). 2019.
137149 """
138150 if armijo :
139- alpha , fc , f_val = line_search_armijo (cost , G , deltaG , Mi , f_val )
151+ alpha , fc , f_val = line_search_armijo (
152+ cost , G , deltaG , Mi , f_val , alpha_min = alpha_min , alpha_max = alpha_max
153+ )
140154 else : # requires symetric matrices
141155 G , deltaG , C1 , C2 , constC , M = list_to_array (G , deltaG , C1 , C2 , constC , M )
142156 if isinstance (M , int ) or isinstance (M , float ):
@@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
150164 c = cost (G )
151165
152166 alpha = solve_1d_linesearch_quad (a , b , c )
167+ if alpha_min is not None or alpha_max is not None :
168+ alpha = np .clip (alpha , alpha_min , alpha_max )
153169 fc = None
154170 f_val = cost (G + alpha * deltaG )
155171
@@ -274,7 +290,10 @@ def cost(G):
274290 deltaG = Gc - G
275291
276292 # line search
277- alpha , fc , f_val = solve_linesearch (cost , G , deltaG , Mi , f_val , reg = reg , M = M , Gc = Gc , ** kwargs )
293+ alpha , fc , f_val = solve_linesearch (
294+ cost , G , deltaG , Mi , f_val , reg = reg , M = M , Gc = Gc ,
295+ alpha_min = 0. , alpha_max = 1. , ** kwargs
296+ )
278297
279298 G = G + alpha * deltaG
280299
@@ -420,7 +439,9 @@ def cost(G):
420439
421440 # line search
422441 dcost = Mi + reg1 * (1 + nx .log (G )) # ??
423- alpha , fc , f_val = line_search_armijo (cost , G , deltaG , dcost , f_val )
442+ alpha , fc , f_val = line_search_armijo (
443+ cost , G , deltaG , dcost , f_val , alpha_min = 0. , alpha_max = 1.
444+ )
424445
425446 G = G + alpha * deltaG
426447
0 commit comments