2323# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
26+ import pytest
2627import numpy as np
2728import mkl_umath ._ufuncs as mu
2829import numpy .core .umath as nu
@@ -49,11 +50,8 @@ def get_args(args_str):
4950 return tuple (args )
5051
5152umaths = [i for i in dir (mu ) if isinstance (getattr (mu , i ), np .ufunc )]
52-
5353umaths .remove ('arccosh' ) # expects input greater than 1
5454
55- # dictionary with test cases
56- # (umath, types) : args
5755generated_cases = {}
5856for umath in umaths :
5957 mkl_umath = getattr (mu , umath )
@@ -64,29 +62,30 @@ def get_args(args_str):
6462 generated_cases [(umath , type )] = args
6563
6664additional_cases = {
67- ('arccosh' , 'f->f' ) : (np .single (np .random .random_sample () + 1 ),),
68- ('arccosh' , 'd->d' ) : (np .double (np .random .random_sample () + 1 ),),
65+ ('arccosh' , 'f->f' ): (np .single (np .random .random_sample () + 1 ),),
66+ ('arccosh' , 'd->d' ): (np .double (np .random .random_sample () + 1 ),),
6967}
7068
71- test_cases = {}
72- for d in (generated_cases , additional_cases ):
73- test_cases .update (d )
69+ test_cases = {** generated_cases , ** additional_cases }
7470
75- for case in test_cases :
76- umath = case [ 0 ]
77- type = case [ 1 ]
71+ @ pytest . mark . parametrize ( " case" , list ( test_cases . keys ()))
72+ def test_umath ( case ):
73+ umath , type = case
7874 args = test_cases [case ]
7975 mkl_umath = getattr (mu , umath )
8076 np_umath = getattr (nu , umath )
8177 print ('*' * 80 )
82- print (umath , type )
83- print ("args" , args )
78+ print (f"Testing { umath } with type { type } " )
79+ print ("args:" , args )
80+
8481 mkl_res = mkl_umath (* args )
8582 np_res = np_umath (* args )
86- print ("mkl res" , mkl_res )
87- print ("npy res" , np_res )
88-
89- assert np .allclose (mkl_res , np_res )
83+
84+ print ("mkl res:" , mkl_res )
85+ print ("npy res:" , np_res )
86+
87+ assert np .allclose (mkl_res , np_res ), f"Results for { umath } do not match"
9088
91- print ("Test cases count:" , len (test_cases ))
92- print ("All looks good!" )
89+ def test_cases_count ():
90+ print ("Test cases count:" , len (test_cases ))
91+ assert len (test_cases ) > 0 , "No test cases found"
0 commit comments