@@ -82,16 +82,15 @@ def _tta_3d(self, model, data):
8282 out = None
8383 cc = 0
8484
85- if self .num_aug == None :
86- opts = itertools .product (
87- (False , ), (False , ), (False , ), (False , ))
88- elif self .num_aug == 4 :
85+ opts = itertools .product (
86+ (False , ), (False , ), (False , ), (False , ))
87+ if self .num_aug == 4 :
8988 opts = itertools .product (
9089 (False , True ), (False , True ), (False , ), (False , ))
9190 elif self .num_aug == 8 :
9291 opts = itertools .product (
9392 (False , True ), (False , True ), (False , ), (False , True ))
94- else :
93+ elif self . num_aug == 16 :
9594 opts = itertools .product (
9695 (False , True ), (False , True ), (False , True ), (False , True ))
9796
@@ -210,7 +209,7 @@ def update_name(self, name):
210209 r"""Update the name of the output file to indicate applied test-time augmentations.
211210 """
212211 extension = "_"
213- if self .num_aug is None :
212+ if self .num_aug is None
214213 return name
215214 elif self .num_aug == 4 :
216215 extension += "xy"
0 commit comments