1111
1212# Author: Remi Flamary <remi.flamary@unice.fr>
1313# Minhui Huang <mhhuang@ucdavis.edu>
14+ # Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
1415#
1516# License: MIT License
1617
@@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k):
4344 return G
4445
4546
47+ def logsumexp (M , axis ):
48+ r"""Log-sum-exp reduction compatible with autograd (no numpy implementation)
49+ """
50+ amax = np .amax (M , axis = axis , keepdims = True )
51+ return np .log (np .sum (np .exp (M - amax ), axis = axis )) + np .squeeze (amax , axis = axis )
52+
53+
54+ def sinkhorn_log (w1 , w2 , M , reg , k ):
55+ r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd)
56+ """
57+ Mr = - M / reg
58+ ui = np .zeros ((M .shape [0 ],))
59+ vi = np .zeros ((M .shape [1 ],))
60+ log_w1 = np .log (w1 )
61+ log_w2 = np .log (w2 )
62+ for i in range (k ):
63+ vi = log_w2 - logsumexp (Mr + ui [:, None ], 0 )
64+ ui = log_w1 - logsumexp (Mr + vi [None , :], 1 )
65+ G = np .exp (ui [:, None ] + Mr + vi [None , :])
66+ return G
67+
68+
4669def split_classes (X , y ):
4770 r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}`
4871 """
@@ -110,7 +133,7 @@ def proj(X):
110133 return Popt , proj
111134
112135
113- def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , maxiter = 100 , verbose = 0 , P0 = None , normalize = False ):
136+ def wda (X , y , p = 2 , reg = 1 , k = 10 , solver = None , sinkhorn_method = 'sinkhorn' , maxiter = 100 , verbose = 0 , P0 = None , normalize = False ):
114137 r"""
115138 Wasserstein Discriminant Analysis :ref:`[11] <references-wda>`
116139
@@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
126149 - :math:`W` is entropic regularized Wasserstein distances
127150 - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i
128151
152+ **Choosing a Sinkhorn solver**
153+
154+ By default and when using a regularization parameter that is not too small
155+ the default sinkhorn solver should be enough. If you need to use a small
156+ regularization to get sparse cost matrices, you should use the
157+ :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical
158+ errors, but can be slow in practice.
159+
129160 Parameters
130161 ----------
131162 X : ndarray, shape (n, d)
@@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
139170 solver : None | str, optional
140171 None for steepest descent or 'TrustRegions' for trust regions algorithm
141172 else should be a pymanopt.solvers
173+ sinkhorn_method : str
174+ method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log'
142175 P0 : ndarray, shape (d, p)
143176 Initial starting point for projection.
144177 normalize : bool, optional
@@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
161194 Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
162195 """ # noqa
163196
197+ if sinkhorn_method .lower () == 'sinkhorn' :
198+ sinkhorn_solver = sinkhorn
199+ elif sinkhorn_method .lower () == 'sinkhorn_log' :
200+ sinkhorn_solver = sinkhorn_log
201+ else :
202+ raise ValueError ("Unknown Sinkhorn method '%s'." % sinkhorn_method )
203+
164204 mx = np .mean (X )
165205 X -= mx .reshape ((1 , - 1 ))
166206
@@ -193,7 +233,7 @@ def cost(P):
193233 for j , xj in enumerate (xc [i :]):
194234 xj = np .dot (xj , P )
195235 M = dist (xi , xj )
196- G = sinkhorn (wc [i ], wc [j + i ], M , reg * regmean [i , j ], k )
236+ G = sinkhorn_solver (wc [i ], wc [j + i ], M , reg * regmean [i , j ], k )
197237 if j == 0 :
198238 loss_w += np .sum (G * M )
199239 else :
0 commit comments