Skip to content

Commit f089a3c

Browse files
committed
better pep8 but not solved
1 parent 4a585de commit f089a3c

File tree

8 files changed

+152
-133
lines changed

8 files changed

+152
-133
lines changed

examples/plot_barycenter_1D.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from mpl_toolkits.mplot3d import Axes3D # noqa
2626
from matplotlib.collections import PolyCollection
2727

28-
##############################################################################
28+
#
2929
# Generate data
3030
# -------------
3131

@@ -48,7 +48,7 @@
4848
M = ot.utils.dist0(n)
4949
M /= M.max()
5050

51-
##############################################################################
51+
#
5252
# Plot data
5353
# ---------
5454

@@ -60,7 +60,7 @@
6060
pl.title('Distributions')
6161
pl.tight_layout()
6262

63-
##############################################################################
63+
#
6464
# Barycenter computation
6565
# ----------------------
6666

@@ -90,7 +90,7 @@
9090
pl.title('Barycenters')
9191
pl.tight_layout()
9292

93-
##############################################################################
93+
#
9494
# Barycentric interpolation
9595
# -------------------------
9696

examples/plot_gromov.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import ot
2121

2222

23-
##############################################################################
23+
#
2424
# Sample two Gaussian distributions (2D and 3D)
2525
# ---------------------------------------------
2626
#
@@ -43,7 +43,7 @@
4343
xt = np.random.randn(n_samples, 3).dot(P) + mu_t
4444

4545

46-
##############################################################################
46+
#
4747
# Plotting the distributions
4848
# --------------------------
4949

@@ -56,7 +56,7 @@
5656
pl.show()
5757

5858

59-
##############################################################################
59+
#
6060
# Compute distance kernels, normalize them and then display
6161
# ---------------------------------------------------------
6262

@@ -74,33 +74,32 @@
7474
pl.imshow(C2)
7575
pl.show()
7676

77-
##############################################################################
77+
#
7878
# Compute Gromov-Wasserstein plans and distance
7979
# ---------------------------------------------
8080

81-
#%%
8281
p = ot.unif(n_samples)
8382
q = ot.unif(n_samples)
8483

85-
gw0,log0 = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True,log=True)
84+
gw0, log0 = ot.gromov.gromov_wasserstein(
85+
C1, C2, p, q, 'square_loss', verbose=True, log=True)
8686

87-
gw,log= ot.gromov.entropic_gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4,log=True,verbose=True)
87+
gw, log = ot.gromov.entropic_gromov_wasserstein(
88+
C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
8889

8990

9091
print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
9192
print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
9293

9394

94-
pl.figure(1,(10,5))
95+
pl.figure(1, (10, 5))
9596

96-
pl.subplot(1,2,1)
97+
pl.subplot(1, 2, 1)
9798
pl.imshow(gw0, cmap='jet')
98-
pl.colorbar()
9999
pl.title('Gromov Wasserstein')
100100

101-
pl.subplot(1,2,2)
102-
pl.imshow(gw0, cmap='jet')
103-
pl.colorbar()
101+
pl.subplot(1, 2, 2)
102+
pl.imshow(gw, cmap='jet')
104103
pl.title('Entropic Gromov Wasserstein')
105104

106105
pl.show()

ot/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,19 @@
1616
from . import optim
1717
from . import utils
1818
from . import datasets
19-
from . import plot
2019
from . import da
2120
from . import gromov
2221

2322
# OT functions
2423
from .lp import emd, emd2
2524
from .bregman import sinkhorn, sinkhorn2, barycenter
2625
from .da import sinkhorn_lpl1_mm
27-
from .gromov import gromov_wasserstein, gromov_wasserstein2
2826

2927
# utils functions
3028
from .utils import dist, unif, tic, toc, toq
3129

3230
__version__ = "0.4.0"
3331

3432
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
35-
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
33+
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
3634
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/da.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ def distribution_estimation_uniform(X):
933933

934934

935935
class BaseTransport(BaseEstimator):
936+
936937
"""Base class for OTDA objects
937938
938939
Notes
@@ -1180,6 +1181,7 @@ class label
11801181

11811182

11821183
class SinkhornTransport(BaseTransport):
1184+
11831185
"""Domain Adapatation OT method based on Sinkhorn Algorithm
11841186
11851187
Parameters
@@ -1289,6 +1291,7 @@ class label
12891291

12901292

12911293
class EMDTransport(BaseTransport):
1294+
12921295
"""Domain Adapatation OT method based on Earth Mover's Distance
12931296
12941297
Parameters
@@ -1377,6 +1380,7 @@ class label
13771380

13781381

13791382
class SinkhornLpl1Transport(BaseTransport):
1383+
13801384
"""Domain Adapatation OT method based on sinkhorn algorithm +
13811385
LpL1 class regularization.
13821386
@@ -1486,6 +1490,7 @@ class label
14861490

14871491

14881492
class SinkhornL1l2Transport(BaseTransport):
1493+
14891494
"""Domain Adapatation OT method based on sinkhorn algorithm +
14901495
l1l2 class regularization.
14911496
@@ -1608,6 +1613,7 @@ class label
16081613

16091614

16101615
class MappingTransport(BaseEstimator):
1616+
16111617
"""MappingTransport: DA methods that aims at jointly estimating a optimal
16121618
transport coupling and the associated mapping
16131619

ot/gpu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .bregman import sinkhorn
66

77
# Author: Remi Flamary <remi.flamary@unice.fr>
8-
# Leo Gautheron <https://github.com/aje>
8+
# Leo Gautheron <https://github.com/aje>
99
#
1010
# License: MIT License
1111

ot/gpu/da.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M_GPU, reg, eta=0.1, numItermax=10,
188188

189189

190190
class OTDA_GPU(OTDA):
191+
191192
def normalizeM(self, norm):
192193
if norm == "median":
193194
self.M_GPU.divide(float(np.median(self.M_GPU.asarray())))
@@ -204,6 +205,7 @@ def normalizeM(self, norm):
204205

205206

206207
class OTDA_sinkhorn(OTDA_GPU):
208+
207209
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
208210
cudamat.init()
209211
xs = np.asarray(xs, dtype=np.float64)
@@ -228,6 +230,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
228230

229231

230232
class OTDA_lpl1(OTDA_GPU):
233+
231234
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
232235
**kwargs):
233236
cudamat.init()

0 commit comments

Comments
 (0)