File tree Expand file tree Collapse file tree 2 files changed +34
-0
lines changed
Expand file tree Collapse file tree 2 files changed +34
-0
lines changed Original file line number Diff line number Diff line change @@ -597,6 +597,22 @@ def eval(self, inputs_to_values=None):
597597 if inputs_to_values is None :
598598 inputs_to_values = {}
599599
600+ def convert_string_keys_to_variables (input_to_values ):
601+ new_input_to_values = {}
602+ for key , value in inputs_to_values .items ():
603+ if isinstance (key , str ):
604+ matching_vars = get_var_by_name ([self ], key )
605+ if not matching_vars :
606+ raise Exception (f"{ key } not found in graph" )
607+ elif len (matching_vars ) > 1 :
608+ raise Exception (f"Found multiple variables with name { key } " )
609+ new_input_to_values [matching_vars [0 ]] = value
610+ else :
611+ new_input_to_values [key ] = value
612+ return new_input_to_values
613+
614+ inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
615+
600616 if not hasattr (self , "_fn_cache" ):
601617 self ._fn_cache = dict ()
602618
Original file line number Diff line number Diff line change @@ -302,6 +302,24 @@ def test_eval(self):
302302 pickle .loads (pickle .dumps (self .w )), "_fn_cache"
303303 ), "temporary functions must not be serialized"
304304
305+ def test_eval_with_strings (self ):
306+ assert self .w .eval ({"x" : 1.0 , self .y : 2.0 }) == 6.0
307+ assert self .w .eval ({self .z : 3 }) == 6.0
308+
309+ def test_eval_with_strings_multiple_matches (self ):
310+ e = scalars ("e" )
311+ t = e + 1
312+ t .name = "e"
313+ with pytest .raises (Exception , match = "Found multiple variables with name e" ):
314+ t .eval ({"e" : 1 })
315+
316+ def test_eval_with_strings_no_match (self ):
317+ e = scalars ("e" )
318+ t = e + 1
319+ t .name = "p"
320+ with pytest .raises (Exception , match = "o not found in graph" ):
321+ t .eval ({"o" : 1 })
322+
305323
306324class TestAutoName :
307325 def test_auto_name (self ):
You can’t perform that action at this time.
0 commit comments