Skip to content

Commit 22cda11

Browse files
committed
Handle upcasting of scalar to vector arrays by scipy vector optimizers
1 parent e126020 commit 22cda11

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

pytensor/tensor/optimize.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,8 @@ def make_node(self, *inputs):
233233

234234
class ScipyScalarWrapperOp(ScipyWrapperOp):
235235
def build_fn(self):
236-
"""
237-
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
238-
wrapper function logic is there to handle this.
239-
"""
240-
241-
# We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
242-
# we need to adjust the graph to work with what scipy will be passing into the inner function --
243-
# always scalar, and always float64
236+
# We need to adjust the graph to work with what scipy will be passing into the inner function --
237+
# always scalar array of float64 type
244238
x, *args = self.inner_inputs
245239
new_root_x = ps.float64(name="x_scalar")
246240
new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype))
@@ -255,6 +249,24 @@ def build_fn(self):
255249
self._fn_wrapped = LRUCache1(fn)
256250

257251

252+
class ScipyVectorWrapperOp(ScipyWrapperOp):
253+
def build_fn(self):
254+
# We need to adjust the graph to work with what scipy will be passing into the inner function --
255+
# always a vector array with size of at least 1
256+
x, *args = self.inner_inputs
257+
if x.type.shape != ():
258+
return super().build_fn()
259+
260+
new_root_x = x[None].type()
261+
new_x = new_root_x.squeeze()
262+
new_outputs = graph_replace(self.inner_outputs, {x: new_x})
263+
self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True)
264+
265+
# Do this reassignment to see the compiled graph in the dprint
266+
# self.fgraph = fn.maker.fgraph
267+
self._fn_wrapped = LRUCache1(fn)
268+
269+
258270
def scalar_implict_optimization_grads(
259271
inner_fx: Variable,
260272
inner_x: Variable,
@@ -474,7 +486,7 @@ def minimize_scalar(
474486
return solution, success
475487

476488

477-
class MinimizeOp(ScipyWrapperOp):
489+
class MinimizeOp(ScipyVectorWrapperOp):
478490
def __init__(
479491
self,
480492
x: Variable,
@@ -808,7 +820,7 @@ def root_scalar(
808820
return solution, success
809821

810822

811-
class RootOp(ScipyWrapperOp):
823+
class RootOp(ScipyVectorWrapperOp):
812824
__props__ = ("method", "jac")
813825

814826
def __init__(

tests/tensor/test_optimize.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pytensor
55
import pytensor.tensor as pt
66
from pytensor import config, function
7+
from pytensor.graph import Apply, Op
8+
from pytensor.tensor import scalar
79
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
810
from tests import unittest_tools as utt
911

@@ -219,3 +221,30 @@ def root_fn(x, a, b):
219221
utt.verify_grad(
220222
root_fn, [x0, a_val, b_val], eps=1e-6 if floatX == "float64" else 1e-3
221223
)
224+
225+
226+
@pytest.mark.parametrize("optimize_op", (minimize, root))
227+
def test_minimize_0d(optimize_op):
228+
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
229+
230+
class AssertScalar(Op):
231+
view_map = {0: [0]}
232+
233+
def make_node(self, x):
234+
return Apply(self, [x], [x.type()])
235+
236+
def perform(self, node, inputs, output_storage):
237+
[x] = inputs
238+
assert x.ndim == 0
239+
output_storage[0][0] = x
240+
241+
def L_op(self, inputs, outputs, out_grads):
242+
return out_grads
243+
244+
x = scalar("x")
245+
x_check = AssertScalar()(x)
246+
opt_x, _ = optimize_op(x_check**2, x)
247+
opt_x_res = opt_x.eval({x: np.array(5, dtype=x.type.dtype)})
248+
np.testing.assert_allclose(
249+
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
250+
)

0 commit comments

Comments
 (0)