File tree Expand file tree Collapse file tree 2 files changed +10
-5
lines changed
Expand file tree Collapse file tree 2 files changed +10
-5
lines changed Original file line number Diff line number Diff line change 44from typing import cast
55
66import numpy as np
7+ from numpy .core .numeric import normalize_axis_tuple # type: ignore
78
89import pytensor
910from pytensor .gradient import DisconnectedType
@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes):
994995 if not axes :
995996 return x
996997
997- if max (axes ) >= x .type .ndim :
998- raise ValueError ("Trying to specify broadcastable of non-existent dimension" )
999-
998+ axes = normalize_axis_tuple (axes , x .type .ndim )
1000999 shape_info = [1 if i in axes else s for i , s in enumerate (x .type .shape )]
10011000 return specify_shape (x , shape_info )
10021001
Original file line number Diff line number Diff line change @@ -562,16 +562,22 @@ def test_basic(self):
562562 x = matrix ()
563563 assert specify_broadcastable (x , 0 ).type .shape == (1 , None )
564564 assert specify_broadcastable (x , 1 ).type .shape == (None , 1 )
565+ assert specify_broadcastable (x , - 1 ).type .shape == (None , 1 )
565566 assert specify_broadcastable (x , 0 , 1 ).type .shape == (1 , 1 )
566567
567568 x = row ()
568569 assert specify_broadcastable (x , 0 ) is x
569570 assert specify_broadcastable (x , 1 ) is not x
571+ assert specify_broadcastable (x , - 2 ) is x
570572
571573 def test_validation (self ):
572574 x = matrix ()
573- with pytest .raises (ValueError , match = "^Trying to specify broadcastable of*" ):
574- specify_broadcastable (x , 2 )
575+ axis = 2
576+ with pytest .raises (
577+ ValueError ,
578+ match = f"axis { axis } is out of bounds for array of dimension { axis } " ,
579+ ):
580+ specify_broadcastable (x , axis )
575581
576582
577583class TestRopLop (RopLopChecker ):
You can’t perform that action at this time.
0 commit comments