Skip to content

Commit f1f116e

Browse files
committed
Add sync_dist=True to all calls to self.log() in
`validation_step()` and `test_step()` to distributed training
1 parent c790615 commit f1f116e

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/pytorch_tabular/models/base_model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,14 @@ def _setup_metrics(self):
244244
else:
245245
self.metrics = self.custom_metrics
246246

247-
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tensor:
247+
def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str, sync_dist: bool = False) -> torch.Tensor:
248248
"""Calculates the loss for the model.
249249
250250
Args:
251251
output (Dict): The output dictionary from the model
252252
y (torch.Tensor): The target tensor
253253
tag (str): The tag to use for logging
254+
sync_dist (bool): enable distributed sync of logs
254255
255256
Returns:
256257
torch.Tensor: The loss value
@@ -270,6 +271,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
270271
on_step=False,
271272
logger=True,
272273
prog_bar=False,
274+
sync_dist=sync_dist,
273275
)
274276
if self.hparams.task == "regression":
275277
computed_loss = reg_loss
@@ -284,6 +286,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
284286
on_step=False,
285287
logger=True,
286288
prog_bar=False,
289+
sync_dist=sync_dist,
287290
)
288291
else:
289292
# TODO loss fails with batch size of 1?
@@ -301,6 +304,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
301304
on_step=False,
302305
logger=True,
303306
prog_bar=False,
307+
sync_dist=sync_dist,
304308
)
305309
start_index = end_index
306310
self.log(
@@ -311,10 +315,13 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
311315
# on_step=False,
312316
logger=True,
313317
prog_bar=True,
318+
sync_dist=sync_dist,
314319
)
315320
return computed_loss
316321

317-
def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> List[torch.Tensor]:
322+
def calculate_metrics(
323+
self, y: torch.Tensor, y_hat: torch.Tensor, tag: str, sync_dist: bool = False
324+
) -> List[torch.Tensor]:
318325
"""Calculates the metrics for the model.
319326
320327
Args:
@@ -324,6 +331,8 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
324331
325332
tag (str): The tag to use for logging
326333
334+
sync_dist (bool): enable distributed sync of logs
335+
327336
Returns:
328337
List[torch.Tensor]: The list of metric values
329338
@@ -356,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
356365
on_step=False,
357366
logger=True,
358367
prog_bar=False,
368+
sync_dist=sync_dist,
359369
)
360370
_metrics.append(_metric)
361371
avg_metric = torch.stack(_metrics, dim=0).sum()
@@ -379,6 +389,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
379389
on_step=False,
380390
logger=True,
381391
prog_bar=False,
392+
sync_dist=sync_dist,
382393
)
383394
_metrics.append(_metric)
384395
start_index = end_index
@@ -391,6 +402,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
391402
on_step=False,
392403
logger=True,
393404
prog_bar=True,
405+
sync_dist=sync_dist,
394406
)
395407
return metrics
396408

@@ -523,19 +535,19 @@ def validation_step(self, batch, batch_idx):
523535
# fetched from the batch
524536
y = batch["target"] if y is None else y
525537
y_hat = output["logits"]
526-
self.calculate_loss(output, y, tag="valid")
527-
self.calculate_metrics(y, y_hat, tag="valid")
538+
self.calculate_loss(output, y, tag="valid", sync_dist=True)
539+
self.calculate_metrics(y, y_hat, tag="valid", sync_dist=True)
528540
return y_hat, y
529541

530542
def test_step(self, batch, batch_idx):
531543
with torch.no_grad():
532544
output, y = self.forward_pass(batch)
533-
# y is not None for SSL task.Rest of the tasks target is
545+
# y is not None for SSL task. Rest of the tasks target is
534546
# fetched from the batch
535547
y = batch["target"] if y is None else y
536548
y_hat = output["logits"]
537-
self.calculate_loss(output, y, tag="test")
538-
self.calculate_metrics(y, y_hat, tag="test")
549+
self.calculate_loss(output, y, tag="test", sync_dist=True)
550+
self.calculate_metrics(y, y_hat, tag="test", sync_dist=True)
539551
return y_hat, y
540552

541553
def configure_optimizers(self):

0 commit comments

Comments
 (0)