Skip to content

Commit fc9923d

Browse files
committed
add tests for ot.uils
1 parent 5efdf00 commit fc9923d

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

test/test_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import ot
99
import numpy as np
10+
import sys
1011

1112

1213
def test_parmap():
@@ -123,3 +124,79 @@ def test_clean_zeros():
123124

124125
assert len(a) == n - nz
125126
assert len(b) == n - nz2
127+
128+
129+
def test_cost_normalization():
130+
131+
C = np.random.rand(10, 10)
132+
133+
# does nothing
134+
M0 = ot.utils.cost_normalization(C)
135+
np.testing.assert_allclose(C, M0)
136+
137+
M = ot.utils.cost_normalization(C, 'median')
138+
np.testing.assert_allclose(np.median(M), 1)
139+
140+
M = ot.utils.cost_normalization(C, 'max')
141+
np.testing.assert_allclose(M.max(), 1)
142+
143+
M = ot.utils.cost_normalization(C, 'log')
144+
np.testing.assert_allclose(M.max(), np.log(1 + C).max())
145+
146+
M = ot.utils.cost_normalization(C, 'loglog')
147+
np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
148+
149+
150+
def test_check_params():
151+
152+
res1 = ot.utils.check_params(first='OK', second=20)
153+
assert res1 is True
154+
155+
res0 = ot.utils.check_params(first='OK', second=None)
156+
assert res0 is False
157+
158+
159+
def test_deprecated_func():
160+
161+
@ot.utils.deprecated('deprecated text for fun')
162+
def fun():
163+
pass
164+
165+
def fun2():
166+
pass
167+
168+
@ot.utils.deprecated('deprecated text for class')
169+
class Class():
170+
pass
171+
172+
if sys.version_info < (3, 5):
173+
print('Not tested')
174+
else:
175+
assert ot.utils._is_deprecated(fun) is True
176+
177+
assert ot.utils._is_deprecated(fun2) is False
178+
179+
180+
def test_BaseEstimator():
181+
182+
class Class(ot.utils.BaseEstimator):
183+
184+
def __init__(self, first='spam', second='eggs'):
185+
186+
self.first = first
187+
self.second = second
188+
189+
cl = Class()
190+
191+
names = cl._get_param_names()
192+
assert 'first' in names
193+
assert 'second' in names
194+
195+
params = cl.get_params()
196+
assert 'first' in params
197+
assert 'second' in params
198+
199+
params['first'] = 'spam again'
200+
cl.set_params(**params)
201+
202+
assert cl.first == 'spam again'

0 commit comments

Comments
 (0)