Skip to content

Commit 8032d0b

Browse files
committed
Add sync_dist to SSL Models
1 parent 06e9589 commit 8032d0b

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/pytorch_tabular/ssl_models/base_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def _setup_metrics(self):
136136
pass
137137

138138
@abstractmethod
139-
def calculate_loss(self, output, tag):
139+
def calculate_loss(self, output, tag, sync_dist):
140140
pass
141141

142142
@abstractmethod
143-
def calculate_metrics(self, output, tag):
143+
def calculate_metrics(self, output, tag, sync_dist):
144144
pass
145145

146146
@abstractmethod
@@ -167,15 +167,15 @@ def training_step(self, batch, batch_idx):
167167
def validation_step(self, batch, batch_idx):
168168
with torch.no_grad():
169169
output = self.forward(batch)
170-
self.calculate_loss(output, tag="valid")
171-
self.calculate_metrics(output, tag="valid")
170+
self.calculate_loss(output, tag="valid", sync_dist=True)
171+
self.calculate_metrics(output, tag="valid", sync_dist=True)
172172
return output
173173

174174
def test_step(self, batch, batch_idx):
175175
with torch.no_grad():
176176
output = self.forward(batch)
177-
self.calculate_loss(output, tag="test")
178-
self.calculate_metrics(output, tag="test")
177+
self.calculate_loss(output, tag="test", sync_dist=True)
178+
self.calculate_metrics(output, tag="test", sync_dist=True)
179179
return output
180180

181181
def on_validation_epoch_end(self) -> None:

src/pytorch_tabular/ssl_models/dae/dae.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def forward(self, x: Dict):
200200
else:
201201
return z.features
202202

203-
def calculate_loss(self, output, tag):
203+
def calculate_loss(self, output, tag, sync_dist=False):
204204
total_loss = 0
205205
for type_, out in output.items():
206206
if type_ == "categorical":
@@ -220,6 +220,7 @@ def calculate_loss(self, output, tag):
220220
on_step=False,
221221
logger=True,
222222
prog_bar=False,
223+
sync_dist=sync_dist,
223224
)
224225
total_loss += loss
225226
self.log(
@@ -230,10 +231,11 @@ def calculate_loss(self, output, tag):
230231
# on_step=False,
231232
logger=True,
232233
prog_bar=True,
234+
sync_dist=sync_dist,
233235
)
234236
return total_loss
235237

236-
def calculate_metrics(self, output, tag):
238+
def calculate_metrics(self, output, tag, sync_dist=False):
237239
pass
238240

239241
def featurize(self, x: Dict):

0 commit comments

Comments
 (0)