|
7 | 7 |
|
8 | 8 | import ot |
9 | 9 | import numpy as np |
| 10 | +import sys |
10 | 11 |
|
11 | 12 |
|
12 | 13 | def test_parmap(): |
@@ -123,3 +124,79 @@ def test_clean_zeros(): |
123 | 124 |
|
124 | 125 | assert len(a) == n - nz |
125 | 126 | assert len(b) == n - nz2 |
| 127 | + |
| 128 | + |
| 129 | +def test_cost_normalization(): |
| 130 | + |
| 131 | + C = np.random.rand(10, 10) |
| 132 | + |
| 133 | + # does nothing |
| 134 | + M0 = ot.utils.cost_normalization(C) |
| 135 | + np.testing.assert_allclose(C, M0) |
| 136 | + |
| 137 | + M = ot.utils.cost_normalization(C, 'median') |
| 138 | + np.testing.assert_allclose(np.median(M), 1) |
| 139 | + |
| 140 | + M = ot.utils.cost_normalization(C, 'max') |
| 141 | + np.testing.assert_allclose(M.max(), 1) |
| 142 | + |
| 143 | + M = ot.utils.cost_normalization(C, 'log') |
| 144 | + np.testing.assert_allclose(M.max(), np.log(1 + C).max()) |
| 145 | + |
| 146 | + M = ot.utils.cost_normalization(C, 'loglog') |
| 147 | + np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) |
| 148 | + |
| 149 | + |
| 150 | +def test_check_params(): |
| 151 | + |
| 152 | + res1 = ot.utils.check_params(first='OK', second=20) |
| 153 | + assert res1 is True |
| 154 | + |
| 155 | + res0 = ot.utils.check_params(first='OK', second=None) |
| 156 | + assert res0 is False |
| 157 | + |
| 158 | + |
| 159 | +def test_deprecated_func(): |
| 160 | + |
| 161 | + @ot.utils.deprecated('deprecated text for fun') |
| 162 | + def fun(): |
| 163 | + pass |
| 164 | + |
| 165 | + def fun2(): |
| 166 | + pass |
| 167 | + |
| 168 | + @ot.utils.deprecated('deprecated text for class') |
| 169 | + class Class(): |
| 170 | + pass |
| 171 | + |
| 172 | + if sys.version_info < (3, 5): |
| 173 | + print('Not tested') |
| 174 | + else: |
| 175 | + assert ot.utils._is_deprecated(fun) is True |
| 176 | + |
| 177 | + assert ot.utils._is_deprecated(fun2) is False |
| 178 | + |
| 179 | + |
| 180 | +def test_BaseEstimator(): |
| 181 | + |
| 182 | + class Class(ot.utils.BaseEstimator): |
| 183 | + |
| 184 | + def __init__(self, first='spam', second='eggs'): |
| 185 | + |
| 186 | + self.first = first |
| 187 | + self.second = second |
| 188 | + |
| 189 | + cl = Class() |
| 190 | + |
| 191 | + names = cl._get_param_names() |
| 192 | + assert 'first' in names |
| 193 | + assert 'second' in names |
| 194 | + |
| 195 | + params = cl.get_params() |
| 196 | + assert 'first' in params |
| 197 | + assert 'second' in params |
| 198 | + |
| 199 | + params['first'] = 'spam again' |
| 200 | + cl.set_params(**params) |
| 201 | + |
| 202 | + assert cl.first == 'spam again' |
0 commit comments