@@ -233,14 +233,8 @@ def make_node(self, *inputs):
233233
234234class ScipyScalarWrapperOp (ScipyWrapperOp ):
235235 def build_fn (self ):
236- """
237- This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
238- wrapper function logic is there to handle this.
239- """
240-
241- # We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
242- # we need to adjust the graph to work with what scipy will be passing into the inner function --
243- # always scalar, and always float64
236+ # We need to adjust the graph to work with what scipy will be passing into the inner function --
237+ # always scalar array of float64 type
244238 x , * args = self .inner_inputs
245239 new_root_x = ps .float64 (name = "x_scalar" )
246240 new_x = tensor_from_scalar (new_root_x .astype (x .type .dtype ))
@@ -255,6 +249,24 @@ def build_fn(self):
255249 self ._fn_wrapped = LRUCache1 (fn )
256250
257251
252+ class ScipyVectorWrapperOp (ScipyWrapperOp ):
253+ def build_fn (self ):
254+ # We need to adjust the graph to work with what scipy will be passing into the inner function --
255+ # always a vector array with size of at least 1
256+ x , * args = self .inner_inputs
257+ if x .type .shape != ():
258+ return super ().build_fn ()
259+
260+ new_root_x = x [None ].type ()
261+ new_x = new_root_x .squeeze ()
262+ new_outputs = graph_replace (self .inner_outputs , {x : new_x })
263+ self ._fn = fn = function ([new_root_x , * args ], new_outputs , trust_input = True )
264+
265+ # Do this reassignment to see the compiled graph in the dprint
266+ # self.fgraph = fn.maker.fgraph
267+ self ._fn_wrapped = LRUCache1 (fn )
268+
269+
258270def scalar_implict_optimization_grads (
259271 inner_fx : Variable ,
260272 inner_x : Variable ,
@@ -474,7 +486,7 @@ def minimize_scalar(
474486 return solution , success
475487
476488
477- class MinimizeOp (ScipyWrapperOp ):
489+ class MinimizeOp (ScipyVectorWrapperOp ):
478490 def __init__ (
479491 self ,
480492 x : Variable ,
@@ -808,7 +820,7 @@ def root_scalar(
808820 return solution , success
809821
810822
811- class RootOp (ScipyWrapperOp ):
823+ class RootOp (ScipyVectorWrapperOp ):
812824 __props__ = ("method" , "jac" )
813825
814826 def __init__ (
0 commit comments