From 04330788a1dffcc3e4459bea2f9071352133eda0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 1 Dec 2025 12:47:32 +0100 Subject: [PATCH 1/2] Numba Blockwise: Fix OpFromGraph as core_op --- pytensor/compile/builders.py | 9 +++++++-- pytensor/link/numba/dispatch/blockwise.py | 1 - pytensor/link/numba/dispatch/elemwise.py | 1 - tests/link/numba/test_compile_ops.py | 21 ++++++++++++++++++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index c76d7db33b..055497a8ff 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -4,6 +4,7 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial +from itertools import chain from typing import Union, cast from pytensor.compile.function import function @@ -47,11 +48,15 @@ def infer_shape(outs, inputs, input_shapes): assert len(inp_shp) == inp.type.ndim shape_feature = ShapeFeature() - shape_feature.on_attach(FunctionGraph([], [])) + fgraph = FunctionGraph([], [], features=[shape_feature]) + for v in chain.from_iterable(s for s in input_shapes if s is not None): + # Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before + if (node := v.owner) is not None: + fgraph.import_node(node, import_missing=True) # Initialize shape_of with the input shapes for inp, inp_shp in zip(inputs, input_shapes, strict=True): - shape_feature.set_shape(inp, inp_shp) + shape_feature.set_shape(inp, inp_shp, override=True) def local_traverse(out): """ diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 17f68045d7..e0b086e89c 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -36,7 +36,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_op_fn, core_op_key = numba_funcify_and_cache_key( core_op, node=core_node, - parent_node=node, **kwargs, ) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 9b2c9f514c..ef5fcc7196 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -273,7 +273,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( op.scalar_op, node=scalar_node, - parent_node=node, **kwargs, ) diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index 918b4324d3..5e263762ae 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -4,9 +4,10 @@ from pytensor import OpFromGraph, config, function, ifelse from pytensor import tensor as pt from pytensor.compile import ViewOp +from pytensor.graph import vectorize_graph from pytensor.raise_op import assert_op from pytensor.scalar import Add -from pytensor.tensor import matrix +from pytensor.tensor import dmatrix, dtensor3, matrix from pytensor.tensor.elemwise import Elemwise from tests.link.numba.test_basic import compare_numba_and_py @@ -171,6 +172,24 @@ def test_ofg_aliased_outputs(): np.testing.assert_allclose(res, np.ones((2, 2))) +def test_ofg_elemwise_regression(): + # Regression bug for https://github.com/pymc-devs/pytensor/issues/1507 + x = dmatrix("x", shape=(None, None)) + z = OpFromGraph( + inputs=[x], + outputs=[x + 1], + )(x) + + x_batched = dtensor3("X_batched", shape=(None, None, None)) + z_batched = vectorize_graph(z, {x: x_batched}) + compare_numba_and_py( + [x_batched], + [z_batched], + [np.random.normal(size=(3, 2, 4))], + eval_obj_mode=False, + ) + + def test_check_and_raise(): x = pt.vector() x_test_value = np.array([1.0, 2.0], dtype=config.floatX) From e187a6ebff57d09a05e62516eebcce0f617124c3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 30 Nov 2025 15:51:32 +0100 Subject: [PATCH 2/2] Numba Alloc: Patch so it works inside a Blockwise --- pytensor/link/numba/dispatch/tensor_basic.py | 8 +++++--- tests/link/numba/test_blockwise.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 5d86a525b1..b7529c6d03 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -74,13 +74,13 @@ def numba_funcify_Alloc(op, node, **kwargs): f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")' ) check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4) - + dtype = node.inputs[0].type.dtype alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} {check_runtime_broadcast_src} - res = np.empty(scalar_shape, dtype=val.dtype) + res = np.empty(scalar_shape, dtype=np.{dtype}) res[...] = val return res """ @@ -88,10 +88,12 @@ def alloc(val, {", ".join(shape_var_names)}): alloc_def_src, "alloc", globals() | {"np": np}, + write_to_disk=True, ) + cache_version = -1 cache_key = sha256( - str((type(op), node.inputs[0].type.broadcastable)).encode() + str((type(op), node.inputs[0].type.broadcastable, cache_version)).encode() ).hexdigest() return numba_basic.numba_njit(alloc_fn), cache_key diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py index 702efe6ed9..59df988747 100644 --- a/tests/link/numba/test_blockwise.py +++ b/tests/link/numba/test_blockwise.py @@ -2,8 +2,8 @@ import pytest from pytensor import function -from pytensor.tensor import tensor, tensor3 -from pytensor.tensor.basic import ARange +from pytensor.tensor import lvector, tensor, tensor3 +from pytensor.tensor.basic import Alloc, ARange, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.slinalg import Cholesky, cholesky @@ -70,3 +70,13 @@ def test_repeated_args(): final_node = fn.maker.fgraph.outputs[0].owner assert isinstance(final_node.op, BlockwiseWithCoreShape) assert final_node.inputs[0] is final_node.inputs[1] + + +def test_blockwise_alloc(): + val = lvector("val") + out = Blockwise(Alloc(), signature="(),(),()->(2,3)")( + val, constant(2, dtype="int64"), constant(3, dtype="int64") + ) + assert out.type.ndim == 3 + + compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)