Skip to content

Commit a1f09f3

Browse files
committed
add check_random_state in utils
1 parent 90efa5a commit a1f09f3

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

ot/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,26 @@ def check_params(**kwargs):
225225
return check
226226

227227

228+
def check_random_state(seed):
229+
"""Turn seed into a np.random.RandomState instance
230+
Parameters
231+
----------
232+
seed : None | int | instance of RandomState
233+
If seed is None, return the RandomState singleton used by np.random.
234+
If seed is an int, return a new RandomState instance seeded with seed.
235+
If seed is already a RandomState instance, return it.
236+
Otherwise raise ValueError.
237+
"""
238+
if seed is None or seed is np.random:
239+
return np.random.mtrand._rand
240+
if isinstance(seed, (int, np.integer)):
241+
return np.random.RandomState(seed)
242+
if isinstance(seed, np.random.RandomState):
243+
return seed
244+
raise ValueError('{} cannot be used to seed a numpy.random.RandomState'
245+
' instance'.format(seed))
246+
247+
228248
class deprecated(object):
229249

230250
"""Decorator to mark a function or class as deprecated.

0 commit comments

Comments
 (0)