Skip to content

Commit 06eabe7

Browse files
committed
pep8 + working tests
1 parent fde3d59 commit 06eabe7

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ot/datasets.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import scipy as sp
12+
from .utils import check_random_state
1213

1314

1415
def get_1D_gauss(n, m, s):
@@ -60,7 +61,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
6061
n samples drawn from N(m,sigma)
6162
6263
"""
63-
64+
6465
generator = check_random_state(random_state)
6566
if np.isscalar(sigma):
6667
sigma = np.array([sigma, ])
@@ -98,9 +99,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
9899
labels of the samples
99100
100101
"""
101-
102+
102103
generator = check_random_state(random_state)
103-
104+
104105
if dataset.lower() == '3gauss':
105106
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
106107
x = np.zeros((n, 2))
@@ -140,8 +141,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
140141
n2 = np.sum(y == 2)
141142
x = np.zeros((n, 2))
142143

143-
x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator)
144-
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator)
144+
x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator)
145+
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator)
145146

146147
x = x.dot(rot)
147148

0 commit comments

Comments
 (0)