|
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import scipy as sp |
| 12 | +from .utils import check_random_state |
12 | 13 |
|
13 | 14 |
|
14 | 15 | def get_1D_gauss(n, m, s): |
@@ -60,7 +61,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): |
60 | 61 | n samples drawn from N(m,sigma) |
61 | 62 |
|
62 | 63 | """ |
63 | | - |
| 64 | + |
64 | 65 | generator = check_random_state(random_state) |
65 | 66 | if np.isscalar(sigma): |
66 | 67 | sigma = np.array([sigma, ]) |
@@ -98,9 +99,9 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): |
98 | 99 | labels of the samples |
99 | 100 |
|
100 | 101 | """ |
101 | | - |
| 102 | + |
102 | 103 | generator = check_random_state(random_state) |
103 | | - |
| 104 | + |
104 | 105 | if dataset.lower() == '3gauss': |
105 | 106 | y = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 |
106 | 107 | x = np.zeros((n, 2)) |
@@ -140,8 +141,8 @@ def get_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): |
140 | 141 | n2 = np.sum(y == 2) |
141 | 142 | x = np.zeros((n, 2)) |
142 | 143 |
|
143 | | - x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz,random_state=generator) |
144 | | - x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz,random_state=generator) |
| 144 | + x[y == 1, :] = get_2D_samples_gauss(n1, m1, nz, random_state=generator) |
| 145 | + x[y == 2, :] = get_2D_samples_gauss(n2, m2, nz, random_state=generator) |
145 | 146 |
|
146 | 147 | x = x.dot(rot) |
147 | 148 |
|
|
0 commit comments