@@ -2217,15 +2217,14 @@ def make_node(self, axis, *tensors):
22172217 # except for the axis dimension.
22182218 # Initialize bcastable all false, and then fill in some trues with
22192219 # the loops.
2220- ndim = tensors [0 ].type .ndim
2221- out_shape = [None ] * ndim
22222220
22232221 if not isinstance (axis , int ):
22242222 try :
22252223 axis = int (get_scalar_constant_value (axis ))
22262224 except NotScalarConstantError :
22272225 pass
22282226
2227+ ndim = tensors [0 ].type .ndim
22292228 if isinstance (axis , int ):
22302229 # Basically, broadcastable -> length 1, but the
22312230 # converse does not hold. So we permit e.g. T/F/T
@@ -2241,30 +2240,55 @@ def make_node(self, axis, *tensors):
22412240 )
22422241 if axis < 0 :
22432242 axis += ndim
2244-
2245- for x in tensors :
2246- for current_axis , s in enumerate (x .type .shape ):
2247- # Constant negative axis can no longer be negative at
2248- # this point. It safe to compare this way.
2249- if current_axis == axis :
2250- continue
2251- if s == 1 :
2252- out_shape [current_axis ] = 1
2253- try :
2254- out_shape [axis ] = None
2255- except IndexError :
2243+ if axis > ndim - 1 :
22562244 raise ValueError (
22572245 f"Axis value { axis } is out of range for the given input dimensions"
22582246 )
2247+ # NOTE: Constant negative axis can no longer be negative at this point.
2248+
2249+ in_shapes = [x .type .shape for x in tensors ]
2250+ in_ndims = [len (s ) for s in in_shapes ]
2251+ if set (in_ndims ) != {ndim }:
2252+ raise TypeError (
2253+ "Only tensors with the same number of dimensions can be joined."
2254+ f" Input ndims were: { in_ndims } ."
2255+ )
2256+
2257+ # Determine output shapes from a matrix of input shapes
2258+ in_shapes = np .array (in_shapes )
2259+ out_shape = [None ] * ndim
2260+ for d in range (ndim ):
2261+ ins = in_shapes [:, d ]
2262+ if d == axis :
2263+ # Any unknown size along the axis means we can't sum
2264+ if None in ins :
2265+ out_shape [d ] = None
2266+ else :
2267+ out_shape [d ] = sum (ins )
2268+ else :
2269+ inset = set (in_shapes [:, d ])
2270+ # Other dims must match exactly,
2271+ # or if a mix of None and ? the output will be ?
2272+ # otherwise the input shapes are incompatible.
2273+ if len (inset ) == 1 :
2274+ (out_shape [d ],) = inset
2275+ elif len (inset - {None }) == 1 :
2276+ (out_shape [d ],) = inset - {None }
2277+ else :
2278+ raise ValueError (
2279+ f"all input array dimensions other than the specified `axis` ({ axis } )"
2280+ " must match exactly, or be unknown (None),"
2281+ f" but along dimension { d } , the inputs shapes are incompatible: { ins } "
2282+ )
22592283 else :
22602284 # When the axis may vary, no dimension can be guaranteed to be
22612285 # broadcastable.
22622286 out_shape = [None ] * tensors [0 ].type .ndim
22632287
2264- if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2265- raise TypeError (
2266- "Only tensors with the same number of dimensions can be joined"
2267- )
2288+ if not builtins .all (x .ndim == len (out_shape ) for x in tensors ):
2289+ raise TypeError (
2290+ "Only tensors with the same number of dimensions can be joined"
2291+ )
22682292
22692293 inputs = [as_tensor_variable (axis )] + list (tensors )
22702294
0 commit comments