@@ -1667,18 +1667,15 @@ def test_broadcastable_flag_assignment_mixed_otheraxes(self):
16671667 a = self .shared (a_val , shape = (None , None , 1 ))
16681668 b = self .shared (b_val , shape = (1 , None , 1 ))
16691669 c = self .join_op (1 , a , b )
1670- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1671- assert c .type .shape [1 ] != 1
1670+ assert c .type .shape == (1 , None , 1 )
16721671
16731672 # Opt can remplace the int by an PyTensor constant
16741673 c = self .join_op (constant (1 ), a , b )
1675- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1676- assert c .type .shape [1 ] != 1
1674+ assert c .type .shape == (1 , None , 1 )
16771675
16781676 # In case futur opt insert other useless stuff
16791677 c = self .join_op (cast (constant (1 ), dtype = "int32" ), a , b )
1680- assert c .type .shape [0 ] == 1 and c .type .shape [2 ] == 1
1681- assert c .type .shape [1 ] != 1
1678+ assert c .type .shape == (1 , None , 1 )
16821679
16831680 f = function ([], c , mode = self .mode )
16841681 topo = f .maker .fgraph .toposort ()
@@ -1783,15 +1780,21 @@ def test_broadcastable_flags_many_dims_and_inputs(self):
17831780 c = TensorType (dtype = self .floatX , shape = (1 , None , None , None , None , None ))()
17841781 d = TensorType (dtype = self .floatX , shape = (1 , None , 1 , 1 , None , 1 ))()
17851782 e = TensorType (dtype = self .floatX , shape = (1 , None , 1 , None , None , 1 ))()
1783+
17861784 f = self .join_op (0 , a , b , c , d , e )
17871785 fb = tuple (s == 1 for s in f .type .shape )
1788- assert not fb [0 ] and fb [1 ] and fb [2 ] and fb [3 ] and not fb [4 ] and fb [5 ]
1786+ assert f .type .shape == (5 , 1 , 1 , 1 , None , 1 )
1787+ assert fb == (False , True , True , True , False , True )
1788+
17891789 g = self .join_op (1 , a , b , c , d , e )
17901790 gb = tuple (s == 1 for s in g .type .shape )
1791- assert gb [0 ] and not gb [1 ] and gb [2 ] and gb [3 ] and not gb [4 ] and gb [5 ]
1791+ assert g .type .shape == (1 , None , 1 , 1 , None , 1 )
1792+ assert gb == (True , False , True , True , False , True )
1793+
17921794 h = self .join_op (4 , a , b , c , d , e )
17931795 hb = tuple (s == 1 for s in h .type .shape )
1794- assert hb [0 ] and hb [1 ] and hb [2 ] and hb [3 ] and not hb [4 ] and hb [5 ]
1796+ assert h .type .shape == (1 , 1 , 1 , 1 , None , 1 )
1797+ assert hb == (True , True , True , True , False , True )
17951798
17961799 f = function ([a , b , c , d , e ], f , mode = self .mode )
17971800 topo = f .maker .fgraph .toposort ()
@@ -1903,7 +1906,7 @@ def test_mixed_ndim_error(self):
19031906 rng = np .random .default_rng (seed = utt .fetch_seed ())
19041907 v = self .shared (rng .random (4 ).astype (self .floatX ))
19051908 m = self .shared (rng .random ((4 , 4 )).astype (self .floatX ))
1906- with pytest .raises (TypeError ):
1909+ with pytest .raises (TypeError , match = "same number of dimensions" ):
19071910 self .join_op (0 , v , m )
19081911
19091912 def test_split_0elem (self ):
0 commit comments