Skip to content

Commit b08e8cc

Browse files
author
donglaiw
committed
add pretrained model dict check
1 parent 0d763b3 commit b08e8cc

File tree

4 files changed

+8
-76
lines changed

4 files changed

+8
-76
lines changed

configs/MitoEM/MitoEM-BC.yaml

Lines changed: 0 additions & 16 deletions
This file was deleted.

configs/MitoEM/MitoEM-BCD.yaml

Lines changed: 0 additions & 17 deletions
This file was deleted.

configs/MitoEM/MitoEM-Base.yaml

Lines changed: 0 additions & 40 deletions
This file was deleted.

connectomics/engine/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from yacs.config import CfgNode
1212

1313
import torch
14+
torch.backends.cuda.matmul.allow_tf32 = False
15+
torch.backends.cudnn.allow_tf32 = False
1416
from torch.cuda.amp import autocast, GradScaler
1517

1618
from .base import TrainerBase
@@ -390,9 +392,12 @@ def update_checkpoint(self, checkpoint: Optional[str] = None):
390392
if not model_dict.keys() == pretrained_dict.keys():
391393
warnings.warn("Module keys in model.state_dict() do not exactly "
392394
"match the keys in pretrained_dict!")
393-
for key in model_dict.keys():
394-
if not key in pretrained_dict:
395-
print(key)
395+
key_missing = [key for key in model_dict.keys() if not key in pretrained_dict.keys()]
396+
if len(key_missing) != 0:
397+
print('missing keys (%d): '%len(key_missing), key_missing)
398+
key_unused = [key for key in pretrained_dict.keys() if not key in model_dict.keys()]
399+
if len(key_unused) != 0:
400+
print('unused keys (%d): '%len(key_unused), key_unused)
396401

397402
# 1. filter out unnecessary keys by name
398403
pretrained_dict = {k: v for k,

0 commit comments

Comments
 (0)