Skip to content

Commit 55cfd84

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Changes nsl.tools.add_edge() to return a boolean result.
PiperOrigin-RevId: 317868388
1 parent b01695c commit 55cfd84

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

neural_structured_learning/tools/graph_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ def add_edge(graph, edge):
6363
supplied, it defaults to 1.0.
6464
6565
Returns:
66-
`None`. Instead, this function has a side-effect on the `graph` argument.
66+
`True` if and only if a new edge was added to `graph`.
6767
"""
6868
source = edge[0]
69-
if source not in graph: graph[source] = {}
70-
t_dict = graph[source]
7169
target = edge[1]
7270
weight = float(edge[2]) if len(edge) > 2 else 1.0
73-
if target not in t_dict or weight > t_dict[target]:
71+
t_dict = graph.setdefault(source, {})
72+
is_new_edge = target not in t_dict
73+
if is_new_edge or weight > t_dict[target]:
7474
t_dict[target] = weight
75+
return is_new_edge
7576

7677

7778
def add_undirected_edges(graph):

neural_structured_learning/tools/graph_utils_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ class GraphUtilsTest(absltest.TestCase):
2929

3030
def testAddEdge(self):
3131
graph = {}
32-
graph_utils.add_edge(graph, ['A', 'B', '0.5'])
33-
graph_utils.add_edge(graph, ['A', 'C', 0.7]) # Tests that the edge
34-
graph_utils.add_edge(graph, ['A', 'C', 0.9]) # ...with maximal weight
35-
graph_utils.add_edge(graph, ['A', 'C', 0.8]) # ...is used.
36-
graph_utils.add_edge(graph, ('B', 'A', '0.4'))
37-
graph_utils.add_edge(graph, ('B', 'C')) # Tests default weight
38-
graph_utils.add_edge(graph, ('D', 'A', 0.75))
32+
self.assertTrue(graph_utils.add_edge(graph, ['A', 'B', '0.5']))
33+
# The next 3 calls test that the edge with maximal weight is used.
34+
self.assertTrue(graph_utils.add_edge(graph, ['A', 'C', 0.7]))
35+
self.assertFalse(graph_utils.add_edge(graph, ['A', 'C', 0.9]))
36+
self.assertFalse(graph_utils.add_edge(graph, ['A', 'C', 0.8]))
37+
self.assertTrue(graph_utils.add_edge(graph, ('B', 'A', '0.4')))
38+
# Tests that when no weight is specified, it defaults to 1.0.
39+
self.assertTrue(graph_utils.add_edge(graph, ('B', 'C')))
40+
self.assertTrue(graph_utils.add_edge(graph, ('D', 'A', 0.75)))
3941
self.assertDictEqual(graph, GRAPH)
4042

4143
def testAddUndirectedEdges(self):

0 commit comments

Comments
 (0)