Skip to content

Commit fb883fc

Browse files
committed
proper documentation
1 parent 724984d commit fb883fc

File tree

1 file changed

+147
-8
lines changed

1 file changed

+147
-8
lines changed

ot/smooth.py

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def dual_obj_grad(alpha, beta, a, b, C, regul):
226226
return obj, grad_alpha, grad_beta
227227

228228

229-
def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500):
229+
def solve_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
230+
verbose=False):
230231
"""
231232
Solve the "smoothed" dual objective.
232233
@@ -273,7 +274,7 @@ def _func(params):
273274
params_init = np.concatenate((alpha_init, beta_init))
274275

275276
res = minimize(_func, params_init, method=method, jac=True,
276-
tol=tol, options=dict(maxiter=max_iter, disp=False))
277+
tol=tol, options=dict(maxiter=max_iter, disp=verbose))
277278

278279
alpha = res.x[:len(a)]
279280
beta = res.x[len(a):]
@@ -321,7 +322,7 @@ def semi_dual_obj_grad(alpha, a, b, C, regul):
321322

322323

323324
def solve_semi_dual(a, b, C, regul, method="L-BFGS-B", tol=1e-3, max_iter=500,
324-
):
325+
verbose=False):
325326
"""
326327
Solve the "smoothed" semi-dual objective.
327328
@@ -355,7 +356,7 @@ def _func(alpha):
355356
alpha_init = np.zeros(len(a))
356357

357358
res = minimize(_func, alpha_init, method=method, jac=True,
358-
tol=tol, options=dict(maxiter=max_iter, disp=False))
359+
tol=tol, options=dict(maxiter=max_iter, disp=verbose))
359360

360361
return res.x, res
361362

@@ -408,7 +409,75 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
408409

409410

410411
def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
411-
numItermax=500, log=False):
412+
numItermax=500, verbose=False, log=False):
413+
r"""
414+
Solve the regularized OT problem in the dual and return the OT matrix
415+
416+
The function solves the smooth relaxed dual formulation (7) in [17]_ :
417+
418+
.. math::
419+
\max_{\alpha,\beta}\quad a^T\alpha+b^T\beta-\sum_j\delta_\Omega(\alpha+\beta_j-\mathbf{m}_j)
420+
421+
where :
422+
423+
- :math:`\mathbf{m}_j` is the jth column of the cost matrix
424+
- :math:`\delta_\Omega` is the convex conjugate of the regularization term :math:`\Omega`
425+
- a and b are source and target weights (sum to 1)
426+
427+
The OT matrix can is reconstructed from the gradient of :math:`\delta_\Omega`
428+
(See [17]_ Proposition 1).
429+
The optimization algorithm is using gradient decent (L-BFGS by default).
430+
431+
432+
Parameters
433+
----------
434+
a : np.ndarray (ns,)
435+
samples weights in the source domain
436+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
437+
samples in the target domain, compute sinkhorn with multiple targets
438+
and fixed M if b is a matrix (return OT loss + dual variables in log)
439+
M : np.ndarray (ns,nt)
440+
loss matrix
441+
reg : float
442+
Regularization term >0
443+
reg_type : str
444+
Regularization type, can be the following (default ='l2'):
445+
- 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
446+
- 'l2' : Squared Euclidean regularization
447+
method : str
448+
Solver to use for scipy.optimize.minimize
449+
numItermax : int, optional
450+
Max number of iterations
451+
stopThr : float, optional
452+
Stop threshol on error (>0)
453+
verbose : bool, optional
454+
Print information along iterations
455+
log : bool, optional
456+
record log if True
457+
458+
459+
Returns
460+
-------
461+
gamma : (ns x nt) ndarray
462+
Optimal transportation matrix for the given parameters
463+
log : dict
464+
log dictionary return only if log==True in parameters
465+
466+
467+
References
468+
----------
469+
470+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
471+
472+
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
473+
474+
See Also
475+
--------
476+
ot.lp.emd : Unregularized OT
477+
ot.sinhorn : Entropic regularized OT
478+
ot.optim.cg : General regularized OT
479+
480+
"""
412481

413482
if reg_type.lower() in ['l2', 'squaredl2']:
414483
regul = SquaredL2(gamma=reg)
@@ -418,7 +487,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
418487
raise NotImplementedError('Unknown regularization')
419488

420489
# solve dual
421-
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
490+
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
491+
tol=stopThr, verbose=verbose)
422492

423493
# reconstruct transport matrix
424494
G = get_plan_from_dual(alpha, beta, M, regul)
@@ -431,8 +501,77 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
431501

432502

433503
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
434-
numItermax=500, log=False):
504+
numItermax=500, verbose=False, log=False):
505+
r"""
506+
Solve the regularized OT problem in the semi-dual and return the OT matrix
435507
508+
The function solves the smooth relaxed dual formulation (10) in [17]_ :
509+
510+
.. math::
511+
\max_{\alpha}\quad a^T\alpha-OT_\Omega^*(\alpha,b)
512+
513+
where :
514+
515+
.. math::
516+
OT_\Omega^*(\alpha,b)=\sum_j b_j
517+
518+
- :math:`\mathbf{m}_j` is the jth column of the cost matrix
519+
- :math:`OT_\Omega^*(\alpha,b)` is defined in Eq. (9) in [17]
520+
- a and b are source and target weights (sum to 1)
521+
522+
The OT matrix can is reconstructed using [17]_ Proposition 2.
523+
The optimization algorithm is using gradient decent (L-BFGS by default).
524+
525+
526+
Parameters
527+
----------
528+
a : np.ndarray (ns,)
529+
samples weights in the source domain
530+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
531+
samples in the target domain, compute sinkhorn with multiple targets
532+
and fixed M if b is a matrix (return OT loss + dual variables in log)
533+
M : np.ndarray (ns,nt)
534+
loss matrix
535+
reg : float
536+
Regularization term >0
537+
reg_type : str
538+
Regularization type, can be the following (default ='l2'):
539+
- 'kl' : Kullback Leibler (~ Neg-entropy used in sinkhorn [2]_)
540+
- 'l2' : Squared Euclidean regularization
541+
method : str
542+
Solver to use for scipy.optimize.minimize
543+
numItermax : int, optional
544+
Max number of iterations
545+
stopThr : float, optional
546+
Stop threshol on error (>0)
547+
verbose : bool, optional
548+
Print information along iterations
549+
log : bool, optional
550+
record log if True
551+
552+
553+
Returns
554+
-------
555+
gamma : (ns x nt) ndarray
556+
Optimal transportation matrix for the given parameters
557+
log : dict
558+
log dictionary return only if log==True in parameters
559+
560+
561+
References
562+
----------
563+
564+
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
565+
566+
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).
567+
568+
See Also
569+
--------
570+
ot.lp.emd : Unregularized OT
571+
ot.sinhorn : Entropic regularized OT
572+
ot.optim.cg : General regularized OT
573+
574+
"""
436575
if reg_type.lower() in ['l2', 'squaredl2']:
437576
regul = SquaredL2(gamma=reg)
438577
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
@@ -444,7 +583,7 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
444583
alpha, res = solve_semi_dual(a, b, M, regul, max_iter=numItermax, tol=stopThr)
445584

446585
# reconstruct transport matrix
447-
G = get_plan_from_semi_dual(alpha, b, M, regul)
586+
G = get_plan_from_semi_dual(alpha, b, M, regul, verbose=verbose)
448587

449588
if log:
450589
log = {'alpha': alpha, 'res': res}

0 commit comments

Comments
 (0)