Skip to content

Commit 935d25a

Browse files
DualityGaptensorflow-copybara
authored andcommitted
Open source the NSL APIs to build TF Estimators that enable graph as well as adversarial regularization.
NSL Estimator APIs and test cases currently support only TF 1.x (and not TF 2.0 yet). PiperOrigin-RevId: 273627369
1 parent 1c57b09 commit 935d25a

File tree

9 files changed

+824
-0
lines changed

9 files changed

+824
-0
lines changed

neural_structured_learning/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ py_library(
2727
name = "neural_structured_learning",
2828
srcs = ["__init__.py"],
2929
deps = [
30+
"//neural_structured_learning/estimator",
3031
"//neural_structured_learning/keras",
3132
"//neural_structured_learning/lib",
3233
"//neural_structured_learning/tools",

neural_structured_learning/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Subpackages of Neural Structured Learning."""
22

33
from neural_structured_learning import configs
4+
from neural_structured_learning import estimator
45
from neural_structured_learning import keras
56
from neural_structured_learning import lib
67
from neural_structured_learning import tools
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Description:
16+
# Build rules for Estimator APIs in Neural Structured Learning.
17+
18+
# Placeholder for internal Python version compatibility macro.
19+
20+
package(
21+
default_visibility = ["//visibility:public"],
22+
licenses = ["notice"], # Apache 2.0
23+
)
24+
25+
exports_files(["LICENSE"])
26+
27+
py_library(
28+
name = "estimator",
29+
srcs = ["__init__.py"],
30+
srcs_version = "PY2AND3",
31+
deps = [
32+
":adversarial_regularization",
33+
":graph_regularization",
34+
],
35+
)
36+
37+
py_library(
38+
name = "adversarial_regularization",
39+
srcs = ["adversarial_regularization.py"],
40+
srcs_version = "PY2AND3",
41+
deps = [
42+
"//neural_structured_learning/configs",
43+
"//neural_structured_learning/lib",
44+
# package tensorflow
45+
],
46+
)
47+
48+
py_test(
49+
name = "adversarial_regularization_test",
50+
srcs = ["adversarial_regularization_test.py"],
51+
srcs_version = "PY2AND3",
52+
deps = [
53+
":estimator",
54+
"//neural_structured_learning/configs",
55+
# package numpy
56+
# package tensorflow
57+
# package tensorflow framework_test_lib,
58+
],
59+
)
60+
61+
py_library(
62+
name = "graph_regularization",
63+
srcs = ["graph_regularization.py"],
64+
srcs_version = "PY2AND3",
65+
deps = [
66+
"//neural_structured_learning/configs",
67+
"//neural_structured_learning/lib:distances",
68+
"//neural_structured_learning/lib:utils",
69+
# package tensorflow
70+
],
71+
)
72+
73+
py_test(
74+
name = "graph_regularization_test",
75+
srcs = ["graph_regularization_test.py"],
76+
srcs_version = "PY2AND3",
77+
deps = [
78+
":estimator",
79+
# package protobuf,
80+
"//neural_structured_learning/configs",
81+
# package numpy
82+
# package tensorflow
83+
# package tensorflow framework_test_lib,
84+
],
85+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Estimator APIs for Neural Structured Learning.
2+
3+
The implementations here are provided on a strict "as is" basis, without
4+
warranties or conditions of any kind. Also, these implementations may not be
5+
compatible with certain TensorFlow versions (such as 2.0 or above).
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
r"""Estimator APIs for Neural Structured Learning.
2+
3+
The current NSL Estimator APIs may not be compatible with certain TensorFlow
4+
versions (such as 2.0 or above).
5+
"""
6+
7+
from neural_structured_learning.estimator.adversarial_regularization import add_adversarial_regularization
8+
from neural_structured_learning.estimator.graph_regularization import add_graph_regularization
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""A wrapper function to enable adversarial regularization to an Estimator."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import neural_structured_learning.configs as nsl_configs
22+
import neural_structured_learning.lib as nsl_lib
23+
24+
import tensorflow as tf
25+
26+
27+
def add_adversarial_regularization(estimator,
28+
optimizer_fn=None,
29+
adv_config=None):
30+
"""Adds adversarial regularization to a `tf.estimator.Estimator`.
31+
32+
Args:
33+
estimator: An object of type `tf.estimator.Estimator`.
34+
optimizer_fn: A function that accepts no arguments and returns an instance
35+
of `tf.train.Optimizer`.
36+
adv_config: An instance of `nsl.configs.AdvRegConfig` that specifies various
37+
hyperparameters for adversarial regularization.
38+
39+
Returns:
40+
A modified `tf.estimator.Estimator` object with adversarial regularization
41+
incorporated into its loss.
42+
"""
43+
44+
if not adv_config:
45+
adv_config = nsl_configs.AdvRegConfig()
46+
47+
base_model_fn = estimator._model_fn # pylint: disable=protected-access
48+
49+
def adv_model_fn(features, labels, mode, params=None, config=None):
50+
"""The adversarial-regularized model_fn.
51+
52+
Args:
53+
features: This is the first item returned from the `input_fn` passed to
54+
`train`, `evaluate`, and `predict`. This should be a single `tf.Tensor`
55+
or `dict` of same.
56+
labels: This is the second item returned from the `input_fn` passed to
57+
`train`, `evaluate`, and `predict`. This should be a single `tf.Tensor`
58+
or dict of same (for multi-head models). If mode is
59+
`tf.estimator.ModeKeys.PREDICT`, `labels=None` will be passed. If the
60+
`model_fn`'s signature does not accept `mode`, the `model_fn` must still
61+
be able to handle `labels=None`.
62+
mode: Optional. Specifies if this is training, evaluation, or prediction.
63+
See `tf.estimator.ModeKeys`.
64+
params: Optional `dict` of hyperparameters. Will receive what is passed to
65+
Estimator in the `params` parameter. This allows users to configure
66+
Estimators from hyper parameter tuning.
67+
config: Optional `estimator.RunConfig` object. Will receive what is passed
68+
to Estimator as its `config` parameter, or a default value. Allows
69+
setting up things in the model_fn based on configuration such as
70+
`num_ps_replicas`, or `model_dir`. Unused currently.
71+
72+
Returns:
73+
A `tf.EstimatorSpec` whose loss incorporates graph-based regularization.
74+
"""
75+
76+
# Uses the same variable scope for calculating the original objective and
77+
# adversarial regularization.
78+
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope(),
79+
reuse=tf.compat.v1.AUTO_REUSE,
80+
auxiliary_name_scope=False):
81+
# If no 'params' is passed, then it is possible for base_model_fn not to
82+
# accept a 'params' argument. See documentation for tf.estimator.Estimator
83+
# for additional context.
84+
if params:
85+
original_spec = base_model_fn(features, labels, mode, params, config)
86+
else:
87+
original_spec = base_model_fn(features, labels, mode, config)
88+
89+
# Adversarial regularization only happens in training.
90+
if mode != tf.estimator.ModeKeys.TRAIN:
91+
return original_spec
92+
93+
adv_neighbor, _ = nsl_lib.gen_adv_neighbor(features, original_spec.loss,
94+
adv_config.adv_neighbor_config)
95+
96+
# Runs the base model again to compute loss on adv_neighbor.
97+
adv_spec = base_model_fn(adv_neighbor, labels, mode, config)
98+
99+
final_loss = original_spec.loss + adv_config.multiplier * adv_spec.loss
100+
101+
if not optimizer_fn:
102+
# Default to the Adagrad optimizer, the same as canned DNNEstimator.
103+
optimizer = tf.train.AdagradOptimizer(learning_rate=0.05)
104+
else:
105+
optimizer = optimizer_fn()
106+
107+
final_train_op = optimizer.minimize(
108+
loss=final_loss, global_step=tf.compat.v1.train.get_global_step())
109+
110+
return original_spec._replace(loss=final_loss, train_op=final_train_op)
111+
112+
# Replaces the model_fn while keeps other fields/methods in the estimator.
113+
estimator._model_fn = adv_model_fn # pylint: disable=protected-access
114+
return estimator
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for nsl.estimator.adversarial_regularization."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
import shutil
22+
import tempfile
23+
24+
import neural_structured_learning.configs as nsl_configs
25+
import neural_structured_learning.estimator as nsl_estimator
26+
27+
import numpy as np
28+
import tensorflow as tf
29+
30+
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import
31+
32+
33+
FEATURE_NAME = 'x'
34+
WEIGHT_VARIABLE = 'linear/linear_model/' + FEATURE_NAME + '/weights'
35+
BIAS_VARIABLE = 'linear/linear_model/bias_weights'
36+
37+
38+
def single_batch_input_fn(features, labels=None):
39+
def input_fn():
40+
inputs = features if labels is None else (features, labels)
41+
dataset = tf.data.Dataset.from_tensor_slices(inputs)
42+
return dataset.batch(len(features))
43+
return input_fn
44+
45+
46+
class AdversarialRegularizationTest(tf.test.TestCase):
47+
48+
def setUp(self):
49+
super(AdversarialRegularizationTest, self).setUp()
50+
self.model_dir = tempfile.mkdtemp()
51+
52+
def tearDown(self):
53+
if self.model_dir:
54+
shutil.rmtree(self.model_dir)
55+
super(AdversarialRegularizationTest, self).tearDown()
56+
57+
def build_linear_regressor(self, weight, bias):
58+
with tf.Graph().as_default():
59+
tf.Variable(weight, name=WEIGHT_VARIABLE)
60+
tf.Variable(bias, name=BIAS_VARIABLE)
61+
tf.Variable(100, name=tf.GraphKeys.GLOBAL_STEP, dtype=tf.int64)
62+
63+
with tf.Session() as sess:
64+
sess.run([tf.global_variables_initializer()])
65+
tf.train.Saver().save(sess, os.path.join(self.model_dir, 'model.ckpt'))
66+
67+
fc = tf.feature_column.numeric_column(FEATURE_NAME,
68+
shape=np.array(weight).shape)
69+
return tf.estimator.LinearRegressor(
70+
feature_columns=(fc,), model_dir=self.model_dir, optimizer='SGD')
71+
72+
@test_util.run_v1_only('Requires tf.GraphKeys')
73+
def test_adversarial_wrapper_not_affecting_predictions(self):
74+
# base model: y = x + 2
75+
base_est = self.build_linear_regressor(weight=[[1.0]], bias=[2.0])
76+
adv_est = nsl_estimator.add_adversarial_regularization(base_est)
77+
input_fn = single_batch_input_fn({FEATURE_NAME: np.array([[1.0], [2.0]])})
78+
predictions = adv_est.predict(input_fn=input_fn)
79+
predicted_scores = [x['predictions'] for x in predictions]
80+
self.assertAllClose([[3.0], [4.0]], predicted_scores)
81+
82+
@test_util.run_v1_only('Requires tf.GraphKeys')
83+
def test_adversarial_wrapper_adds_regularization(self):
84+
# base model: y = w*x+b = 4*x1 + 3*x2 + 2
85+
weight = np.array([[4.0], [3.0]], dtype=np.float32)
86+
bias = np.array([2.0], dtype=np.float32)
87+
x0, y0 = np.array([[1.0, 1.0]]), np.array([8.0])
88+
adv_step_size = 0.1
89+
learning_rate = 0.01
90+
91+
base_est = self.build_linear_regressor(weight=weight, bias=bias)
92+
adv_config = nsl_configs.make_adv_reg_config(
93+
multiplier=1.0, # equal weight on original and adv examples
94+
adv_step_size=adv_step_size)
95+
adv_est = nsl_estimator.add_adversarial_regularization(
96+
base_est,
97+
optimizer_fn=lambda: tf.train.GradientDescentOptimizer(learning_rate),
98+
adv_config=adv_config)
99+
input_fn = single_batch_input_fn({FEATURE_NAME: x0}, y0)
100+
adv_est.train(input_fn=input_fn, steps=1)
101+
102+
# Computes the gradients on original and adversarial examples.
103+
orig_pred = np.dot(x0, weight) + bias # [9.0]
104+
orig_grad_w = 2 * (orig_pred - y0) * x0.T # [[2.0], [2.0]]
105+
orig_grad_b = 2 * (orig_pred - y0).reshape((1,)) # [2.0]
106+
grad_x = 2 * (orig_pred - y0) * weight.T # [[8.0, 6.0]]
107+
perturbation = adv_step_size * grad_x / np.linalg.norm(grad_x)
108+
x_adv = x0 + perturbation # [[1.08, 1.06]]
109+
adv_pred = np.dot(x_adv, weight) + bias # [9.5]
110+
adv_grad_w = 2 * (adv_pred - y0) * x_adv.T # [[3.24], [3.18]]
111+
adv_grad_b = 2 * (adv_pred - y0).reshape((1,)) # [3.0]
112+
113+
new_bias = bias - learning_rate * (orig_grad_b + adv_grad_b)
114+
new_weight = weight - learning_rate * (orig_grad_w + adv_grad_w)
115+
self.assertAllClose(new_bias, adv_est.get_variable_value(BIAS_VARIABLE))
116+
self.assertAllClose(new_weight, adv_est.get_variable_value(WEIGHT_VARIABLE))
117+
118+
119+
if __name__ == '__main__':
120+
tf.test.main()

0 commit comments

Comments
 (0)