44from typing import cast
55
66import numpy as np
7- from scipy.optimize import minimize as scipy_minimize
8- from scipy.optimize import minimize_scalar as scipy_minimize_scalar
9- from scipy.optimize import root as scipy_root
10- from scipy.optimize import root_scalar as scipy_root_scalar
117
128import pytensor.scalar as ps
13- from pytensor import Variable, function, graph_replace
9+ from pytensor.compile.function import function
1410from pytensor.gradient import grad, hessian, jacobian
1511from pytensor.graph import Apply, Constant, FunctionGraph
1612from pytensor.graph.basic import ancestors, truncated_graph_inputs
1713from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
14+ from pytensor.graph.replace import graph_replace
1815from pytensor.tensor.basic import (
1916 atleast_2d,
2017 concatenate,
2421)
2522from pytensor.tensor.math import dot
2623from pytensor.tensor.slinalg import solve
27- from pytensor.tensor.variable import TensorVariable
24+ from pytensor.tensor.variable import TensorVariable, Variable
25+
26+
27+ # scipy.optimize can be slow to import, and will not be used by most users
28+ # We import scipy.optimize lazily inside optimization perform methods to avoid this.
29+ optimize = None
2830
2931
3032_log = logging.getLogger(__name__)
@@ -352,8 +354,6 @@ def implict_optimization_grads(
352354
353355
354356class MinimizeScalarOp(ScipyScalarWrapperOp):
355- __props__ = ("method",)
356-
357357 def __init__(
358358 self,
359359 x: Variable,
@@ -377,15 +377,22 @@ def __init__(
377377 self._fn = None
378378 self._fn_wrapped = None
379379
380+ def __str__(self):
381+ return f"{self.__class__.__name__}(method={self.method})"
382+
380383 def perform(self, node, inputs, outputs):
384+ global optimize
385+ if optimize is None:
386+ import scipy.optimize as optimize
387+
381388 f = self.fn_wrapped
382389 f.clear_cache()
383390
384391 # minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
385392 # the args of the objective function), but it is not used in the optimization.
386393 x0, *args = inputs
387394
388- res = scipy_minimize_scalar (
395+ res = optimize.minimize_scalar (
389396 fun=f.value,
390397 args=tuple(args),
391398 method=self.method,
@@ -426,6 +433,27 @@ def minimize_scalar(
426433):
427434 """
428435 Minimize a scalar objective function using scipy.optimize.minimize_scalar.
436+
437+ Parameters
438+ ----------
439+ objective : TensorVariable
440+ The objective function to minimize. This should be a PyTensor variable representing a scalar value.
441+ x : TensorVariable
442+ The variable with respect to which the objective function is minimized. It must be a scalar and an
443+ input to the computational graph of `objective`.
444+ method : str, optional
445+ The optimization method to use. Default is "brent". See `scipy.optimize.minimize_scalar` for other options.
446+ optimizer_kwargs : dict, optional
447+ Additional keyword arguments to pass to `scipy.optimize.minimize_scalar`.
448+
449+ Returns
450+ -------
451+ solution: TensorVariable
452+ Value of `x` that minimizes `objective(x, *args)`. If the success flag is False, this will be the
453+ final state returned by the minimization routine, not necessarily a minimum.
454+ success : TensorVariable
455+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
456+ value, based on the requested convergence criteria.
429457 """
430458
431459 args = _find_optimization_parameters(objective, x)
@@ -438,12 +466,14 @@ def minimize_scalar(
438466 optimizer_kwargs=optimizer_kwargs,
439467 )
440468
441- return minimize_scalar_op(x, *args)
469+ solution, success = cast(
470+ tuple[TensorVariable, TensorVariable], minimize_scalar_op(x, *args)
471+ )
442472
473+ return solution, success
443474
444- class MinimizeOp(ScipyWrapperOp):
445- __props__ = ("method", "jac", "hess", "hessp")
446475
476+ class MinimizeOp(ScipyWrapperOp):
447477 def __init__(
448478 self,
449479 x: Variable,
@@ -487,11 +517,24 @@ def __init__(
487517 self._fn = None
488518 self._fn_wrapped = None
489519
520+ def __str__(self):
521+ str_args = ", ".join(
522+ [
523+ f"{arg}={getattr(self, arg)}"
524+ for arg in ["method", "jac", "hess", "hessp"]
525+ ]
526+ )
527+ return f"{self.__class__.__name__}({str_args})"
528+
490529 def perform(self, node, inputs, outputs):
530+ global optimize
531+ if optimize is None:
532+ import scipy.optimize as optimize
533+
491534 f = self.fn_wrapped
492535 x0, *args = inputs
493536
494- res = scipy_minimize (
537+ res = optimize.minimize (
495538 fun=f.value_and_grad if self.jac else f.value,
496539 jac=self.jac,
497540 x0=x0,
@@ -538,7 +581,7 @@ def minimize(
538581 jac: bool = True,
539582 hess: bool = False,
540583 optimizer_kwargs: dict | None = None,
541- ):
584+ ) -> tuple[TensorVariable, TensorVariable] :
542585 """
543586 Minimize a scalar objective function using scipy.optimize.minimize.
544587
@@ -563,9 +606,13 @@ def minimize(
563606
564607 Returns
565608 -------
566- TensorVariable
567- The optimized value of x that minimizes the objective function.
609+ solution: TensorVariable
610+ The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
611+ is False, this will be the final state of the minimization routine, but not necessarily a minimum.
568612
613+ success: TensorVariable
614+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
615+ value, based on the requested convergence criteria.
569616 """
570617 args = _find_optimization_parameters(objective, x)
571618
@@ -579,12 +626,14 @@ def minimize(
579626 optimizer_kwargs=optimizer_kwargs,
580627 )
581628
582- return minimize_op(x, *args)
629+ solution, success = cast(
630+ tuple[TensorVariable, TensorVariable], minimize_op(x, *args)
631+ )
632+
633+ return solution, success
583634
584635
585636class RootScalarOp(ScipyScalarWrapperOp):
586- __props__ = ("method", "jac", "hess")
587-
588637 def __init__(
589638 self,
590639 variables,
@@ -633,14 +682,24 @@ def __init__(
633682 self._fn = None
634683 self._fn_wrapped = None
635684
685+ def __str__(self):
686+ str_args = ", ".join(
687+ [f"{arg}={getattr(self, arg)}" for arg in ["method", "jac", "hess"]]
688+ )
689+ return f"{self.__class__.__name__}({str_args})"
690+
636691 def perform(self, node, inputs, outputs):
692+ global optimize
693+ if optimize is None:
694+ import scipy.optimize as optimize
695+
637696 f = self.fn_wrapped
638697 f.clear_cache()
639698 # f.copy_x = True
640699
641700 variables, *args = inputs
642701
643- res = scipy_root_scalar (
702+ res = optimize.root_scalar (
644703 f=f.value,
645704 fprime=f.grad if self.jac else None,
646705 fprime2=f.hess if self.hess else None,
@@ -676,19 +735,48 @@ def L_op(self, inputs, outputs, output_grads):
676735
677736def root_scalar(
678737 equation: TensorVariable,
679- variables : TensorVariable,
738+ variable : TensorVariable,
680739 method: str = "secant",
681740 jac: bool = False,
682741 hess: bool = False,
683742 optimizer_kwargs: dict | None = None,
684- ):
743+ ) -> tuple[TensorVariable, TensorVariable] :
685744 """
686745 Find roots of a scalar equation using scipy.optimize.root_scalar.
746+
747+ Parameters
748+ ----------
749+ equation : TensorVariable
750+ The equation for which to find roots. This should be a PyTensor variable representing a single equation in one
751+ variable. The function will find `variables` such that `equation(variables, *args) = 0`.
752+ variable : TensorVariable
753+ The variable with respect to which the equation is solved. It must be a scalar and an input to the
754+ computational graph of `equation`.
755+ method : str, optional
756+ The root-finding method to use. Default is "secant". See `scipy.optimize.root_scalar` for other options.
757+ jac : bool, optional
758+ Whether to compute and use the first derivative of the equation with respect to `variables`.
759+ Default is False. Some methods require this.
760+ hess : bool, optional
761+ Whether to compute and use the second derivative of the equation with respect to `variables`.
762+ Default is False. Some methods require this.
763+ optimizer_kwargs : dict, optional
764+ Additional keyword arguments to pass to `scipy.optimize.root_scalar`.
765+
766+ Returns
767+ -------
768+ solution: TensorVariable
769+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
770+ causes `equation` to evaluate to zero. Otherwise it is the final state returned by the root-finding
771+ routine, but not necessarily a root.
772+
773+ success: TensorVariable
774+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
687775 """
688- args = _find_optimization_parameters(equation, variables )
776+ args = _find_optimization_parameters(equation, variable )
689777
690778 root_scalar_op = RootScalarOp(
691- variables ,
779+ variable ,
692780 *args,
693781 equation=equation,
694782 method=method,
@@ -697,7 +785,11 @@ def root_scalar(
697785 optimizer_kwargs=optimizer_kwargs,
698786 )
699787
700- return root_scalar_op(variables, *args)
788+ solution, success = cast(
789+ tuple[TensorVariable, TensorVariable], root_scalar_op(variable, *args)
790+ )
791+
792+ return solution, success
701793
702794
703795class RootOp(ScipyWrapperOp):
@@ -734,6 +826,12 @@ def __init__(
734826 self._fn = None
735827 self._fn_wrapped = None
736828
829+ def __str__(self):
830+ str_args = ", ".join(
831+ [f"{arg}={getattr(self, arg)}" for arg in ["method", "jac"]]
832+ )
833+ return f"{self.__class__.__name__}({str_args})"
834+
737835 def build_fn(self):
738836 outputs = self.inner_outputs
739837 variables, *args = self.inner_inputs
@@ -761,13 +859,17 @@ def build_fn(self):
761859 self._fn_wrapped = LRUCache1(fn)
762860
763861 def perform(self, node, inputs, outputs):
862+ global optimize
863+ if optimize is None:
864+ import scipy.optimize as optimize
865+
764866 f = self.fn_wrapped
765867 f.clear_cache()
766868 f.copy_x = True
767869
768870 variables, *args = inputs
769871
770- res = scipy_root (
872+ res = optimize.root (
771873 fun=f,
772874 jac=self.jac,
773875 x0=variables,
@@ -815,8 +917,36 @@ def root(
815917 method: str = "hybr",
816918 jac: bool = True,
817919 optimizer_kwargs: dict | None = None,
818- ):
819- """Find roots of a system of equations using scipy.optimize.root."""
920+ ) -> tuple[TensorVariable, TensorVariable]:
921+ """
922+ Find roots of a system of equations using scipy.optimize.root.
923+
924+ Parameters
925+ ----------
926+ equations : TensorVariable
927+ The system of equations for which to find roots. This should be a PyTensor variable representing a
928+ vector (or scalar) value. The function will find `variables` such that `equations(variables, *args) = 0`.
929+ variables : TensorVariable
930+ The variable(s) with respect to which the system of equations is solved. It must be an input to the
931+ computational graph of `equations` and have the same number of dimensions as `equations`.
932+ method : str, optional
933+ The root-finding method to use. Default is "hybr". See `scipy.optimize.root` for other options.
934+ jac : bool, optional
935+ Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
936+ Default is True. Most methods require this.
937+ optimizer_kwargs : dict, optional
938+ Additional keyword arguments to pass to `scipy.optimize.root`.
939+
940+ Returns
941+ -------
942+ solution: TensorVariable
943+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
944+ causes all `equations` to evaluate to zero. Otherwise it is the final state returned by the root-finding
945+ routine, but not necessarily a root.
946+
947+ success: TensorVariable
948+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
949+ """
820950
821951 args = _find_optimization_parameters(equations, variables)
822952
@@ -829,7 +959,11 @@ def root(
829959 optimizer_kwargs=optimizer_kwargs,
830960 )
831961
832- return root_op(variables, *args)
962+ solution, success = cast(
963+ tuple[TensorVariable, TensorVariable], root_op(variables, *args)
964+ )
965+
966+ return solution, success
833967
834968
835969__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"]
0 commit comments