Skip to content

Commit 9dfec99

Browse files
lint: simplify used tools (#431)
* lint: simplify used tools * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3f0a15c commit 9dfec99

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+161
-57
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,12 @@ repos:
2626
- id: check-docstring-first
2727
- id: detect-private-key
2828

29-
- repo: https://github.com/asottile/pyupgrade
30-
rev: v3.15.1
31-
hooks:
32-
- id: pyupgrade
33-
args: ["--py38-plus"]
34-
name: Upgrade code
35-
3629
- repo: https://github.com/PyCQA/docformatter
3730
rev: v1.7.5
3831
hooks:
3932
- id: docformatter
40-
args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120]
41-
42-
- repo: https://github.com/psf/black
43-
rev: 24.2.0
44-
hooks:
45-
- id: black
46-
name: Black code
33+
additional_dependencies: [tomli]
34+
args: ["--in-place"]
4735

4836
- repo: https://github.com/executablebooks/mdformat
4937
rev: 0.7.17
@@ -58,20 +46,19 @@ repos:
5846
docs/|
5947
README.md
6048
)
61-
- repo: https://github.com/asottile/yesqa
62-
rev: v1.5.0
63-
hooks:
64-
- id: yesqa
6549
6650
- repo: https://github.com/astral-sh/ruff-pre-commit
6751
rev: v0.2.2
6852
hooks:
6953
- id: ruff
7054
args: ["--fix"]
55+
- id: ruff-format
56+
- id: ruff
7157

7258
- repo: https://github.com/pre-commit/mirrors-prettier
7359
rev: v4.0.0-alpha.8
7460
hooks:
7561
- id: prettier
62+
files: \.(json|yml|yaml|toml)
7663
# https://prettier.io/docs/en/options.html#print-width
7764
args: ["--print-width=120"]

.prettierignore

Lines changed: 0 additions & 2 deletions
This file was deleted.

examples/covertype_classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@
9292
normalize_continuous_features=True,
9393
)
9494
head_config = LinearHeadConfig(
95-
layers="", dropout=0.1, initialization="kaiming" # No additional layer in head, just a mapping layer to output_dim
95+
layers="",
96+
dropout=0.1,
97+
initialization="kaiming", # No additional layer in head, just a mapping layer to output_dim
9698
).__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)
9799
model_config = CategoryEmbeddingModelConfig(
98100
task="classification",

pyproject.toml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
[tool.black]
2-
# https://github.com/psf/black
3-
line-length = 120
4-
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"
5-
6-
71
[tool.ruff]
2+
target-version = "py38"
83
line-length = 120
94
# Enable Pyflakes `E` and `F` codes by default.
105
select = [
116
"E", "W", # see: https://pypi.org/project/pycodestyle
127
"F", # see: https://pypi.org/project/pyflakes
138
"I", # isort
9+
"UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up
10+
"RUF100", # yesqa
1411
# "D", # see: https://pypi.org/project/pydocstyle
1512
# "N", # see: https://pypi.org/project/pep8-naming
1613
]
@@ -45,3 +42,10 @@ ignore-init-module-imports = true
4542
[tool.ruff.pydocstyle]
4643
# Use Google-style docstrings.
4744
convention = "google"
45+
46+
[tool.docformatter]
47+
recursive = true
48+
# this need to be shorter as some docstings are r"""...
49+
wrap-summaries = 119
50+
wrap-descriptions = 120
51+
blank = true

src/pytorch_tabular/categorical_encoders.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def transform(self, X):
5151
:return: encoded DataFrame of shape (n_samples, n_features), initial categorical columns are dropped, and
5252
replaced with encoded columns. DataFrame passed in argument is unchanged.
5353
:rtype: pandas.DataFrame
54+
5455
"""
5556
if not self._mapping:
5657
raise ValueError("`fit` method must be called before `transform`.")
@@ -80,6 +81,7 @@ def fit_transform(self, X, y=None):
8081
:return: encoded DataFrame of shape (n_samples, n_features), initial categorical columns are dropped, and
8182
replaced with encoded columns. DataFrame passed in argument is unchanged.
8283
:rtype: pandas.DataFrame
84+
8385
"""
8486
self.fit(X, y)
8587
return self.transform(X)
@@ -104,6 +106,7 @@ def save_as_object_file(self, path):
104106
105107
Args:
106108
path (str): path to save the encoder
109+
107110
"""
108111
if not self._mapping:
109112
raise ValueError("`fit` method must be called before `save_as_object_file`.")
@@ -114,6 +117,7 @@ def load_from_object_file(self, path):
114117
115118
Args:
116119
path (str): path to load the encoder
120+
117121
"""
118122
for k, v in pickle.load(open(path, "rb")).items():
119123
setattr(self, k, v)
@@ -131,6 +135,7 @@ def __init__(self, cols=None, handle_unseen="impute", handle_missing="impute"):
131135
'ignore' - skip unseen categories
132136
'impute' - impute new categories to a predefined value, which is same as NAN_CATEGORY
133137
:return: None
138+
134139
"""
135140
self._input_check("handle_unseen", handle_unseen, ["error", "ignore", "impute"])
136141
self._input_check("handle_missing", handle_missing, ["error", "impute"])
@@ -141,6 +146,7 @@ def fit(self, X, y=None):
141146
142147
:param pandas.DataFrame X: DataFrame of features, shape (n_samples, n_features). Must contain columns to encode.
143148
:return: None
149+
144150
"""
145151
self._before_fit_check(X, y)
146152
if self.handle_missing == "error":
@@ -161,6 +167,7 @@ def __init__(self, tabular_model):
161167
162168
Args:
163169
tabular_model (TabularModel): The trained TabularModel object
170+
164171
"""
165172
self._categorical_encoder = tabular_model.datamodule.categorical_encoder
166173
self.cols = tabular_model.model.hparams.categorical_cols
@@ -198,6 +205,7 @@ def fit(self, X, y=None):
198205
"""Just for compatibility.
199206
200207
Does not do anything
208+
201209
"""
202210
return self
203211

@@ -213,6 +221,7 @@ def transform(self, X: DataFrame, y=None) -> DataFrame:
213221
214222
Returns:
215223
DataFrame: The encoded dataframe
224+
216225
"""
217226
if not self._mapping:
218227
raise ValueError(
@@ -245,6 +254,7 @@ def fit_transform(self, X: DataFrame, y=None) -> DataFrame:
245254
246255
Returns:
247256
DataFrame: The encoded dataframe
257+
248258
"""
249259
self.fit(X, y)
250260
return self.transform(X)

src/pytorch_tabular/config/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class DataConfig:
9494
9595
handle_missing_values (bool): Whether to handle missing values in categorical columns as
9696
unknown
97+
9798
"""
9899

99100
target: Optional[List[str]] = field(
@@ -201,6 +202,7 @@ class InferredConfig:
201202
list of tuples (cardinality, embedding_dim).
202203
203204
embedded_cat_dim (int): The number of features or dimensions of the embedded categorical features
205+
204206
"""
205207

206208
categorical_dim: int = field(
@@ -341,6 +343,7 @@ class TrainerConfig:
341343
342344
trainer_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch Lightning Trainer. See
343345
https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer
346+
344347
"""
345348

346349
batch_size: int = field(default=64, metadata={"help": "Number of samples in each batch of training"})
@@ -575,6 +578,7 @@ class ExperimentConfig:
575578
log_logits (bool): Turn this on to log the logits as a histogram in W&B
576579
577580
exp_log_freq (int): step count between logging of gradients and parameters.
581+
578582
"""
579583

580584
project_name: str = field(
@@ -651,6 +655,7 @@ class OptimizerConfig:
651655
652656
lr_scheduler_monitor_metric (Optional[str]): Used with ReduceLROnPlateau, where the plateau is
653657
decided based on this metric
658+
654659
"""
655660

656661
optimizer: str = field(
@@ -703,6 +708,7 @@ def __init__(
703708
Args:
704709
exp_version_manager (str, optional): The path of the yml file which acts as version control.
705710
Defaults to ".pt_tmp/exp_version_manager.yml".
711+
706712
"""
707713
super().__init__()
708714
self._exp_version_manager = exp_version_manager
@@ -776,6 +782,7 @@ class ModelConfig:
776782
not apply any restrictions
777783
778784
seed (int): The seed for reproducibility. Defaults to 42
785+
779786
"""
780787

781788
task: str = field(
@@ -956,6 +963,7 @@ class SSLModelConfig:
956963
learning_rate (float): The learning rate of the model. Defaults to 1e-3
957964
958965
seed (int): The seed for reproducibility. Defaults to 42
966+
959967
"""
960968

961969
task: str = field(init=False, default="ssl")

src/pytorch_tabular/feature_extractor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, tabular_model, extract_keys=["backbone_features"], drop_origi
2626
tabular_model (TabularModel): The trained TabularModel object
2727
extract_keys (list, optional): The keys of the features to extract. Defaults to ["backbone_features"].
2828
drop_original (bool, optional): Whether to drop the original columns. Defaults to True.
29+
2930
"""
3031
assert not (
3132
isinstance(tabular_model.model, NODEModel)
@@ -40,6 +41,7 @@ def fit(self, X, y=None):
4041
"""Just for compatibility.
4142
4243
Does not do anything
44+
4345
"""
4446
return self
4547

@@ -55,6 +57,7 @@ def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
5557
5658
Returns:
5759
pd.DataFrame: The encoded dataframe
60+
5861
"""
5962

6063
X_encoded = X.copy(deep=True)
@@ -99,6 +102,7 @@ def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame:
99102
100103
Returns:
101104
pd.DataFrame: The encoded dataframe
105+
102106
"""
103107
self.fit(X, y)
104108
return self.transform(X)
@@ -108,6 +112,7 @@ def save_as_object_file(self, path):
108112
109113
Args:
110114
path (str): The path to save the file
115+
111116
"""
112117
if not self._mapping:
113118
raise ValueError("`fit` method must be called before `save_as_object_file`.")
@@ -118,6 +123,7 @@ def load_from_object_file(self, path):
118123
119124
Args:
120125
path (str): The path to load the file from
126+
121127
"""
122128
for k, v in pickle.load(open(path, "rb")).items():
123129
setattr(self, k, v)

src/pytorch_tabular/models/autoint/autoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, config: DictConfig):
1919
2020
Args:
2121
config (DictConfig): config of the model
22+
2223
"""
2324
super().__init__()
2425
self.hparams = config

src/pytorch_tabular/models/autoint/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class AutoIntConfig(ModelConfig):
108108
not apply any restrictions
109109
110110
seed (int): The seed for reproducibility. Defaults to 42
111+
111112
"""
112113

113114
attn_embed_dim: int = field(

src/pytorch_tabular/models/base_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def safe_merge_config(config: DictConfig, inferred_config: DictConfig) -> DictCo
4949
5050
Returns:
5151
The merged configuration.
52+
5253
"""
5354
# using base config values if exist
5455
inferred_config.embedding_dims = config.get("embedding_dims") or inferred_config.embedding_dims
@@ -90,6 +91,7 @@ def __init__(
9091
A custom optimizer as callable or string to be imported. Defaults to None.
9192
custom_optimizer_params (Dict, optional): A dictionary of custom optimizer parameters. Defaults to {}.
9293
kwargs (Dict, optional): Additional keyword arguments.
94+
9395
"""
9496
super().__init__()
9597
assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
@@ -231,6 +233,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
231233
232234
Returns:
233235
torch.Tensor: The loss value
236+
234237
"""
235238
y_hat = output["logits"]
236239
reg_terms = [k for k, v in output.items() if "regularization" in k]
@@ -287,6 +290,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
287290
288291
Returns:
289292
List[torch.Tensor]: The list of metric values
293+
290294
"""
291295
metrics = []
292296
for metric, metric_str, prob_inp, metric_params in zip(
@@ -349,13 +353,15 @@ def embed_input(self, x: Dict) -> torch.Tensor:
349353
return self.embedding_layer(x)
350354

351355
def apply_output_sigmoid_scaling(self, y_hat: torch.Tensor) -> torch.Tensor:
352-
"""Applies sigmoid scaling to the output of the model if the task is regression and the target range is defined.
356+
"""Applies sigmoid scaling to the output of the model if the task is regression and the target range is
357+
defined.
353358
354359
Args:
355360
y_hat (torch.Tensor): The output of the model
356361
357362
Returns:
358363
torch.Tensor: The output of the model with sigmoid scaling applied
364+
359365
"""
360366
if (self.hparams.task == "regression") and (self.hparams.target_range is not None):
361367
for i in range(self.hparams.output_dim):
@@ -373,6 +379,7 @@ def pack_output(self, y_hat: torch.Tensor, backbone_features: torch.tensor) -> D
373379
374380
Returns:
375381
The packed output of the model
382+
376383
"""
377384
# if self.head is the Identity function it means that we cannot extract backbone features,
378385
# because the model cannot be divide in backbone and head (i.e. TabNet)
@@ -388,6 +395,7 @@ def compute_head(self, backbone_features: Tensor) -> Dict[str, Any]:
388395
389396
Returns:
390397
The output of the model
398+
391399
"""
392400
y_hat = self.head(backbone_features)
393401
y_hat = self.apply_output_sigmoid_scaling(y_hat)
@@ -398,6 +406,7 @@ def forward(self, x: Dict) -> Dict[str, Any]:
398406
399407
Args:
400408
x (Dict): The input of the model with 'continuous' and 'categorical' keys
409+
401410
"""
402411
x = self.embed_input(x)
403412
x = self.compute_backbone(x)
@@ -413,6 +422,7 @@ def predict(self, x: Dict, ret_model_output: bool = False) -> Union[torch.Tensor
413422
414423
Returns:
415424
The output of the model
425+
416426
"""
417427
assert self.hparams.task != "ssl", "It's not allowed to use the method predict in case of ssl task"
418428
ret_value = self.forward(x)
@@ -427,6 +437,7 @@ def extract_embedding(self):
427437
"""Extracts the embedding of the model.
428438
429439
This is used in `CategoricalEmbeddingTransformer`
440+
430441
"""
431442
if self.hparams.categorical_dim > 0:
432443
if not isinstance(self.embedding_layer, PreEncoded1dLayer):

0 commit comments

Comments
 (0)