@@ -36,7 +36,7 @@ def get_1D_gauss(n, m, s):
3636 return h / h .sum ()
3737
3838
39- def get_2D_samples_gauss (n , m , sigma ):
39+ def get_2D_samples_gauss (n , m , sigma , random_state = None ):
4040 """return n samples drawn from 2D gaussian N(m,sigma)
4141
4242 Parameters
@@ -48,25 +48,31 @@ def get_2D_samples_gauss(n, m, sigma):
4848 mean value of the gaussian distribution
4949 sigma : np.array (2,2)
5050 covariance matrix of the gaussian distribution
51-
51+ random_state : int, RandomState instance or None, optional (default=None)
52+ If int, random_state is the seed used by the random number generator;
53+ If RandomState instance, random_state is the random number generator;
54+ If None, the random number generator is the RandomState instance used
55+ by `np.random`.
5256
5357 Returns
5458 -------
5559 X : np.array (n,2)
5660 n samples drawn from N(m,sigma)
5761
5862 """
63+
64+ generator = check_random_state (random_state )
5965 if np .isscalar (sigma ):
6066 sigma = np .array ([sigma , ])
6167 if len (sigma ) > 1 :
6268 P = sp .linalg .sqrtm (sigma )
63- res = np . random .randn (n , 2 ).dot (P ) + m
69+ res = generator .randn (n , 2 ).dot (P ) + m
6470 else :
65- res = np . random .randn (n , 2 ) * np .sqrt (sigma ) + m
71+ res = generator .randn (n , 2 ) * np .sqrt (sigma ) + m
6672 return res
6773
6874
69- def get_data_classif (dataset , n , nz = .5 , theta = 0 , ** kwargs ):
75+ def get_data_classif (dataset , n , nz = .5 , theta = 0 , random_state = None , ** kwargs ):
7076 """ dataset generation for classification problems
7177
7278 Parameters
@@ -78,7 +84,11 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
7884 number of training samples
7985 nz : float
8086 noise level (>0)
81-
87+ random_state : int, RandomState instance or None, optional (default=None)
88+ If int, random_state is the seed used by the random number generator;
89+ If RandomState instance, random_state is the random number generator;
90+ If None, the random number generator is the RandomState instance used
91+ by `np.random`.
8292
8393 Returns
8494 -------
@@ -88,6 +98,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
8898 labels of the samples
8999
90100 """
101+
102+ generator = check_random_state (random_state )
103+
91104 if dataset .lower () == '3gauss' :
92105 y = np .floor ((np .arange (n ) * 1.0 / n * 3 )) + 1
93106 x = np .zeros ((n , 2 ))
@@ -99,8 +112,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
99112 x [y == 3 , 0 ] = 1.
100113 x [y == 3 , 1 ] = 0
101114
102- x [y != 3 , :] += 1.5 * nz * np . random .randn (sum (y != 3 ), 2 )
103- x [y == 3 , :] += 2 * nz * np . random .randn (sum (y == 3 ), 2 )
115+ x [y != 3 , :] += 1.5 * nz * generator .randn (sum (y != 3 ), 2 )
116+ x [y == 3 , :] += 2 * nz * generator .randn (sum (y == 3 ), 2 )
104117
105118 elif dataset .lower () == '3gauss2' :
106119 y = np .floor ((np .arange (n ) * 1.0 / n * 3 )) + 1
@@ -114,8 +127,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
114127 x [y == 3 , 0 ] = 2.
115128 x [y == 3 , 1 ] = 0
116129
117- x [y != 3 , :] += nz * np . random .randn (sum (y != 3 ), 2 )
118- x [y == 3 , :] += 2 * nz * np . random .randn (sum (y == 3 ), 2 )
130+ x [y != 3 , :] += nz * generator .randn (sum (y != 3 ), 2 )
131+ x [y == 3 , :] += 2 * nz * generator .randn (sum (y == 3 ), 2 )
119132
120133 elif dataset .lower () == 'gaussrot' :
121134 rot = np .array (
@@ -127,8 +140,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, **kwargs):
127140 n2 = np .sum (y == 2 )
128141 x = np .zeros ((n , 2 ))
129142
130- x [y == 1 , :] = get_2D_samples_gauss (n1 , m1 , nz )
131- x [y == 2 , :] = get_2D_samples_gauss (n2 , m2 , nz )
143+ x [y == 1 , :] = get_2D_samples_gauss (n1 , m1 , nz , random_state = generator )
144+ x [y == 2 , :] = get_2D_samples_gauss (n2 , m2 , nz , random_state = generator )
132145
133146 x = x .dot (rot )
134147
0 commit comments