@@ -555,13 +555,20 @@ def get_parents(self):
555555 return [self .owner ]
556556 return []
557557
558- def eval (self , inputs_to_values = None ):
559- r"""Evaluate the `Variable`.
558+ def eval (
559+ self ,
560+ inputs_to_values : dict [Union ["Variable" , str ], Any ] | None = None ,
561+ ** kwargs ,
562+ ):
563+ r"""Evaluate the `Variable` given a set of values for its inputs.
560564
561565 Parameters
562566 ----------
563567 inputs_to_values :
564- A dictionary mapping PyTensor `Variable`\s to values.
568+ A dictionary mapping PyTensor `Variable`\s or names to values.
569+ Not needed if variable has no required inputs.
570+ kwargs :
571+ Optional keyword arguments to pass to the underlying `pytensor.function`
565572
566573 Examples
567574 --------
@@ -591,10 +598,7 @@ def eval(self, inputs_to_values=None):
591598 """
592599 from pytensor .compile .function import function
593600
594- if inputs_to_values is None :
595- inputs_to_values = {}
596-
597- def convert_string_keys_to_variables (input_to_values ):
601+ def convert_string_keys_to_variables (inputs_to_values ) -> dict ["Variable" , Any ]:
598602 new_input_to_values = {}
599603 for key , value in inputs_to_values .items ():
600604 if isinstance (key , str ):
@@ -608,19 +612,32 @@ def convert_string_keys_to_variables(input_to_values):
608612 new_input_to_values [key ] = value
609613 return new_input_to_values
610614
611- inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
615+ parsed_inputs_to_values : dict [Variable , Any ] = {}
616+ if inputs_to_values is not None :
617+ parsed_inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
612618
613619 if not hasattr (self , "_fn_cache" ):
614- self ._fn_cache = dict ()
620+ self ._fn_cache : dict = dict ()
615621
616- inputs = tuple (sorted (inputs_to_values .keys (), key = id ))
617- if inputs not in self ._fn_cache :
618- self ._fn_cache [inputs ] = function (inputs , self )
619- args = [inputs_to_values [param ] for param in inputs ]
622+ inputs = tuple (sorted (parsed_inputs_to_values .keys (), key = id ))
623+ cache_key = (inputs , tuple (kwargs .items ()))
624+ try :
625+ fn = self ._fn_cache [cache_key ]
626+ except (KeyError , TypeError ):
627+ fn = None
620628
621- rval = self ._fn_cache [inputs ](* args )
629+ if fn is None :
630+ fn = function (inputs , self , ** kwargs )
631+ try :
632+ self ._fn_cache [cache_key ] = fn
633+ except TypeError as exc :
634+ warnings .warn (
635+ "Keyword arguments could not be used to create a cache key for the underlying variable. "
636+ f"A function will be recompiled on every call with such keyword arguments.\n { exc } "
637+ )
622638
623- return rval
639+ args = [parsed_inputs_to_values [param ] for param in inputs ]
640+ return fn (* args )
624641
625642 def __getstate__ (self ):
626643 d = self .__dict__ .copy ()
0 commit comments