Skip to content

Commit fde3d59

Browse files
committed
add random_state
1 parent a1f09f3 commit fde3d59

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

ot/datasets.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_1D_gauss(n, m, s):
3636
return h / h.sum()
3737

3838

39-
def get_2D_samples_gauss(n, m, sigma):
39+
def get_2D_samples_gauss(n, m, sigma, random_state=None):
4040
"""return n samples drawn from 2D gaussian N(m,sigma)
4141
4242
Parameters
@@ -48,25 +48,31 @@ def get_2D_samples_gauss(n, m, sigma):
4848
mean value of the gaussian distribution
4949
sigma : np.array (2,2)
5050
covariance matrix of the gaussian distribution
51-
51+
random_state : int, RandomState instance or None, optional (default=None)
52+
If int, random_state is the seed used by the random number generator;
53+
If RandomState instance, random_state is the random number generator;
54+
If None, the random number generator is the RandomState instance used
55+
by `np.random`.
5256
5357
Returns
5458
-------
5559
X : np.array (n,2)
5660
n samples drawn from N(m,sigma)
5761
5862
"""
63+
64+
generator = check_random_state(random_state)
5965
if np.isscalar(sigma):
6066
sigma = np.array([sigma, ])
6167
if len(sigma) > 1:
6268
P = sp.linalg.sqrtm(sigma)
63-
res = np.random.randn(n, 2).dot(P) + m
69+
res = generator.randn(n, 2).dot(P) + m
6470
else:
65-
res = np.random.randn(n, 2) * np.sqrt(sigma) + m
71+
res = generator.randn(n, 2) * np.sqrt(sigma) + m
6672
return res
6773

6874

69-
def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
75+
def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
7076
""" dataset generation for classification problems
7177
7278
Parameters
@@ -78,7 +84,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
7884
number of training samples
7985
nz : float
8086
noise level (>0)
81-
87+
random_state : int, RandomState instance or None, optional (default=None)
88+
If int, random_state is the seed used by the random number generator;
89+
If RandomState instance, random_state is the random number generator;
90+
If None, the random number generator is the RandomState instance used
91+
by `np.random`.
8292
8393
Returns
8494
-------
@@ -88,6 +98,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
8898
labels of the samples
8999
90100
"""
101+
102+
generator = check_random_state(random_state)
103+
91104
if dataset.lower() == '3gauss':
92105
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
93106
x = np.zeros((n, 2))
@@ -99,8 +112,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
99112
x[y == 3, 0] = 1.
100113
x[y == 3, 1] = 0
101114

102-
x[y != 3, :] += 1.5 * nz * np.random.randn(sum(y != 3), 2)
103-
x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
115+
x[y != 3, :] += 1.5 * nz * generator.randn(sum(y != 3), 2)
116+
x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
104117

105118
elif dataset.lower() == '3gauss2':
106119
y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
@@ -114,8 +127,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
114127
x[y == 3, 0] = 2.
115128
x[y == 3, 1] = 0
116129

117-
x[y != 3, :] += nz * np.random.randn(sum(y != 3), 2)
118-
x[y == 3, :] += 2 * nz * np.random.randn(sum(y == 3), 2)
130+
x[y != 3, :] += nz * generator.randn(sum(y != 3), 2)
131+
x[y == 3, :] += 2 * nz * generator.randn(sum(y == 3), 2)
119132

120133
elif dataset.lower() == 'gaussrot':
121134
rot = np.array(
@@ -127,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
127140
n2 = np.sum(y == 2)
128141
x = np.zeros((n, 2))
129142

130-
x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz)
131-
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz)
143+
x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator)
144+
x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator)
132145

133146
x = x.dot(rot)
134147

0 commit comments

Comments
 (0)