Skip to content

Commit 6fdf5de

Browse files
committed
add linear mapping test + autopep8
1 parent 287c659 commit 6fdf5de

File tree

6 files changed

+49
-18
lines changed

6 files changed

+49
-18
lines changed

examples/plot_otda_linear_mapping.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
import pylab as pl
1111
import ot
12-
from scipy import ndimage
1312

1413
##############################################################################
1514
# Generate data
@@ -87,8 +86,8 @@ def minmax(I):
8786

8887

8988
# Loading images
90-
I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
91-
I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
89+
I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
90+
I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
9291

9392

9493
X1 = im2mat(I1)

ot/bregman.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
import numpy as np
1212

1313

14-
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
14+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
15+
stopThr=1e-9, verbose=False, log=False, **kwargs):
1516
u"""
1617
Solve the entropic regularization optimal transport problem and return the OT matrix
1718
@@ -120,7 +121,8 @@ def sink():
120121
return sink()
121122

122123

123-
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
124+
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
125+
stopThr=1e-9, verbose=False, log=False, **kwargs):
124126
u"""
125127
Solve the entropic regularization optimal transport problem and return the loss
126128
@@ -233,7 +235,8 @@ def sink():
233235
return sink()
234236

235237

236-
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs):
238+
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
239+
stopThr=1e-9, verbose=False, log=False, **kwargs):
237240
"""
238241
Solve the entropic regularization optimal transport problem and return the OT matrix
239242
@@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l
403406
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
404407

405408

406-
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
409+
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
410+
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
407411
"""
408412
Solve the entropic regularization OT problem with log stabilization
409413
@@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa
526530

527531
def get_K(alpha, beta):
528532
"""log space computation"""
529-
return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
533+
return np.exp(-(M - alpha.reshape((na, 1)) -
534+
beta.reshape((1, nb))) / reg)
530535

531536
def get_Gamma(alpha, beta, u, v):
532537
"""log space gamma computation"""
533-
return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
538+
return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) /
539+
reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb))))
534540

535541
# print(np.min(K))
536542

@@ -620,7 +626,8 @@ def get_Gamma(alpha, beta, u, v):
620626
return get_Gamma(alpha, beta, u, v)
621627

622628

623-
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
629+
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
630+
tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
624631
"""
625632
Solve the entropic regularization optimal transport problem with log
626633
stabilization and epsilon scaling.
@@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
739746

740747
def get_K(alpha, beta):
741748
"""log space computation"""
742-
return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg)
749+
return np.exp(-(M - alpha.reshape((na, 1)) -
750+
beta.reshape((1, nb))) / reg)
743751

744752
# print(np.min(K))
745753
def get_reg(n): # exponential decreasing
@@ -811,7 +819,8 @@ def projC(gamma, q):
811819
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
812820

813821

814-
def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False):
822+
def barycenter(A, M, reg, weights=None, numItermax=1000,
823+
stopThr=1e-4, verbose=False, log=False):
815824
"""Compute the entropic regularized wasserstein barycenter of distributions A
816825
817826
The function solves the following optimization problem:
@@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F
904913
return geometricBar(weights, UKv)
905914

906915

907-
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False):
916+
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
917+
stopThr=1e-3, verbose=False, log=False):
908918
"""
909919
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
910920

ot/lp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False):
107107
return G
108108

109109

110-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(),
111+
numItermax=100000, log=False, return_matrix=False):
111112
"""Solves the Earth Movers distance problem and returns the loss
112113
113114
.. math::

ot/optim.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# The corresponding scipy function does not work for matrices
1616

1717

18-
def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99):
18+
def line_search_armijo(f, xk, pk, gfk, old_fval,
19+
args=(), c1=1e-4, alpha0=0.99):
1920
"""
2021
Armijo linesearch function that works with matrices
2122
@@ -71,7 +72,8 @@ def phi(alpha1):
7172
return alpha, fc[0], phi1
7273

7374

74-
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False):
75+
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
76+
stopThr=1e-9, verbose=False, log=False):
7577
"""
7678
Solve the general regularized OT problem with conditional gradient
7779
@@ -202,7 +204,8 @@ def cost(G):
202204
return G
203205

204206

205-
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
207+
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
208+
numInnerItermax=200, stopThr=1e-9, verbose=False, log=False):
206209
"""
207210
Solve the general regularized OT problem with the generalized conditional gradient
208211

ot/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def _is_deprecated(func):
316316
closures = []
317317
is_deprecated = ('deprecated' in ''.join([c.cell_contents
318318
for c in closures
319-
if isinstance(c.cell_contents, str)]))
319+
if isinstance(c.cell_contents, str)]))
320320
return is_deprecated
321321

322322

test/test_da.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,24 @@ def test_mapping_transport_class():
444444
assert len(otda.log_.keys()) != 0
445445

446446

447+
def test_linear_mapping():
448+
449+
ns = 150
450+
nt = 200
451+
452+
Xs, ys = get_data_classif('3gauss', ns)
453+
Xt, yt = get_data_classif('3gauss2', nt)
454+
455+
A, b = ot.da.OT_mapping_linear(Xs, Xt)
456+
457+
Xst = Xs.dot(A) + b
458+
459+
Ct = np.cov(Xt.T)
460+
Cst = np.cov(Xst.T)
461+
462+
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
463+
464+
447465
def test_otda():
448466

449467
n_samples = 150 # nb samples

0 commit comments

Comments
 (0)