Skip to content

Commit 1f445fd

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Adds experimental graph neural network module with graph-regularizer and graph convolution.
PiperOrigin-RevId: 389226767
1 parent b4de7c1 commit 1f445fd

File tree

8 files changed

+477
-0
lines changed

8 files changed

+477
-0
lines changed

neural_structured_learning/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ py_library(
3131
":version",
3232
"//neural_structured_learning/configs",
3333
"//neural_structured_learning/estimator",
34+
"//neural_structured_learning/experimental",
3435
"//neural_structured_learning/keras",
3536
"//neural_structured_learning/lib",
3637
"//neural_structured_learning/tools",

neural_structured_learning/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from neural_structured_learning import configs
44
from neural_structured_learning import estimator
5+
from neural_structured_learning import experimental
56
from neural_structured_learning import keras
67
from neural_structured_learning import lib
78
from neural_structured_learning import tools
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package(
2+
licenses = ["notice"], # Apache 2.0
3+
)
4+
5+
exports_files(["LICENSE"])
6+
7+
py_binary(
8+
name = "graph_keras_mlp_cora",
9+
srcs = ["graph_keras_mlp_cora.py"],
10+
python_version = "PY3",
11+
deps = [
12+
# package absl:app
13+
# package absl/flags
14+
# package absl/logging
15+
# package attr
16+
"//neural_structured_learning",
17+
# package tensorflow
18+
],
19+
)
20+
21+
py_binary(
22+
name = "graph_nets_cora_graph_regularization",
23+
srcs = ["graph_nets_cora_graph_regularization.py"],
24+
python_version = "PY3",
25+
deps = [
26+
# package absl:app
27+
# package absl/flags
28+
# package graph_nets
29+
"//neural_structured_learning",
30+
"//neural_structured_learning/experimental:gnn",
31+
# package tensorflow
32+
],
33+
)
34+
35+
py_binary(
36+
name = "graph_nets_cora_gcn",
37+
srcs = ["graph_nets_cora_gcn.py"],
38+
python_version = "PY3",
39+
deps = [
40+
# package absl:app
41+
# package absl/flags
42+
# package graph_nets
43+
"//neural_structured_learning",
44+
"//neural_structured_learning/experimental:gnn",
45+
# package tensorflow
46+
],
47+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Example of an NSL GNN."""
15+
from absl import app
16+
from absl import flags
17+
import graph_nets
18+
import neural_structured_learning as nsl
19+
from neural_structured_learning.experimental import gnn
20+
import tensorflow as tf
21+
22+
flags.DEFINE_string(
23+
'train_examples_path',
24+
None,
25+
'Path to training examples.')
26+
flags.DEFINE_string('eval_examples_path',
27+
None,
28+
'Path to evaluation examples.')
29+
30+
FLAGS = flags.FLAGS
31+
32+
33+
def main(argv):
34+
del argv
35+
neighbor_config = nsl.configs.GraphNeighborConfig(max_neighbors=3)
36+
train_dataset = gnn.make_cora_dataset(
37+
FLAGS.train_examples_path, shuffle=True, neighbor_config=neighbor_config)
38+
eval_dataset = gnn.make_cora_dataset(FLAGS.eval_examples_path, batch_size=32)
39+
40+
model = gnn.GraphConvolutionalNodeClassifier(
41+
seq_length=tf.data.experimental.get_structure(train_dataset)[0]
42+
['words'].shape[-1],
43+
num_classes=7)
44+
model.compile(
45+
optimizer=tf.keras.optimizers.Adam(),
46+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
47+
metrics=[
48+
tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
49+
tf.keras.metrics.SparseCategoricalAccuracy(),
50+
tf.keras.metrics.SparseTopKCategoricalAccuracy(2),
51+
])
52+
model.fit(train_dataset, epochs=30, validation_data=eval_dataset)
53+
54+
55+
if __name__ == '__main__':
56+
graph_nets.compat.set_sonnet_version('2')
57+
app.run(main)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Example of Graph Regularization with an NSL GNN."""
15+
import functools
16+
17+
from absl import app
18+
from absl import flags
19+
import graph_nets
20+
import neural_structured_learning as nsl
21+
from neural_structured_learning.experimental import gnn
22+
import tensorflow as tf
23+
24+
flags.DEFINE_string(
25+
'train_examples_path',
26+
None,
27+
'Path to training examples.')
28+
flags.DEFINE_string('eval_examples_path',
29+
None,
30+
'Path to evaluation examples.')
31+
32+
FLAGS = flags.FLAGS
33+
34+
35+
class NodeClassifier(tf.keras.Model):
36+
"""Classifier model for nodes."""
37+
38+
def __init__(self,
39+
seq_length,
40+
num_classes,
41+
hidden_units=None,
42+
dropout_rate=0.5,
43+
**kwargs):
44+
inputs = tf.keras.Input(shape=(seq_length,), dtype=tf.int64, name='words')
45+
x = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(inputs)
46+
for num_units in (hidden_units or [50, 50]):
47+
x = tf.keras.layers.Dense(num_units, activation='relu')(x)
48+
x = tf.keras.layers.Dropout(dropout_rate)(x)
49+
outputs = tf.keras.layers.Dense(num_classes)(x)
50+
super(NodeClassifier, self).__init__(inputs, outputs, **kwargs)
51+
52+
53+
def main(argv):
54+
del argv
55+
graph_reg_config = nsl.configs.GraphRegConfig(
56+
neighbor_config=nsl.configs.GraphNeighborConfig(max_neighbors=3),
57+
multiplier=0.1,
58+
distance_config=nsl.configs.DistanceConfig(
59+
distance_type=nsl.configs.DistanceType.L2,
60+
reduction=tf.compat.v1.losses.Reduction.NONE,
61+
sum_over_axis=-1))
62+
63+
train_dataset = gnn.make_cora_dataset(
64+
FLAGS.train_examples_path,
65+
shuffle=True,
66+
neighbor_config=graph_reg_config.neighbor_config)
67+
eval_dataset = gnn.make_cora_dataset(FLAGS.eval_examples_path, batch_size=32)
68+
69+
model = gnn.GraphRegularizationModel(
70+
config=graph_reg_config,
71+
node_model_fn=functools.partial(
72+
NodeClassifier,
73+
seq_length=tf.data.experimental.get_structure(train_dataset)[0]
74+
['words'].shape[-1],
75+
num_classes=7))
76+
model.compile(
77+
optimizer=tf.keras.optimizers.Adam(),
78+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
79+
metrics=[
80+
tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
81+
tf.keras.metrics.SparseCategoricalAccuracy(),
82+
tf.keras.metrics.SparseTopKCategoricalAccuracy(2),
83+
])
84+
model.fit(train_dataset, epochs=30, validation_data=eval_dataset)
85+
86+
87+
if __name__ == '__main__':
88+
graph_nets.compat.set_sonnet_version('2')
89+
app.run(main)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# TODO(ppham27): describe this package.
2+
# Placeholder for internal Python strict compatibility macro.
3+
4+
package(
5+
default_visibility = ["//neural_structured_learning:__subpackages__"],
6+
licenses = ["notice"], # Apache 2.0
7+
)
8+
9+
exports_files(["LICENSE"])
10+
11+
py_library(
12+
name = "experimental",
13+
srcs = ["__init__.py"],
14+
srcs_version = "PY3",
15+
deps = [
16+
":gnn",
17+
],
18+
)
19+
20+
py_library(
21+
name = "gnn",
22+
srcs = ["gnn.py"],
23+
srcs_version = "PY3",
24+
deps = [
25+
# package graph_nets
26+
"//neural_structured_learning/configs",
27+
"//neural_structured_learning/keras",
28+
"//neural_structured_learning/lib",
29+
# package sonnet/v2
30+
# package tensorflow
31+
],
32+
)

neural_structured_learning/experimental/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)