Skip to content

Commit 5c21c88

Browse files
committed
Only load best checkpoint on rank zero in distributed
training
1 parent c589f71 commit 5c21c88

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from pytorch_lightning.tuner.tuning import Tuner
3333
from pytorch_lightning.utilities.model_summary import summarize
34+
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3435
from rich import print as rich_print
3536
from rich.pretty import pprint
3637
from sklearn.base import TransformerMixin
@@ -1522,6 +1523,7 @@ def add_noise(module, input, output):
15221523
)
15231524
return pred_df
15241525

1526+
@rank_zero_only
15251527
def load_best_model(self) -> None:
15261528
"""Loads the best model after training is done."""
15271529
if self.trainer.checkpoint_callback is not None:

0 commit comments

Comments
 (0)