Skip to content

Commit f65073f

Browse files
committed
comlete documentation
1 parent 9a9b354 commit f65073f

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

ot/lp/__init__.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929

3030
def center_ot_dual(alpha0, beta0, a=None, b=None):
31-
r"""Center dual OT potentials wrt theirs weights
31+
r"""Center dual OT potentials w.r.t. theirs weights
3232
3333
The main idea of this function is to find unique dual potentials
34-
that ensure some kind of centering/fairness. It will help have
34+
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
3535
stability when multiple calling of the OT solver with small changes.
3636
3737
Basically we add another constraint to the potential that will not
@@ -91,7 +91,15 @@ def center_ot_dual(alpha0, beta0, a=None, b=None):
9191
def estimate_dual_null_weights(alpha0, beta0, a, b, M):
9292
r"""Estimate feasible values for 0-weighted dual potentials
9393
94-
The feasible values are computed efficiently bjt rather coarsely.
94+
The feasible values are computed efficiently but rather coarsely.
95+
96+
.. warning::
97+
This function is necessary because the C++ solver in emd_c
98+
discards all samples in the distributions with
99+
zeros weights. This means that while the primal variable (transport
100+
matrix) is exact, the solver only returns feasible dual potentials
101+
on the samples with weights different from zero.
102+
95103
First we compute the constraints violations:
96104
97105
.. math::
@@ -113,11 +121,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
113121
114122
\beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0
115123
116-
In the end the dual potential are centred using function
124+
In the end the dual potentials are centered using function
117125
:ref:`center_ot_dual`.
118126
119127
Note that all those updates do not change the objective value of the
120-
solution but provide dual potential that do not violate the constraints.
128+
solution but provide dual potentials that do not violate the constraints.
121129
122130
Parameters
123131
----------
@@ -130,9 +138,9 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
130138
beta0 : (nt,) numpy.ndarray, float64
131139
Target dual potential
132140
a : (ns,) numpy.ndarray, float64
133-
Source histogram (uniform weight if empty list)
141+
Source distribution (uniform weights if empty list)
134142
b : (nt,) numpy.ndarray, float64
135-
Target histogram (uniform weight if empty list)
143+
Target distribution (uniform weights if empty list)
136144
M : (ns,nt) numpy.ndarray, float64
137145
Loss matrix (c-order array with type float64)
138146
@@ -150,11 +158,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
150158
bsel = b != 0
151159

152160
# compute dual constraints violation
153-
Viol = alpha0[:, None] + beta0[None, :] - M
161+
constraint_violation = alpha0[:, None] + beta0[None, :] - M
154162

155-
# Compute worst violation per line and columns
156-
aviol = np.max(Viol, 1)
157-
bviol = np.max(Viol, 0)
163+
# Compute largest violation per line and columns
164+
aviol = np.max(constraint_violation, 1)
165+
bviol = np.max(constraint_violation, 0)
158166

159167
# update corrects violation of
160168
alpha_up = -1 * ~asel * np.maximum(aviol, 0)

ot/lp/emd_wrap.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
6666
.. warning::
6767
Note that the M matrix needs to be a C-order :py.cls:`numpy.array`
6868
69+
.. warning::
70+
The C++ solver discards all samples in the distributions with
71+
zeros weights. This means that while the primal variable (transport
72+
matrix) is exact, the solver only returns feasible dual potentials
73+
on the samples with weights different from zero.
74+
6975
Parameters
7076
----------
7177
a : (ns,) numpy.ndarray, float64

test/test_ot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,9 @@ def test_dual_variables():
338338
np.testing.assert_almost_equal(cost1, log['cost'])
339339
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])
340340

341-
viol = log['u'][:, None] + log['v'][None, :] - M
341+
constraint_violation = log['u'][:, None] + log['v'][None, :] - M
342342

343-
assert viol.max() < 1e-8
343+
assert constraint_violation.max() < 1e-8
344344

345345

346346
def check_duality_gap(a, b, M, G, u, v, cost):

0 commit comments

Comments
 (0)