@@ -256,6 +256,26 @@ def _obj_is_wrappable_as_tensor(x):
256256
257257
258258def get_scalar_constant_value (
259+ v , elemwise = True , only_process_constants = False , max_recur = 10
260+ ):
261+ """
262+ Checks whether 'v' is a scalar (ndim = 0).
263+
264+ If 'v' is a scalar then this function fetches the underlying constant by calling
265+ 'get_underlying_scalar_constant_value()'.
266+
267+ If 'v' is not a scalar, it raises a NotScalarConstantError.
268+
269+ """
270+ if isinstance (v , (Variable , np .ndarray )):
271+ if v .ndim != 0 :
272+ raise NotScalarConstantError ()
273+ return get_underlying_scalar_constant_value (
274+ v , elemwise , only_process_constants , max_recur
275+ )
276+
277+
278+ def get_underlying_scalar_constant_value (
259279 orig_v , elemwise = True , only_process_constants = False , max_recur = 10
260280):
261281 """Return the constant scalar(0-D) value underlying variable `v`.
@@ -358,7 +378,7 @@ def get_scalar_constant_value(
358378 elif isinstance (v .owner .op , CheckAndRaise ):
359379 # check if all conditions are constant and true
360380 conds = [
361- get_scalar_constant_value (c , max_recur = max_recur )
381+ get_underlying_scalar_constant_value (c , max_recur = max_recur )
362382 for c in v .owner .inputs [1 :]
363383 ]
364384 if builtins .all (0 == c .ndim and c != 0 for c in conds ):
@@ -372,7 +392,7 @@ def get_scalar_constant_value(
372392 continue
373393 if isinstance (v .owner .op , _scalar_constant_value_elemwise_ops ):
374394 const = [
375- get_scalar_constant_value (i , max_recur = max_recur )
395+ get_underlying_scalar_constant_value (i , max_recur = max_recur )
376396 for i in v .owner .inputs
377397 ]
378398 ret = [[None ]]
@@ -391,7 +411,7 @@ def get_scalar_constant_value(
391411 v .owner .op .scalar_op , _scalar_constant_value_elemwise_ops
392412 ):
393413 const = [
394- get_scalar_constant_value (i , max_recur = max_recur )
414+ get_underlying_scalar_constant_value (i , max_recur = max_recur )
395415 for i in v .owner .inputs
396416 ]
397417 ret = [[None ]]
@@ -437,7 +457,7 @@ def get_scalar_constant_value(
437457 ):
438458 idx = v .owner .op .idx_list [0 ]
439459 if isinstance (idx , Type ):
440- idx = get_scalar_constant_value (
460+ idx = get_underlying_scalar_constant_value (
441461 v .owner .inputs [1 ], max_recur = max_recur
442462 )
443463 try :
@@ -471,14 +491,14 @@ def get_scalar_constant_value(
471491 ):
472492 idx = v .owner .op .idx_list [0 ]
473493 if isinstance (idx , Type ):
474- idx = get_scalar_constant_value (
494+ idx = get_underlying_scalar_constant_value (
475495 v .owner .inputs [1 ], max_recur = max_recur
476496 )
477497 # Python 2.4 does not support indexing with numpy.integer
478498 # So we cast it.
479499 idx = int (idx )
480500 ret = v .owner .inputs [0 ].owner .inputs [idx ]
481- ret = get_scalar_constant_value (ret , max_recur = max_recur )
501+ ret = get_underlying_scalar_constant_value (ret , max_recur = max_recur )
482502 # MakeVector can cast implicitly its input in some case.
483503 return _asarray (ret , dtype = v .type .dtype )
484504
@@ -493,7 +513,7 @@ def get_scalar_constant_value(
493513 idx_list = op .idx_list
494514 idx = idx_list [0 ]
495515 if isinstance (idx , Type ):
496- idx = get_scalar_constant_value (
516+ idx = get_underlying_scalar_constant_value (
497517 owner .inputs [1 ], max_recur = max_recur
498518 )
499519 grandparent = leftmost_parent .owner .inputs [0 ]
@@ -508,7 +528,7 @@ def get_scalar_constant_value(
508528
509529 if not (idx < ndim ):
510530 msg = (
511- "get_scalar_constant_value detected "
531+ "get_underlying_scalar_constant_value detected "
512532 f"deterministic IndexError: x.shape[{ int (idx )} ] "
513533 f"when x.ndim={ int (ndim )} ."
514534 )
@@ -1570,7 +1590,7 @@ def do_constant_folding(self, fgraph, node):
15701590@_get_vector_length .register (Alloc )
15711591def _get_vector_length_Alloc (var_inst , var ):
15721592 try :
1573- return get_scalar_constant_value (var .owner .inputs [1 ])
1593+ return get_underlying_scalar_constant_value (var .owner .inputs [1 ])
15741594 except NotScalarConstantError :
15751595 raise ValueError (f"Length of { var } cannot be determined" )
15761596
@@ -1821,17 +1841,17 @@ def perform(self, node, inp, out_):
18211841
18221842def extract_constant (x , elemwise = True , only_process_constants = False ):
18231843 """
1824- This function is basically a call to tensor.get_scalar_constant_value .
1844+ This function is basically a call to tensor.get_underlying_scalar_constant_value .
18251845
18261846 The main difference is the behaviour in case of failure. While
1827- get_scalar_constant_value raises an TypeError, this function returns x,
1847+ get_underlying_scalar_constant_value raises an TypeError, this function returns x,
18281848 as a tensor if possible. If x is a ScalarVariable from a
18291849 scalar_from_tensor, we remove the conversion. If x is just a
18301850 ScalarVariable, we convert it to a tensor with tensor_from_scalar.
18311851
18321852 """
18331853 try :
1834- x = get_scalar_constant_value (x , elemwise , only_process_constants )
1854+ x = get_underlying_scalar_constant_value (x , elemwise , only_process_constants )
18351855 except NotScalarConstantError :
18361856 pass
18371857 if isinstance (x , aes .ScalarVariable ) or isinstance (
@@ -2201,7 +2221,7 @@ def make_node(self, axis, *tensors):
22012221
22022222 if not isinstance (axis , int ):
22032223 try :
2204- axis = int (get_scalar_constant_value (axis ))
2224+ axis = int (get_underlying_scalar_constant_value (axis ))
22052225 except NotScalarConstantError :
22062226 pass
22072227
@@ -2450,7 +2470,7 @@ def infer_shape(self, fgraph, node, ishapes):
24502470def _get_vector_length_Join (op , var ):
24512471 axis , * arrays = var .owner .inputs
24522472 try :
2453- axis = get_scalar_constant_value (axis )
2473+ axis = get_underlying_scalar_constant_value (axis )
24542474 assert axis == 0 and builtins .all (a .ndim == 1 for a in arrays )
24552475 return builtins .sum (get_vector_length (a ) for a in arrays )
24562476 except NotScalarConstantError :
@@ -2862,7 +2882,7 @@ def infer_shape(self, fgraph, node, i_shapes):
28622882
28632883 def is_constant_value (var , value ):
28642884 try :
2865- v = get_scalar_constant_value (var )
2885+ v = get_underlying_scalar_constant_value (var )
28662886 return np .all (v == value )
28672887 except NotScalarConstantError :
28682888 pass
@@ -3774,7 +3794,7 @@ def make_node(self, a, choices):
37743794 static_out_shape = ()
37753795 for s in out_shape :
37763796 try :
3777- s_val = pytensor .get_scalar_constant_value (s )
3797+ s_val = pytensor .get_underlying_scalar_constant (s )
37783798 except (NotScalarConstantError , AttributeError ):
37793799 s_val = None
37803800
@@ -4095,6 +4115,7 @@ def take_along_axis(arr, indices, axis=0):
40954115 "scalar_from_tensor" ,
40964116 "tensor_from_scalar" ,
40974117 "get_scalar_constant_value" ,
4118+ "get_underlying_scalar_constant_value" ,
40984119 "constant" ,
40994120 "as_tensor_variable" ,
41004121 "as_tensor" ,
0 commit comments