@@ -37,12 +37,15 @@ ctypedef struct function_info:
3737cdef class patch:
3838 cdef int functions_count
3939 cdef function_info* functions
40+ cdef bint _is_patched
4041
41- functions_dict = {}
42+ functions_dict = dict ()
4243
4344 def __cinit__ (self ):
4445 cdef int pi, oi
4546
47+ self ._is_patched = False
48+
4649 umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
4750 self .functions_count = 0
4851 for umath in umaths:
@@ -95,7 +98,9 @@ cdef class patch:
9598 index = self .functions_dict[func]
9699 function = self .functions[index].patch_function
97100 signature = self .functions[index].signature
98- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
101+ res = cnp.PyUFunc_ReplaceLoopBySignature(< cnp.ufunc> np_umath, function, signature, & temp)
102+
103+ self ._is_patched = True
99104
100105 def do_unpatch (self ):
101106 cdef int res
@@ -110,6 +115,10 @@ cdef class patch:
110115 signature = self .functions[index].signature
111116 res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
112117
118+ self ._is_patched = False
119+
120+ def is_patched (self ):
121+ return self ._is_patched
113122
114123from threading import local as threading_local
115124_tls = threading_local()
@@ -123,14 +132,40 @@ def _initialize_tls():
123132 _tls.patch = patch()
124133 _tls.initialized = True
125134
126-
127- def do_patch ():
135+
136+ def use_in_numpy ():
137+ '''
138+ Enables using of mkl_umath in Numpy.
139+ '''
128140 if not _is_tls_initialized():
129141 _initialize_tls()
130142 _tls.patch.do_patch()
131143
132144
133- def do_unpatch ():
145+ def restore ():
146+ '''
147+ Disables using of mkl_umath in Numpy.
148+ '''
134149 if not _is_tls_initialized():
135150 _initialize_tls()
136151 _tls.patch.do_unpatch()
152+
153+
154+ def is_patched ():
155+ '''
156+ Returns whether Numpy has been patched with mkl_umath.
157+ '''
158+ if not _is_tls_initialized():
159+ _initialize_tls()
160+ _tls.patch.is_patched()
161+
162+ from contextlib import ContextDecorator
163+
164+ class mkl_umath (ContextDecorator ):
165+ def __enter__ (self ):
166+ use_in_numpy()
167+ return self
168+
169+ def __exit__ (self , *exc ):
170+ restore()
171+ return False
0 commit comments