@@ -35,7 +35,7 @@ def all(cls):
3535
3636@attr .s
3737class AdvNeighborConfig (object ):
38- """AdvNeighborConfig contains configs for generating adversarial neighbors.
38+ """Contains configuration for generating adversarial neighbors.
3939
4040 Attributes:
4141 feature_mask: mask (w/ 0-1 values) applied on gradient. The shape should be
@@ -54,7 +54,7 @@ class AdvNeighborConfig(object):
5454
5555@attr .s
5656class AdvRegConfig (object ):
57- """AdvRegConfig contains configs for adversarial regularization.
57+ """Contains configuration for adversarial regularization.
5858
5959 Attributes:
6060 multiplier: multiplier to adversarial regularization loss. Default set to
@@ -71,22 +71,20 @@ def make_adv_reg_config(
7171 feature_mask = attr .fields (AdvNeighborConfig ).feature_mask .default ,
7272 adv_step_size = attr .fields (AdvNeighborConfig ).adv_step_size .default ,
7373 adv_grad_norm = attr .fields (AdvNeighborConfig ).adv_grad_norm .default ):
74- """Creates AdvRegConfig object.
74+ """Creates an `nsl.configs. AdvRegConfig` object.
7575
7676 Args:
77- multiplier: multiplier to adversarial regularization loss. Default set to
78- 0.2.
79- feature_mask: mask (w/ 0-1 values) applied on gradient. The shape should be
80- the same as (or broadcastable to) input features. If set to None, no
77+ multiplier: multiplier to adversarial regularization loss. Defaults to 0.2.
78+ feature_mask: mask (w/ 0-1 values) applied on the gradient. The shape should
79+ be the same as (or broadcastable to) input features. If set to `None`, no
8180 feature mask will be applied.
82- adv_step_size: step size to find the adversarial sample. Default set to
83- 0.001.
81+ adv_step_size: step size to find the adversarial sample. Defaults to 0.001.
8482 adv_grad_norm: type of tensor norm to normalize the gradient. Input will be
85- converted to `NormType` when applicable (e.g., 'l2' -> NormType.L2).
86- Default set to L2 norm.
83+ converted to `NormType` when applicable (e.g., a value of 'l2' will be
84+ converted to `nsl.configs.NormType.L2`). Defaults to L2 norm.
8785
8886 Returns:
89- An AdvRegConfig object.
87+ An `nsl.configs. AdvRegConfig` object.
9088 """
9189 return AdvRegConfig (
9290 multiplier = multiplier ,
@@ -110,7 +108,7 @@ def all(cls):
110108
111109@attr .s
112110class AdvTargetConfig (object ):
113- """AdvTargetConfig contains configs for selecting targets to be attacked.
111+ """Contains configuration for selecting targets to be attacked.
114112
115113 Attributes:
116114 target_method: type of adversarial targeting method. The value needs to be
@@ -142,20 +140,21 @@ def all(cls):
142140
143141@attr .s
144142class DistanceConfig (object ):
145- """DistanceConfig contains configs for computing distances.
143+ """Contains configuration for computing distances between tensors .
146144
147145 Attributes:
148146 distance_type: type of distance function. Input type will be converted to
149- 'DistanceType' when applicable (e.g., 'l2' -> DistanceType.L2). Default
150- set to L2 norm.
151- reduction: type of distance reduction. See tf.compat.v1.losses.Reduction for
152- details. Default set to tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS.
147+ the appropriate `nsl.configs.DistanceType` value (e.g., the value 'l2' is
148+ converted to `nsl.configs.DistanceType.L2`). Defaults to the L2 norm.
149+ reduction: type of distance reduction. See ` tf.compat.v1.losses.Reduction`
150+ for details. Defaults to ` tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS` .
153151 sum_over_axis: the distance is the sum over the difference along the axis.
154- Default set to None.
152+ See `nsl.lib.pairwise_distance_wrapper` for how this field is used.
153+ Defaults to `None`.
155154 transform_fn: type of transform function to be applied on each side before
156155 computing the pairwise distance. Input type will be converted to
157- ' TransformType' when applicable (e.g., 'softmax' ->
158- TransformType.SOFTMAX). Default set to 'none'.
156+ `nsl.configs. TransformType` when applicable (e.g., the value 'softmax'
157+ maps to `nsl.configs. TransformType.SOFTMAX` ). Defaults to 'none'.
159158 """
160159 distance_type = attr .ib (converter = DistanceType , default = DistanceType .L2 )
161160 reduction = attr .ib (
@@ -177,7 +176,7 @@ def all(cls):
177176
178177@attr .s
179178class DecayConfig (object ):
180- """DecayConfig contains configs for computing decayed value.
179+ """Contains configuration for computing decayed value.
181180
182181 Attributes:
183182 decay_steps: A scalar int32 or int64 Tensor or a Python number. How often to
@@ -207,7 +206,7 @@ def all(cls):
207206
208207@attr .s
209208class IntegrationConfig (object ):
210- """IntegrationConfig contains configs for computing multimodal integration.
209+ """Contains configuration for computing multimodal integration.
211210
212211 Attributes:
213212 integration_type: Type of integration function to apply.
@@ -222,7 +221,7 @@ class IntegrationConfig(object):
222221
223222@attr .s
224223class VirtualAdvConfig (object ):
225- """VirtualAdvConfig contains configs for virtual adversarial training.
224+ """Contains configuration for virtual adversarial training.
226225
227226 Attributes:
228227 adv_neighbor_config: an AdvNeighborConfig object for generating virtual
@@ -245,7 +244,7 @@ class VirtualAdvConfig(object):
245244
246245@attr .s
247246class GraphNeighborConfig (object ):
248- """GraphNeighborConfig specifies neighbor attributes for graph regularization.
247+ """Specifies neighbor attributes for graph regularization.
249248
250249 Attributes:
251250 prefix: The prefix in feature names that identifies neighbor-specific
@@ -268,21 +267,78 @@ class GraphNeighborConfig(object):
268267
269268@attr .s
270269class GraphRegConfig (object ):
271- """GraphRegConfig contains the configuration for graph regularization.
270+ """Contains the configuration for graph regularization.
272271
273272 Attributes:
274273 neighbor_config: An instance of `GraphNeighborConfig` that describes
275274 neighbor attributes for graph regularization.
276275 multiplier: The multiplier or weight factor applied on the graph
277- regularization loss term. Defaults to 0.01. This value has to be greater
278- than or equal to 0 .
276+ regularization loss term. This value has to be non-negative. Defaults to
277+ 0.01 .
279278 distance_config: An instance of `DistanceConfig` to calculate the graph
280- regularization loss term. Defaults to `DistanceConfig()`.
279+ regularization loss term. Defaults to `nsl.configs. DistanceConfig()`.
281280 """
282281 neighbor_config = attr .ib (default = GraphNeighborConfig ())
283282 multiplier = attr .ib (default = 0.01 )
284283 distance_config = attr .ib (default = DistanceConfig ())
285284
286285
286+ def make_graph_reg_config (
287+ neighbor_prefix = attr .fields (GraphNeighborConfig ).prefix .default ,
288+ neighbor_weight_suffix = attr .fields (
289+ GraphNeighborConfig ).weight_suffix .default ,
290+ max_neighbors = attr .fields (GraphNeighborConfig ).max_neighbors .default ,
291+ multiplier = attr .fields (GraphRegConfig ).multiplier .default ,
292+ distance_type = attr .fields (DistanceConfig ).distance_type .default ,
293+ reduction = attr .fields (DistanceConfig ).reduction .default ,
294+ sum_over_axis = attr .fields (DistanceConfig ).sum_over_axis .default ,
295+ transform_fn = attr .fields (DistanceConfig ).transform_fn .default ):
296+ """Creates an `nsl.configs.GraphRegConfig` object.
297+
298+ Args:
299+ neighbor_prefix: The prefix in feature names that identifies
300+ neighbor-specific features. Defaults to 'NL_nbr_'.
301+ neighbor_weight_suffix: The suffix in feature names that identifies the
302+ neighbor weight value. Defaults to '_weight'. Note that neighbor weight
303+ features will have `prefix` as a prefix and `weight_suffix` as a suffix.
304+ For example, based on the default values of `prefix` and `weight_suffix`,
305+ a valid neighbor weight feature is 'NL_nbr_0_weight', where 0 corresponds
306+ to the first neighbor of the sample.
307+ max_neighbors: The maximum number of neighbors to be used for graph
308+ regularization. Defaults to 0, which disables graph regularization. Note
309+ that this value has to be less than or equal to the actual number of
310+ neighbors in each sample.
311+ multiplier: The multiplier or weight factor applied on the graph
312+ regularization loss term. This value has to be non-negative. Defaults to
313+ 0.01.
314+ distance_type: type of distance function. Input type will be converted to
315+ the appropriate `nsl.configs.DistanceType` value (e.g., the value 'l2' is
316+ converted to `nsl.configs.DistanceType.L2`). Defaults to the L2 norm.
317+ reduction: type of distance reduction. See `tf.compat.v1.losses.Reduction`
318+ for details. Defaults to `tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS`.
319+ sum_over_axis: the distance is the sum over the difference along the axis.
320+ See `nsl.lib.pairwise_distance_wrapper` for how this field is used.
321+ Defaults to `None`.
322+ transform_fn: type of transform function to be applied on each side before
323+ computing the pairwise distance. Input type will be converted to
324+ `nsl.configs.TransformType` when applicable (e.g., the value 'softmax'
325+ maps to `nsl.configs.TransformType.SOFTMAX`). Defaults to 'none'.
326+
327+ Returns:
328+ An `nsl.configs.GraphRegConfig` object.
329+ """
330+ return GraphRegConfig (
331+ neighbor_config = GraphNeighborConfig (
332+ prefix = neighbor_prefix ,
333+ weight_suffix = neighbor_weight_suffix ,
334+ max_neighbors = max_neighbors ),
335+ multiplier = multiplier ,
336+ distance_config = DistanceConfig (
337+ distance_type = distance_type ,
338+ reduction = reduction ,
339+ sum_over_axis = sum_over_axis ,
340+ transform_fn = transform_fn ))
341+
342+
287343DEFAULT_DISTANCE_PARAMS = attr .asdict (DistanceConfig ())
288344DEFAULT_ADVERSARIAL_PARAMS = attr .asdict (AdvNeighborConfig ())
0 commit comments