Skip to content

Commit 74745fc

Browse files
arjungtensorflow-copybara
authored andcommitted
Add a library for graph building with unit tests.
This is in line with the rest of the modules in NSL (and in TF), and also allows us to write unit tests. PiperOrigin-RevId: 272276186
1 parent c6c47c6 commit 74745fc

File tree

5 files changed

+266
-77
lines changed

5 files changed

+266
-77
lines changed

neural_structured_learning/tools/BUILD

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ py_library(
2929
srcs = ["__init__.py"],
3030
srcs_version = "PY2AND3",
3131
deps = [
32-
":build_graph_lib",
32+
":graph_builder",
3333
":graph_utils",
3434
":pack_nbrs_lib",
3535
],
@@ -56,25 +56,41 @@ py_test(
5656
)
5757

5858
py_library(
59-
name = "build_graph_lib",
60-
srcs = ["build_graph.py"],
59+
name = "graph_builder",
60+
srcs = ["graph_builder.py"],
6161
srcs_version = "PY2AND3",
6262
deps = [
6363
":graph_utils",
64-
# package absl:app
65-
# package absl/flags
6664
# package absl/logging
6765
# package numpy
6866
# package six
6967
# package tensorflow
7068
],
7169
)
7270

71+
py_test(
72+
name = "graph_builder_test",
73+
srcs = ["graph_builder_test.py"],
74+
srcs_version = "PY2AND3",
75+
deps = [
76+
":graph_builder",
77+
":graph_utils",
78+
# package protobuf,
79+
# package absl/testing:absltest
80+
# package tensorflow
81+
],
82+
)
83+
7384
py_binary(
74-
name = "build_graph",
75-
srcs = ["build_graph.py"],
85+
name = "graph_builder_main",
86+
srcs = ["graph_builder_main.py"],
7687
python_version = "PY3",
77-
deps = [":build_graph_lib"],
88+
deps = [
89+
":graph_builder",
90+
# package absl:app
91+
# package absl/flags
92+
# package tensorflow
93+
],
7894
)
7995

8096
py_library(
@@ -103,6 +119,8 @@ py_binary(
103119
srcs = ["build_docs.py"],
104120
python_version = "PY3",
105121
deps = [
122+
# package absl:app
123+
# package absl/flags
106124
"//neural_structured_learning",
107125
# package tensorflow_docs/api_generator
108126
],

neural_structured_learning/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tools and APIs for preparing data for Neural Structured Learning."""
22

3-
import neural_structured_learning.tools.build_graph
3+
from neural_structured_learning.tools.graph_builder import build_graph
44
from neural_structured_learning.tools.graph_utils import add_edge
55
from neural_structured_learning.tools.graph_utils import add_undirected_edges
66
from neural_structured_learning.tools.graph_utils import read_tsv_graph

neural_structured_learning/tools/build_graph.py renamed to neural_structured_learning/tools/graph_builder.py

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
r"""Program to build a graph based on dense input features (embeddings).
15+
r"""Library to build a graph based on dense input features (embeddings).
1616
17-
USAGE:
18-
19-
`python build_graph.py` [*flags*] *input_features.tfr ... output_graph.tsv*
20-
21-
This program reads input instances from one or more TFRecord files, each
22-
containing `tf.train.Example` protos. Each input example is expected to
23-
contain at least these 2 features:
24-
25-
* `id`: A singleton `bytes_list` feature that identifies each Example.
26-
* `embedding`: A `float_list` feature that contains the (dense) embedding of
27-
each example.
28-
29-
`id` and `embedding` are not necessarily the literal feature names; if your
30-
features have different names, you can use the `--id_feature_name` and
31-
`--embedding_feature_name` flags to specify them, respectively.
32-
33-
The program then computes the cosine similarity between all pairs of input
34-
examples based on their associated embeddings. An edge is written to the
35-
*output_graph.tsv* file for each pair whose similarity is at least as large as
36-
the value of the `--similarity_threshold` flag's value. Each output edge is
37-
represented by a line in the *output_graph.tsv* file with the following form:
38-
39-
```
40-
source_id<TAB>target_id<TAB>edge_weight
41-
```
42-
43-
All edges in the output will be symmetric (i.e., if edge `A--w-->B` exists in
44-
the output, then so will edge `B--w-->A`).
45-
46-
For details about this program's flags, run `python build_graph.py --help`.
17+
A python-based program for graph building also exists on
18+
[GitHub](https://github.com/tensorflow/neural-structured-learning/tree/master/neural_structured_learning/tools/graph_builder_main.py).
4719
"""
4820

4921
from __future__ import absolute_import
@@ -54,8 +26,6 @@
5426
import itertools
5527
import time
5628

57-
from absl import app
58-
from absl import flags
5929
from absl import logging
6030
from neural_structured_learning.tools import graph_utils
6131
import numpy as np
@@ -71,7 +41,8 @@ def _read_tfrecord_examples(filenames, id_feature_name, embedding_feature_name):
7141
"""Reads and returns the embeddings stored in the Examples in `filename`.
7242
7343
Args:
74-
filenames: A list of names of TFRecord files containing tensorflow.Examples.
44+
filenames: A list of names of TFRecord files containing `tf.train.Example`
45+
objects.
7546
id_feature_name: Name of the feature that identifies the Example's ID.
7647
embedding_feature_name: Name of the feature that identifies the Example's
7748
embedding.
@@ -162,39 +133,52 @@ def _add_edges(embeddings, threshold, g):
162133
edge_cnt, (time.time() - start_time))
163134

164135

165-
def _main(argv):
166-
"""Main function for running the build_graph program."""
167-
flag = flags.FLAGS
168-
flag.showprefixforinfo = False
169-
if len(argv) < 3:
170-
raise app.UsageError(
171-
'Invalid number of arguments; expected 2 or more, got %d' %
172-
(len(argv) - 1))
136+
def build_graph(embedding_files,
137+
output_graph_path,
138+
similarity_threshold=0.8,
139+
id_feature_name='id',
140+
embedding_feature_name='embedding'):
141+
"""Builds a graph based on dense embeddings and persists it in TSV format.
173142
174-
embeddings = _read_tfrecord_examples(argv[1:-1], flag.id_feature_name,
175-
flag.embedding_feature_name)
143+
This function reads input instances from one or more TFRecord files, each
144+
containing `tf.train.Example` protos. Each input example is expected to
145+
contain at least the following 2 features:
146+
147+
* `id`: A singleton `bytes_list` feature that identifies each example.
148+
* `embedding`: A `float_list` feature that contains the (dense) embedding of
149+
each example.
150+
151+
`id` and `embedding` are not necessarily the literal feature names; if your
152+
features have different names, you can specify them using the
153+
`id_feature_name` and `embedding_feature_name` arguments, respectively.
154+
155+
This function then computes the cosine similarity between all pairs of input
156+
examples based on their associated embeddings. An edge is written to the TSV
157+
file named by `output_graph_path` for each pair whose similarity is at least
158+
as large as `similarity_threshold`. Each output edge is represented by a TSV
159+
line in the `output_graph_path` file with the following form:
160+
161+
```
162+
source_id<TAB>target_id<TAB>edge_weight
163+
```
164+
165+
All edges in the output will be symmetric (i.e., if edge `A--w-->B` exists in
166+
the output, then so will edge `B--w-->A`).
167+
168+
Args:
169+
embedding_files: A list of names of TFRecord files containing
170+
`tf.train.Example` objects, which in turn contain dense embeddings.
171+
output_graph_path: Name of the file to which the output graph in TSV format
172+
should be written.
173+
similarity_threshold: Threshold used to determine which edges to retain in
174+
the resulting graph.
175+
id_feature_name: The name of the feature in the input `tf.train.Example`
176+
objects representing the ID of examples.
177+
embedding_feature_name: The name of the feature in the input
178+
`tf.train.Example` objects representing the embedding of examples.
179+
"""
180+
embeddings = _read_tfrecord_examples(embedding_files, id_feature_name,
181+
embedding_feature_name)
176182
graph = collections.defaultdict(dict)
177-
_add_edges(embeddings, flag.similarity_threshold, graph)
178-
graph_utils.write_tsv_graph(argv[-1], graph)
179-
180-
181-
if __name__ == '__main__':
182-
flags.DEFINE_string(
183-
'id_feature_name', 'id',
184-
"""Name of the singleton bytes_list feature in each input Example
185-
whose value is the Example's ID."""
186-
)
187-
flags.DEFINE_string(
188-
'embedding_feature_name', 'embedding',
189-
"""Name of the float_list feature in each input Example
190-
whose value is the Example's (dense) embedding."""
191-
)
192-
flags.DEFINE_float(
193-
'similarity_threshold', 0.8,
194-
"""Lower bound on the cosine similarity required for an edge
195-
to be created between two nodes."""
196-
)
197-
198-
# Ensure TF 2.0 behavior even if TF 1.X is installed.
199-
tf.compat.v1.enable_v2_behavior()
200-
app.run(_main)
183+
_add_edges(embeddings, similarity_threshold, graph)
184+
graph_utils.write_tsv_graph(output_graph_path, graph)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
r"""Program to build a graph based on dense input features (embeddings).
15+
16+
This is a wrapper around the `nsl.tools.build_graph` API. See its documentation
17+
for more details.
18+
19+
USAGE:
20+
21+
`python graph_builder_main.py` [*flags*] *input_features.tfr ...
22+
output_graph.tsv*
23+
24+
For details about this program's flags, run `python graph_builder_main.py
25+
--help`.
26+
"""
27+
28+
from __future__ import absolute_import
29+
from __future__ import division
30+
from __future__ import print_function
31+
32+
from absl import app
33+
from absl import flags
34+
from neural_structured_learning.tools import graph_builder
35+
import tensorflow as tf
36+
37+
38+
def _main(argv):
39+
"""Main function for running the graph_builder_main program."""
40+
flag = flags.FLAGS
41+
flag.showprefixforinfo = False
42+
if len(argv) < 3:
43+
raise app.UsageError(
44+
'Invalid number of arguments; expected 2 or more, got %d' %
45+
(len(argv) - 1))
46+
47+
graph_builder.build_graph(argv[1:-1], argv[-1], flag.similarity_threshold,
48+
flag.id_feature_name, flag.embedding_feature_name)
49+
50+
51+
if __name__ == '__main__':
52+
flags.DEFINE_string(
53+
'id_feature_name', 'id',
54+
"""Name of the singleton bytes_list feature in each input Example
55+
whose value is the Example's ID.""")
56+
flags.DEFINE_string(
57+
'embedding_feature_name', 'embedding',
58+
"""Name of the float_list feature in each input Example
59+
whose value is the Example's (dense) embedding.""")
60+
flags.DEFINE_float(
61+
'similarity_threshold', 0.8,
62+
"""Lower bound on the cosine similarity required for an edge
63+
to be created between two nodes.""")
64+
65+
# Ensure TF 2.0 behavior even if TF 1.X is installed.
66+
tf.compat.v1.enable_v2_behavior()
67+
app.run(_main)

0 commit comments

Comments
 (0)