@@ -223,47 +223,48 @@ def gen_adv_neighbor(input_features,
223223 config ,
224224 raise_invalid_gradient = False ,
225225 gradient_tape = None ):
226- """Functional interface of _GenAdvNeighbor .
226+ """Generates adversarial neighbors for the given input and loss .
227227
228- This function provides a tensor/config-in & tensor-out functional interface
229- that does the following:
230- (a) Instantiates '_GenAdvNeighbor'
231- (b) Invokes 'gen_neighbor' method
232- (c) Returns the adversarial neighbors generated.
228+ This function implements the following operation:
229+ `adv_neighbor = input_features + adv_step_size * gradient`
230+ where `adv_step_size` is the step size (analogous to learning rate) for
231+ searching/calculating adversarial neighbor.
233232
234233 Arguments:
235- input_features: a dense (float32) tensor or a dictionary of feature names
236- and dense tensors. The shape of the tensor(s) should be either:
237- (a) pointwise samples: [batch_size, feat_len], or
238- (b) sequence samples: [batch_size, seq_len, feat_len]. Note that if the
239- `input_features` is a dictionary, only dense (`float`) tensors in it
240- will be perturbed and all other features (int, string, or sparse
241- tensors) will be kept as-is in the returning `adv_neighbor`.
242- labeled_loss: a scalar (float32) tensor calculated from true labels (or
243- supervisions).
244- config: `AdvNeighborConfig` object containing the following hyperparameters
245- for generating adversarial samples. - 'feature_mask': mask (w/ 0-1 values)
246- applied on graident. - 'adv_step_size': step size to find the adversarial
247- sample. - 'adv_grad_norm': type of tensor norm to normalize the gradient.
248- raise_invalid_gradient: (optional, default=False) a Boolean flag indicating
249- whether to raise an error when gradients cannot be computed on some input
250- features. There are three cases where gradients cannot be computed: (1)
251- The feature is a SparseTensor. (2) The feature has a non-differentiable
252- `tf.DType`, like string or integer. (3) The feature is not involved in
253- loss computation. If set to False (default), those inputs without
254- gradient will be ignored silently and not perturbed.
255- gradient_tape: a `tf.GradientTape` object watching the calculation from
234+ input_features: A `Tensor` or a dictionary of `(feature_name, Tensor)`.
235+ The shape of the tensor(s) should be either:
236+ (a) pointwise samples: `[batch_size, feat_len]`, or
237+ (b) sequence samples: `[batch_size, seq_len, feat_len]`.
238+ Note that only dense (`float`) tensors in `input_features` will be
239+ perturbed and all other features (`int`, `string`, or `SparseTensor`) will
240+ be kept as-is in the returning `adv_neighbor`.
241+ labeled_loss: A scalar tensor of floating point type calculated from true
242+ labels (or supervisions).
243+ config: A `nsl.configs.AdvNeighborConfig` object containing the following
244+ hyperparameters for generating adversarial samples.
245+ - 'feature_mask': mask (with 0-1 values) applied on the graident.
246+ - 'adv_step_size': step size to find the adversarial sample.
247+ - 'adv_grad_norm': type of tensor norm to normalize the gradient.
248+ raise_invalid_gradient: (optional) A Boolean flag indicating whether to
249+ raise an error when gradients cannot be computed on any input feature.
250+ There are three cases where this error may happen:
251+ (1) The feature is a `SparseTensor`.
252+ (2) The feature has a non-differentiable `dtype`, like string or integer.
253+ (3) The feature is not involved in loss computation.
254+ If set to `False` (default), those inputs without gradient will be ignored
255+ silently and not perturbed.
256+ gradient_tape: A `tf.GradientTape` object watching the calculation from
256257 `input_features` to `labeled_loss`. Can be omitted if running in graph
257258 mode.
258259
259260 Returns:
260- adv_neighbor: the perturbed example, with the same shape and structure as
261- input_features
262- adv_weight: a dense (float32) tensor with shape of [batch_size, 1],
263- representing the weight for each neighbor
261+ adv_neighbor: The perturbed example, with the same shape and structure as
262+ ` input_features`.
263+ adv_weight: A dense `Tensor` with shape of ` [batch_size, 1]` ,
264+ representing the weight for each neighbor.
264265
265266 Raises:
266- ValueError: if `raise_invalid_gradient` is set and some of the input
267+ ValueError: In case of `raise_invalid_gradient` is set and some of the input
267268 features cannot be perturbed. See `raise_invalid_gradient` for situations
268269 where this can happen.
269270 """
0 commit comments