@@ -25,34 +25,80 @@ import numpy as np
2525
2626from libc.stdlib cimport malloc, free
2727
28- cimport cpython.pycapsule
29-
3028cnp.import_umath()
3129
30+
3231ctypedef struct function_info:
33- cnp.PyUFuncGenericFunction np_function
34- cnp.PyUFuncGenericFunction mkl_function
32+ cnp.PyUFuncGenericFunction original_function
33+ cnp.PyUFuncGenericFunction patch_function
3534 int * signature
3635
37- ctypedef struct functions_struct:
38- int count
39- function_info* functions
40-
41-
42- cdef const char * capsule_name = " functions_cache"
4336
37+ cdef class patch:
38+ cdef int functions_count
39+ cdef function_info* functions
4440
45- cdef void _capsule_destructor(object caps):
46- cdef functions_struct* fs
47-
48- if (caps is None ):
49- print (" Nothing to destroy" )
50- return
51- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52- for i in range (fs[0 ].count):
53- free(fs[0 ].functions[i].signature)
54- free(fs[0 ].functions)
55- free(fs)
41+ functions_dict = {}
42+
43+ def __cinit__ (self ):
44+ umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
45+ self .functions_count = 0
46+ for umath in umaths:
47+ mkl_umath = getattr (mu, umath)
48+ self .functions_count = self .functions_count + mkl_umath.ntypes
49+
50+ self .functions = < function_info * > malloc(self .functions_count * sizeof(function_info))
51+
52+ func_number = 0
53+ for umath in umaths:
54+ mkl_umath = getattr (mu, umath)
55+ np_umath = getattr (nu, umath)
56+ c_mkl_umath = < cnp.ufunc> mkl_umath
57+ c_np_umath = < cnp.ufunc> np_umath
58+ for type in mkl_umath.types:
59+ np_index = np_umath.types.index(type )
60+ self .functions[func_number].original_function = c_np_umath.functions[np_index]
61+ mkl_index = mkl_umath.types.index(type )
62+ self .functions[func_number].patch_function = c_mkl_umath.functions[mkl_index]
63+
64+ nargs = c_mkl_umath.nargs
65+ self .functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66+ for i in range (nargs):
67+ self .functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68+
69+ self .functions_dict[(umath, type )] = func_number
70+ func_number = func_number + 1
71+
72+ def __dealloc__ (self ):
73+ for i in range (self .functions_count):
74+ free(self .functions[i].signature)
75+ free(self .functions)
76+
77+ def do_patch (self ):
78+ cdef int res
79+ cdef cnp.PyUFuncGenericFunction temp
80+ cdef cnp.PyUFuncGenericFunction function
81+ cdef int * signature
82+
83+ for func in self .functions_dict:
84+ np_umath = getattr (nu, func[0 ])
85+ index = self .functions_dict[func]
86+ function = self .functions[index].patch_function
87+ signature = self .functions[index].signature
88+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
89+
90+ def do_unpatch (self ):
91+ cdef int res
92+ cdef cnp.PyUFuncGenericFunction temp
93+ cdef cnp.PyUFuncGenericFunction function
94+ cdef int * signature
95+
96+ for func in self .functions_dict:
97+ np_umath = getattr (nu, func[0 ])
98+ index = self .functions_dict[func]
99+ function = self .functions[index].original_function
100+ signature = self .functions[index].signature
101+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
56102
57103
58104from threading import local as threading_local
@@ -64,103 +110,17 @@ def _is_tls_initialized():
64110
65111
66112def _initialize_tls ():
67- cdef functions_struct* fs
68- cdef int funcs_count
69-
70- _tls.functions_dict = {}
71-
72- umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
73- funcs_count = 0
74- for umath in umaths:
75- mkl_umath = getattr (mu, umath)
76- funcs_count = funcs_count + mkl_umath.ntypes
77-
78- fs = < functions_struct * > malloc(sizeof(functions_struct))
79- fs[0 ].count = funcs_count
80- fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
81-
82- func_number = 0
83- for umath in umaths:
84- mkl_umath = getattr (mu, umath)
85- np_umath = getattr (nu, umath)
86- c_mkl_umath = < cnp.ufunc> mkl_umath
87- c_np_umath = < cnp.ufunc> np_umath
88- for type in mkl_umath.types:
89- np_index = np_umath.types.index(type )
90- fs[0 ].functions[func_number].np_function = c_np_umath.functions[np_index]
91- mkl_index = mkl_umath.types.index(type )
92- fs[0 ].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
93-
94- nargs = c_mkl_umath.nargs
95- fs[0 ].functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
96- for i in range (nargs):
97- fs[0 ].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
98-
99- _tls.functions_dict[(umath, type )] = func_number
100- func_number = func_number + 1
101-
102- _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103-
113+ _tls.patch = patch()
104114 _tls.initialized = True
105115
106-
107- def _get_func_dict ():
116+
117+ def do_patch ():
108118 if not _is_tls_initialized():
109119 _initialize_tls()
110- return _tls.functions_dict
120+ _tls.patch.do_patch()
111121
112122
113- cdef function_info* _get_functions():
114- cdef function_info* functions
115- cdef functions_struct* fs
116-
123+ def do_unpatch ():
117124 if not _is_tls_initialized():
118125 _initialize_tls()
119-
120- capsule = _tls.functions_capsule
121- if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122- raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123- fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124- return fs[0 ].functions
125-
126-
127- cdef void c_do_patch():
128- cdef int res
129- cdef cnp.PyUFuncGenericFunction temp
130- cdef cnp.PyUFuncGenericFunction function
131- cdef int * signature
132-
133- funcs_dict = _get_func_dict()
134- functions = _get_functions()
135-
136- for func in funcs_dict:
137- np_umath = getattr (nu, func[0 ])
138- index = funcs_dict[func]
139- function = functions[index].mkl_function
140- signature = functions[index].signature
141- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
142-
143-
144- cdef void c_do_unpatch():
145- cdef int res
146- cdef cnp.PyUFuncGenericFunction temp
147- cdef cnp.PyUFuncGenericFunction function
148- cdef int * signature
149-
150- funcs_dict = _get_func_dict()
151- functions = _get_functions()
152-
153- for func in funcs_dict:
154- np_umath = getattr (nu, func[0 ])
155- index = funcs_dict[func]
156- function = functions[index].np_function
157- signature = functions[index].signature
158- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
159-
160-
161- def do_patch ():
162- c_do_patch()
163-
164-
165- def do_unpatch ():
166- c_do_unpatch()
126+ _tls.patch.do_unpatch()
0 commit comments