@@ -49,6 +49,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo
4949
5050 Returns:
5151 The merged configuration.
52+
5253 """
5354 # using base config values if exist
5455 inferred_config .embedding_dims = config .get ("embedding_dims" ) or inferred_config .embedding_dims
@@ -90,6 +91,7 @@ def __init__(
9091 A custom optimizer as callable or string to be imported. Defaults to None.
9192 custom_optimizer_params (Dict, optional): A dictionary of custom optimizer parameters. Defaults to {}.
9293 kwargs (Dict, optional): Additional keyword arguments.
94+
9395 """
9496 super ().__init__ ()
9597 assert "inferred_config" in kwargs , "inferred_config not found in initialization arguments"
@@ -231,6 +233,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
231233
232234 Returns:
233235 torch.Tensor: The loss value
236+
234237 """
235238 y_hat = output ["logits" ]
236239 reg_terms = [k for k , v in output .items () if "regularization" in k ]
@@ -287,6 +290,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
287290
288291 Returns:
289292 List[torch.Tensor]: The list of metric values
293+
290294 """
291295 metrics = []
292296 for metric , metric_str , prob_inp , metric_params in zip (
@@ -349,13 +353,15 @@ def embed_input(self, x: Dict) -> torch.Tensor:
349353 return self .embedding_layer (x )
350354
351355 def apply_output_sigmoid_scaling (self , y_hat : torch .Tensor ) -> torch .Tensor :
352- """Applies sigmoid scaling to the output of the model if the task is regression and the target range is defined.
356+ """Applies sigmoid scaling to the output of the model if the task is regression and the target range is
357+ defined.
353358
354359 Args:
355360 y_hat (torch.Tensor): The output of the model
356361
357362 Returns:
358363 torch.Tensor: The output of the model with sigmoid scaling applied
364+
359365 """
360366 if (self .hparams .task == "regression" ) and (self .hparams .target_range is not None ):
361367 for i in range (self .hparams .output_dim ):
@@ -373,6 +379,7 @@ def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> D
373379
374380 Returns:
375381 The packed output of the model
382+
376383 """
377384 # if self.head is the Identity function it means that we cannot extract backbone features,
378385 # because the model cannot be divide in backbone and head (i.e. TabNet)
@@ -388,6 +395,7 @@ def compute_head(self, backbone_features: Tensor) -> Dict[str, Any]:
388395
389396 Returns:
390397 The output of the model
398+
391399 """
392400 y_hat = self .head (backbone_features )
393401 y_hat = self .apply_output_sigmoid_scaling (y_hat )
@@ -398,6 +406,7 @@ def forward(self, x: Dict) -> Dict[str, Any]:
398406
399407 Args:
400408 x (Dict): The input of the model with 'continuous' and 'categorical' keys
409+
401410 """
402411 x = self .embed_input (x )
403412 x = self .compute_backbone (x )
@@ -413,6 +422,7 @@ def predict(self, x: Dict, ret_model_output: bool = False) -> Union[torch.Tensor
413422
414423 Returns:
415424 The output of the model
425+
416426 """
417427 assert self .hparams .task != "ssl" , "It's not allowed to use the method predict in case of ssl task"
418428 ret_value = self .forward (x )
@@ -427,6 +437,7 @@ def extract_embedding(self):
427437 """Extracts the embedding of the model.
428438
429439 This is used in `CategoricalEmbeddingTransformer`
440+
430441 """
431442 if self .hparams .categorical_dim > 0 :
432443 if not isinstance (self .embedding_layer , PreEncoded1dLayer ):
0 commit comments