@@ -174,14 +174,14 @@ def embedding_fn(features, unused_mode):
174174 """
175175
176176 input_fn = single_example_input_fn (
177- example , input_shape = [1 ], max_neighbors = 1 )
177+ example , input_shape = [1 ], max_neighbors = 0 )
178178 predictions = graph_reg_est .predict (input_fn = input_fn )
179179 predicted_scores = [x ['predictions' ] for x in predictions ]
180180 self .assertAllClose ([[3.0 ]], predicted_scores )
181181
182- def train_and_check_params (self , example , max_neighbors , weight , bias ,
183- expected_grad_from_weight ,
184- expected_grad_from_bias ):
182+ def _train_and_check_params (self , example , max_neighbors , weight , bias ,
183+ expected_grad_from_weight ,
184+ expected_grad_from_bias ):
185185 """Runs training for one step and verifies gradient-based updates."""
186186
187187 def embedding_fn (features , unused_mode ):
@@ -261,7 +261,8 @@ def test_graph_reg_wrapper_one_neighbor_with_training(self):
261261 # which includes the supervised loss as well as the graph loss.
262262 orig_pred = np .dot (x0 , weight ) + bias # [9.0]
263263
264- # Based on the implementation of embedding_fn inside train_and_check_params.
264+ # Based on the implementation of embedding_fn inside
265+ # _train_and_check_params.
265266 x0_embedding = np .dot (x0 , weight )
266267 neighbor0_embedding = np .dot (neighbor0 , weight )
267268
@@ -271,8 +272,8 @@ def test_graph_reg_wrapper_one_neighbor_with_training(self):
271272 neighbor0 ).T # [[2.5], [1.5]]
272273 orig_grad_b = 2 * (orig_pred - y0 ).reshape ((1 ,)) # [2.0]
273274
274- self .train_and_check_params (example , 1 , weight , bias , orig_grad_w ,
275- orig_grad_b )
275+ self ._train_and_check_params (example , 1 , weight , bias , orig_grad_w ,
276+ orig_grad_b )
276277
277278 @test_util .run_v1_only ('Requires tf.get_variable' )
278279 def test_graph_reg_wrapper_two_neighbors_with_training (self ):
@@ -318,7 +319,8 @@ def test_graph_reg_wrapper_two_neighbors_with_training(self):
318319 # which includes the supervised loss as well as the graph loss.
319320 orig_pred = np .dot (x0 , weight ) + bias # [9.0]
320321
321- # Based on the implementation of embedding_fn inside train_and_check_params.
322+ # Based on the implementation of embedding_fn inside
323+ # _train_and_check_params.
322324 x0_embedding = np .dot (x0 , weight )
323325 neighbor0_embedding = np .dot (neighbor0 , weight )
324326 neighbor1_embedding = np .dot (neighbor1 , weight )
@@ -338,8 +340,101 @@ def test_graph_reg_wrapper_two_neighbors_with_training(self):
338340 orig_grad_w = grad_w_supervised_loss + grad_w_graph_loss
339341 orig_grad_b = 2 * (orig_pred - y0 ).reshape ((1 ,)) # [2.0]
340342
341- self .train_and_check_params (example , 2 , weight , bias , orig_grad_w ,
342- orig_grad_b )
343+ self ._train_and_check_params (example , 2 , weight , bias , orig_grad_w ,
344+ orig_grad_b )
345+
346+ def _train_and_check_eval_results (self , train_example , test_example ,
347+ max_neighbors , weight , bias ):
348+ """Verifies evaluation results for the graph-regularized model."""
349+
350+ def embedding_fn (features , unused_mode ):
351+ # Computes y = w*x
352+ with tf .variable_scope (
353+ tf .get_variable_scope (),
354+ reuse = tf .AUTO_REUSE ,
355+ auxiliary_name_scope = False ):
356+ weight_tensor = tf .reshape (
357+ tf .get_variable (
358+ WEIGHT_VARIABLE ,
359+ shape = [2 , 1 ],
360+ partitioner = tf .fixed_size_partitioner (1 )),
361+ shape = [- 1 , 2 ])
362+
363+ x_tensor = tf .reshape (features [FEATURE_NAME ], shape = [- 1 , 2 ])
364+ return tf .reduce_sum (
365+ tf .multiply (weight_tensor , x_tensor ), 1 , keep_dims = True )
366+
367+ def optimizer_fn ():
368+ return tf .train .GradientDescentOptimizer (LEARNING_RATE )
369+
370+ base_est = self .build_linear_regressor (
371+ weight = weight , weight_shape = [2 , 1 ], bias = bias , bias_shape = [1 ])
372+
373+ graph_reg_config = nsl_configs .make_graph_reg_config (
374+ max_neighbors = max_neighbors , multiplier = 1 )
375+ graph_reg_est = nsl_estimator .add_graph_regularization (
376+ base_est , embedding_fn , optimizer_fn , graph_reg_config = graph_reg_config )
377+
378+ train_input_fn = single_example_input_fn (
379+ train_example , input_shape = [2 ], max_neighbors = max_neighbors )
380+ graph_reg_est .train (input_fn = train_input_fn , steps = 1 )
381+
382+ # Evaluating the graph-regularized model should yield the same results
383+ # as evaluating the base model because model paramters are shared.
384+ eval_input_fn = single_example_input_fn (
385+ test_example , input_shape = [2 ], max_neighbors = 0 )
386+ graph_eval_results = graph_reg_est .evaluate (input_fn = eval_input_fn )
387+ base_eval_results = base_est .evaluate (input_fn = eval_input_fn )
388+ self .assertAllClose (base_eval_results , graph_eval_results )
389+
390+ @test_util .run_v1_only ('Requires tf.get_variable' )
391+ def test_graph_reg_model_evaluate (self ):
392+ weight = np .array ([[4.0 ], [- 3.0 ]])
393+ bias = np .array ([0.0 ], dtype = np .float32 )
394+
395+ train_example = """
396+ features {
397+ feature {
398+ key: "x"
399+ value: { float_list { value: [ 2.0, 3.0 ] } }
400+ }
401+ feature {
402+ key: "NL_nbr_0_x"
403+ value: { float_list { value: [ 2.5, 3.0 ] } }
404+ }
405+ feature {
406+ key: "NL_nbr_0_weight"
407+ value: { float_list { value: 1.0 } }
408+ }
409+ feature {
410+ key: "NL_nbr_1_x"
411+ value: { float_list { value: [ 2.0, 2.0 ] } }
412+ }
413+ feature {
414+ key: "NL_nbr_1_weight"
415+ value: { float_list { value: 1.0 } }
416+ }
417+ feature {
418+ key: "y"
419+ value: { float_list { value: 0.0 } }
420+ }
421+ }
422+ """
423+
424+ test_example = """
425+ features {
426+ feature {
427+ key: "x"
428+ value: { float_list { value: [ 4.0, 2.0 ] } }
429+ }
430+ feature {
431+ key: "y"
432+ value: { float_list { value: 4.0 } }
433+ }
434+ }
435+ """
436+ self ._train_and_check_eval_results (
437+ train_example , test_example , max_neighbors = 2 , weight = weight , bias = bias )
343438
344439
345440if __name__ == '__main__' :
0 commit comments