Skip to content

Commit aad3635

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Strengthens the input_maker_lib_test.py tests.
PiperOrigin-RevId: 272992405
1 parent b2bf5c1 commit aad3635

File tree

1 file changed

+75
-57
lines changed

1 file changed

+75
-57
lines changed

neural_structured_learning/tools/input_maker_lib_test.py

Lines changed: 75 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def _node_as_neighbor(example, neighbor_id, edge_weight):
130130
return result
131131

132132

133-
def _write_training_examples(training_examples_file):
134-
"""Writes training examples to the specified file."""
135-
with tf.io.TFRecordWriter(training_examples_file) as writer:
136-
for example in [_example_a(), _example_b(), _example_c()]:
133+
def _write_examples(examples_file, examples):
134+
"""Writes the given `examples` to the TFRecord file named `examples_file`."""
135+
with tf.io.TFRecordWriter(examples_file) as writer:
136+
for example in examples:
137137
writer.write(example.SerializeToString())
138138

139139

@@ -168,52 +168,29 @@ def _augmented_a_undirected_two_nbrs():
168168
return _augmented_a_directed_two_nbrs()
169169

170170

171-
def _augmented_b_directed_one_nbr():
172-
"""Returns an augmented `tf.train.Example` instance for node B."""
173-
augmented_b = _example_b()
174-
augmented_b.MergeFrom(_node_as_neighbor(_example_c(), 0, 1.0))
175-
augmented_b.MergeFrom(_num_neighbors_example(1))
176-
return augmented_b
177-
178-
179-
def _augmented_b_directed_two_nbrs():
180-
"""Returns an augmented `tf.train.Example` instance for node B."""
181-
augmented_b = _example_b()
182-
augmented_b.MergeFrom(_node_as_neighbor(_example_c(), 0, 1.0))
183-
augmented_b.MergeFrom(_node_as_neighbor(_example_a(), 1, 0.4))
184-
augmented_b.MergeFrom(_num_neighbors_example(2))
185-
return augmented_b
186-
187-
188-
def _augmented_b_undirected_one_nbr():
189-
"""Returns an augmented `tf.train.Example` instance for node B."""
190-
return _augmented_b_directed_one_nbr()
191-
192-
193-
def _augmented_b_undirected_two_nbrs():
194-
"""Returns an augmented `tf.train.Example` instance for node B."""
195-
augmented_b = _example_b()
196-
augmented_b.MergeFrom(_node_as_neighbor(_example_c(), 0, 1.0))
197-
augmented_b.MergeFrom(_node_as_neighbor(_example_a(), 1, 0.5))
198-
augmented_b.MergeFrom(_num_neighbors_example(2))
199-
return augmented_b
200-
201-
202171
def _augmented_c_directed():
203172
"""Returns an augmented `tf.train.Example` instance for node C."""
204173
augmented_c = _example_c()
205174
augmented_c.MergeFrom(_num_neighbors_example(0))
206175
return augmented_c
207176

208177

209-
def _augmented_c_undirected_one_nbr():
210-
"""Returns an augmented `tf.train.Example` instance for node C."""
178+
def _augmented_c_undirected_one_nbr_b():
179+
"""Returns an augmented `tf.train.Example` instance for node C with nbr B."""
211180
augmented_c = _example_c()
212181
augmented_c.MergeFrom(_node_as_neighbor(_example_b(), 0, 1.0))
213182
augmented_c.MergeFrom(_num_neighbors_example(1))
214183
return augmented_c
215184

216185

186+
def _augmented_c_undirected_one_nbr_a():
187+
"""Returns an augmented `tf.train.Example` instance for node C with nbr A."""
188+
augmented_c = _example_c()
189+
augmented_c.MergeFrom(_node_as_neighbor(_example_a(), 0, 0.9))
190+
augmented_c.MergeFrom(_num_neighbors_example(1))
191+
return augmented_c
192+
193+
217194
def _augmented_c_undirected_two_nbrs():
218195
"""Returns an augmented `tf.train.Example` instance for node C."""
219196
augmented_c = _example_c()
@@ -227,31 +204,75 @@ class PackNbrsTest(absltest.TestCase):
227204

228205
def setUp(self):
229206
super(PackNbrsTest, self).setUp()
230-
self._graph_path = self._create_graph_file()
207+
# Write graph edges (as a TSV file).
208+
self._graph_path = self._create_tmp_file('graph.tsv')
231209
graph_utils.write_tsv_graph(self._graph_path, _GRAPH)
232-
self._training_examples_path = self._create_training_examples_file()
233-
_write_training_examples(self._training_examples_path)
234-
self._output_nsl_training_data_path = self._create_nsl_training_data_file()
235-
236-
def _create_training_examples_file(self):
237-
return self.create_tempfile('train_data.tfr').full_path
210+
# Write labeled training Examples.
211+
self._training_examples_path = self._create_tmp_file('train_data.tfr')
212+
_write_examples(self._training_examples_path, [_example_a(), _example_c()])
213+
# Write unlabeled neighbor Examples.
214+
self._neighbor_examples_path = self._create_tmp_file('neighbor_data.tfr')
215+
_write_examples(self._neighbor_examples_path, [_example_b()])
216+
# Create output file
217+
self._output_nsl_training_data_path = self._create_tmp_file(
218+
'nsl_train_data.tfr')
219+
220+
def _create_tmp_file(self, filename):
221+
return self.create_tempfile(filename).full_path
222+
223+
def testDirectedGraphUnlimitedNbrsNoNeighborExamples(self):
224+
"""Tests pack_nbrs() with an empty second argument (neighbor examples).
225+
226+
In this case, the edge A-->B is dangling because there will be no Example
227+
named "B" in the input.
228+
"""
229+
input_maker_lib.pack_nbrs(
230+
self._training_examples_path,
231+
'',
232+
self._graph_path,
233+
self._output_nsl_training_data_path,
234+
add_undirected_edges=False)
235+
expected_nsl_train_data = {
236+
# Node A has only one neighbor, namely C.
237+
'A': _augmented_a_directed_one_nbr(),
238+
# C has no neighbors in the directed case.
239+
'C': _augmented_c_directed()
240+
}
241+
actual_nsl_train_data = _read_tfrecord_examples(
242+
self._output_nsl_training_data_path)
243+
self.assertDictEqual(actual_nsl_train_data, expected_nsl_train_data)
238244

239-
def _create_nsl_training_data_file(self):
240-
return self.create_tempfile('nsl_train_data.tfr').full_path
245+
def testUndirectedGraphUnlimitedNbrsNoNeighborExamples(self):
246+
"""Tests pack_nbrs() with an empty second argument (neighbor examples).
241247
242-
def _create_graph_file(self):
243-
return self.create_tempfile('graph.tsv').full_path
248+
In this case, the edge A-->B is dangling because there will be no Example
249+
named "B" in the input.
250+
"""
251+
input_maker_lib.pack_nbrs(
252+
self._training_examples_path,
253+
'',
254+
self._graph_path,
255+
self._output_nsl_training_data_path,
256+
add_undirected_edges=True)
257+
expected_nsl_train_data = {
258+
# Node A has only one neighbor, namely C.
259+
'A': _augmented_a_directed_one_nbr(),
260+
# C's only neighbor in the undirected case is A.
261+
'C': _augmented_c_undirected_one_nbr_a()
262+
}
263+
actual_nsl_train_data = _read_tfrecord_examples(
264+
self._output_nsl_training_data_path)
265+
self.assertDictEqual(actual_nsl_train_data, expected_nsl_train_data)
244266

245267
def testDirectedGraphUnlimitedNbrs(self):
246268
input_maker_lib.pack_nbrs(
247269
self._training_examples_path,
248-
'',
270+
self._neighbor_examples_path,
249271
self._graph_path,
250272
self._output_nsl_training_data_path,
251273
add_undirected_edges=False)
252274
expected_nsl_train_data = {
253275
'A': _augmented_a_directed_two_nbrs(),
254-
'B': _augmented_b_directed_two_nbrs(),
255276
'C': _augmented_c_directed()
256277
}
257278
actual_nsl_train_data = _read_tfrecord_examples(
@@ -261,14 +282,13 @@ def testDirectedGraphUnlimitedNbrs(self):
261282
def testDirectedGraphLimitedNbrs(self):
262283
input_maker_lib.pack_nbrs(
263284
self._training_examples_path,
264-
'',
285+
self._neighbor_examples_path,
265286
self._graph_path,
266287
self._output_nsl_training_data_path,
267288
add_undirected_edges=False,
268289
max_nbrs=1)
269290
expected_nsl_train_data = {
270291
'A': _augmented_a_directed_one_nbr(),
271-
'B': _augmented_b_directed_one_nbr(),
272292
'C': _augmented_c_directed()
273293
}
274294
actual_nsl_train_data = _read_tfrecord_examples(
@@ -278,13 +298,12 @@ def testDirectedGraphLimitedNbrs(self):
278298
def testUndirectedGraphUnlimitedNbrs(self):
279299
input_maker_lib.pack_nbrs(
280300
self._training_examples_path,
281-
'',
301+
self._neighbor_examples_path,
282302
self._graph_path,
283303
self._output_nsl_training_data_path,
284304
add_undirected_edges=True)
285305
expected_nsl_train_data = {
286306
'A': _augmented_a_undirected_two_nbrs(),
287-
'B': _augmented_b_undirected_two_nbrs(),
288307
'C': _augmented_c_undirected_two_nbrs()
289308
}
290309
actual_nsl_train_data = _read_tfrecord_examples(
@@ -294,15 +313,14 @@ def testUndirectedGraphUnlimitedNbrs(self):
294313
def testUndirectedGraphLimitedNbrs(self):
295314
input_maker_lib.pack_nbrs(
296315
self._training_examples_path,
297-
'',
316+
self._neighbor_examples_path,
298317
self._graph_path,
299318
self._output_nsl_training_data_path,
300319
add_undirected_edges=True,
301320
max_nbrs=1)
302321
expected_nsl_train_data = {
303322
'A': _augmented_a_undirected_one_nbr(),
304-
'B': _augmented_b_undirected_one_nbr(),
305-
'C': _augmented_c_undirected_one_nbr()
323+
'C': _augmented_c_undirected_one_nbr_b()
306324
}
307325
actual_nsl_train_data = _read_tfrecord_examples(
308326
self._output_nsl_training_data_path)

0 commit comments

Comments
 (0)