@@ -130,10 +130,10 @@ def _node_as_neighbor(example, neighbor_id, edge_weight):
130130 return result
131131
132132
133- def _write_training_examples ( training_examples_file ):
134- """Writes training examples to the specified file."""
135- with tf .io .TFRecordWriter (training_examples_file ) as writer :
136- for example in [ _example_a (), _example_b (), _example_c ()] :
133+ def _write_examples ( examples_file , examples ):
134+ """Writes the given ` examples` to the TFRecord file named `examples_file` ."""
135+ with tf .io .TFRecordWriter (examples_file ) as writer :
136+ for example in examples :
137137 writer .write (example .SerializeToString ())
138138
139139
@@ -168,52 +168,29 @@ def _augmented_a_undirected_two_nbrs():
168168 return _augmented_a_directed_two_nbrs ()
169169
170170
171- def _augmented_b_directed_one_nbr ():
172- """Returns an augmented `tf.train.Example` instance for node B."""
173- augmented_b = _example_b ()
174- augmented_b .MergeFrom (_node_as_neighbor (_example_c (), 0 , 1.0 ))
175- augmented_b .MergeFrom (_num_neighbors_example (1 ))
176- return augmented_b
177-
178-
179- def _augmented_b_directed_two_nbrs ():
180- """Returns an augmented `tf.train.Example` instance for node B."""
181- augmented_b = _example_b ()
182- augmented_b .MergeFrom (_node_as_neighbor (_example_c (), 0 , 1.0 ))
183- augmented_b .MergeFrom (_node_as_neighbor (_example_a (), 1 , 0.4 ))
184- augmented_b .MergeFrom (_num_neighbors_example (2 ))
185- return augmented_b
186-
187-
188- def _augmented_b_undirected_one_nbr ():
189- """Returns an augmented `tf.train.Example` instance for node B."""
190- return _augmented_b_directed_one_nbr ()
191-
192-
193- def _augmented_b_undirected_two_nbrs ():
194- """Returns an augmented `tf.train.Example` instance for node B."""
195- augmented_b = _example_b ()
196- augmented_b .MergeFrom (_node_as_neighbor (_example_c (), 0 , 1.0 ))
197- augmented_b .MergeFrom (_node_as_neighbor (_example_a (), 1 , 0.5 ))
198- augmented_b .MergeFrom (_num_neighbors_example (2 ))
199- return augmented_b
200-
201-
202171def _augmented_c_directed ():
203172 """Returns an augmented `tf.train.Example` instance for node C."""
204173 augmented_c = _example_c ()
205174 augmented_c .MergeFrom (_num_neighbors_example (0 ))
206175 return augmented_c
207176
208177
209- def _augmented_c_undirected_one_nbr ():
210- """Returns an augmented `tf.train.Example` instance for node C."""
178+ def _augmented_c_undirected_one_nbr_b ():
179+ """Returns an augmented `tf.train.Example` instance for node C with nbr B ."""
211180 augmented_c = _example_c ()
212181 augmented_c .MergeFrom (_node_as_neighbor (_example_b (), 0 , 1.0 ))
213182 augmented_c .MergeFrom (_num_neighbors_example (1 ))
214183 return augmented_c
215184
216185
186+ def _augmented_c_undirected_one_nbr_a ():
187+ """Returns an augmented `tf.train.Example` instance for node C with nbr A."""
188+ augmented_c = _example_c ()
189+ augmented_c .MergeFrom (_node_as_neighbor (_example_a (), 0 , 0.9 ))
190+ augmented_c .MergeFrom (_num_neighbors_example (1 ))
191+ return augmented_c
192+
193+
217194def _augmented_c_undirected_two_nbrs ():
218195 """Returns an augmented `tf.train.Example` instance for node C."""
219196 augmented_c = _example_c ()
@@ -227,31 +204,75 @@ class PackNbrsTest(absltest.TestCase):
227204
228205 def setUp (self ):
229206 super (PackNbrsTest , self ).setUp ()
230- self ._graph_path = self ._create_graph_file ()
207+ # Write graph edges (as a TSV file).
208+ self ._graph_path = self ._create_tmp_file ('graph.tsv' )
231209 graph_utils .write_tsv_graph (self ._graph_path , _GRAPH )
232- self ._training_examples_path = self ._create_training_examples_file ()
233- _write_training_examples (self ._training_examples_path )
234- self ._output_nsl_training_data_path = self ._create_nsl_training_data_file ()
235-
236- def _create_training_examples_file (self ):
237- return self .create_tempfile ('train_data.tfr' ).full_path
210+ # Write labeled training Examples.
211+ self ._training_examples_path = self ._create_tmp_file ('train_data.tfr' )
212+ _write_examples (self ._training_examples_path , [_example_a (), _example_c ()])
213+ # Write unlabeled neighbor Examples.
214+ self ._neighbor_examples_path = self ._create_tmp_file ('neighbor_data.tfr' )
215+ _write_examples (self ._neighbor_examples_path , [_example_b ()])
216+ # Create output file
217+ self ._output_nsl_training_data_path = self ._create_tmp_file (
218+ 'nsl_train_data.tfr' )
219+
220+ def _create_tmp_file (self , filename ):
221+ return self .create_tempfile (filename ).full_path
222+
223+ def testDirectedGraphUnlimitedNbrsNoNeighborExamples (self ):
224+ """Tests pack_nbrs() with an empty second argument (neighbor examples).
225+
226+ In this case, the edge A-->B is dangling because there will be no Example
227+ named "B" in the input.
228+ """
229+ input_maker_lib .pack_nbrs (
230+ self ._training_examples_path ,
231+ '' ,
232+ self ._graph_path ,
233+ self ._output_nsl_training_data_path ,
234+ add_undirected_edges = False )
235+ expected_nsl_train_data = {
236+ # Node A has only one neighbor, namely C.
237+ 'A' : _augmented_a_directed_one_nbr (),
238+ # C has no neighbors in the directed case.
239+ 'C' : _augmented_c_directed ()
240+ }
241+ actual_nsl_train_data = _read_tfrecord_examples (
242+ self ._output_nsl_training_data_path )
243+ self .assertDictEqual (actual_nsl_train_data , expected_nsl_train_data )
238244
239- def _create_nsl_training_data_file (self ):
240- return self . create_tempfile ( 'nsl_train_data.tfr' ). full_path
245+ def testUndirectedGraphUnlimitedNbrsNoNeighborExamples (self ):
246+ """Tests pack_nbrs() with an empty second argument (neighbor examples).
241247
242- def _create_graph_file (self ):
243- return self .create_tempfile ('graph.tsv' ).full_path
248+ In this case, the edge A-->B is dangling because there will be no Example
249+ named "B" in the input.
250+ """
251+ input_maker_lib .pack_nbrs (
252+ self ._training_examples_path ,
253+ '' ,
254+ self ._graph_path ,
255+ self ._output_nsl_training_data_path ,
256+ add_undirected_edges = True )
257+ expected_nsl_train_data = {
258+ # Node A has only one neighbor, namely C.
259+ 'A' : _augmented_a_directed_one_nbr (),
260+ # C's only neighbor in the undirected case is A.
261+ 'C' : _augmented_c_undirected_one_nbr_a ()
262+ }
263+ actual_nsl_train_data = _read_tfrecord_examples (
264+ self ._output_nsl_training_data_path )
265+ self .assertDictEqual (actual_nsl_train_data , expected_nsl_train_data )
244266
245267 def testDirectedGraphUnlimitedNbrs (self ):
246268 input_maker_lib .pack_nbrs (
247269 self ._training_examples_path ,
248- '' ,
270+ self . _neighbor_examples_path ,
249271 self ._graph_path ,
250272 self ._output_nsl_training_data_path ,
251273 add_undirected_edges = False )
252274 expected_nsl_train_data = {
253275 'A' : _augmented_a_directed_two_nbrs (),
254- 'B' : _augmented_b_directed_two_nbrs (),
255276 'C' : _augmented_c_directed ()
256277 }
257278 actual_nsl_train_data = _read_tfrecord_examples (
@@ -261,14 +282,13 @@ def testDirectedGraphUnlimitedNbrs(self):
261282 def testDirectedGraphLimitedNbrs (self ):
262283 input_maker_lib .pack_nbrs (
263284 self ._training_examples_path ,
264- '' ,
285+ self . _neighbor_examples_path ,
265286 self ._graph_path ,
266287 self ._output_nsl_training_data_path ,
267288 add_undirected_edges = False ,
268289 max_nbrs = 1 )
269290 expected_nsl_train_data = {
270291 'A' : _augmented_a_directed_one_nbr (),
271- 'B' : _augmented_b_directed_one_nbr (),
272292 'C' : _augmented_c_directed ()
273293 }
274294 actual_nsl_train_data = _read_tfrecord_examples (
@@ -278,13 +298,12 @@ def testDirectedGraphLimitedNbrs(self):
278298 def testUndirectedGraphUnlimitedNbrs (self ):
279299 input_maker_lib .pack_nbrs (
280300 self ._training_examples_path ,
281- '' ,
301+ self . _neighbor_examples_path ,
282302 self ._graph_path ,
283303 self ._output_nsl_training_data_path ,
284304 add_undirected_edges = True )
285305 expected_nsl_train_data = {
286306 'A' : _augmented_a_undirected_two_nbrs (),
287- 'B' : _augmented_b_undirected_two_nbrs (),
288307 'C' : _augmented_c_undirected_two_nbrs ()
289308 }
290309 actual_nsl_train_data = _read_tfrecord_examples (
@@ -294,15 +313,14 @@ def testUndirectedGraphUnlimitedNbrs(self):
294313 def testUndirectedGraphLimitedNbrs (self ):
295314 input_maker_lib .pack_nbrs (
296315 self ._training_examples_path ,
297- '' ,
316+ self . _neighbor_examples_path ,
298317 self ._graph_path ,
299318 self ._output_nsl_training_data_path ,
300319 add_undirected_edges = True ,
301320 max_nbrs = 1 )
302321 expected_nsl_train_data = {
303322 'A' : _augmented_a_undirected_one_nbr (),
304- 'B' : _augmented_b_undirected_one_nbr (),
305- 'C' : _augmented_c_undirected_one_nbr ()
323+ 'C' : _augmented_c_undirected_one_nbr_b ()
306324 }
307325 actual_nsl_train_data = _read_tfrecord_examples (
308326 self ._output_nsl_training_data_path )
0 commit comments