@@ -48,7 +48,7 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
48
48
Helper method to average predictions of TTA-ed model.
49
49
This function assumes TTA dimension is 0, e.g [T, B, C, Ci, Cj, ..]
50
50
Args:
51
- x:
51
+ x: Input tensor of shape [T, B, ... ]
52
52
reduction: Reduction mode ("sum", "mean", "gmean", "hmean", function, None)
53
53
54
54
Returns:
@@ -64,6 +64,11 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
64
64
x = F .harmonic_mean (x , dim = 0 )
65
65
elif callable (reduction ):
66
66
x = reduction (x , dim = 0 )
67
+ elif reduction in {None , "None" , "none" }:
68
+ pass
69
+ else :
70
+ raise KeyError (f"Unsupported reduction mode { reduction } " )
71
+
67
72
return x
68
73
69
74
@@ -94,10 +99,7 @@ def fivecrop_image_augment(image: Tensor, crop_size: Tuple[int, int]) -> Tensor:
94
99
center_crop_x = (image_width - crop_width ) // 2
95
100
crop_cc = image [..., center_crop_y : center_crop_y + crop_height , center_crop_x : center_crop_x + crop_width ]
96
101
97
- return torch .cat (
98
- [crop_tl , crop_tr , crop_bl , crop_br , crop_cc ],
99
- dim = 0 ,
100
- )
102
+ return torch .cat ([crop_tl , crop_tr , crop_bl , crop_br , crop_cc ], dim = 0 ,)
101
103
102
104
103
105
def fivecrop_label_deaugment (logits : Tensor , reduction : MaybeStrOrCallable = "mean" ) -> Tensor :
@@ -275,15 +277,7 @@ def d2_image_augment(image: Tensor) -> Tensor:
275
277
- Vertically-flipped tensor
276
278
277
279
"""
278
- return torch .cat (
279
- [
280
- image ,
281
- F .torch_rot180 (image ),
282
- F .torch_fliplr (image ),
283
- F .torch_flipud (image ),
284
- ],
285
- dim = 0 ,
286
- )
280
+ return torch .cat ([image , F .torch_rot180 (image ), F .torch_fliplr (image ), F .torch_flipud (image ),], dim = 0 ,)
287
281
288
282
289
283
def d2_image_deaugment (image : Tensor , reduction : MaybeStrOrCallable = "mean" ) -> Tensor :
@@ -302,12 +296,7 @@ def d2_image_deaugment(image: Tensor, reduction: MaybeStrOrCallable = "mean") ->
302
296
b1 , b2 , b3 , b4 = torch .chunk (image , 4 )
303
297
304
298
image : Tensor = torch .stack (
305
- [
306
- b1 ,
307
- F .torch_rot180 (b2 ),
308
- F .torch_fliplr (b3 ),
309
- F .torch_flipud (b4 ),
310
- ]
299
+ [b1 , F .torch_rot180 (b2 ), F .torch_fliplr (b3 ), F .torch_flipud (b4 ),]
311
300
)
312
301
313
302
return _deaugment_averaging (image , reduction = reduction )
@@ -440,10 +429,7 @@ def flips_image_augment(image: Tensor) -> Tensor:
440
429
return torch .cat ([image , F .torch_fliplr (image ), F .torch_flipud (image )], dim = 0 )
441
430
442
431
443
- def flips_image_deaugment (
444
- image : Tensor ,
445
- reduction : MaybeStrOrCallable = "mean" ,
446
- ) -> Tensor :
432
+ def flips_image_deaugment (image : Tensor , reduction : MaybeStrOrCallable = "mean" ,) -> Tensor :
447
433
"""
448
434
Deaugment input tensor (output of the model) assuming the input was flip-augmented image (See flips_augment).
449
435
Args:
@@ -464,10 +450,7 @@ def flips_image_deaugment(
464
450
return _deaugment_averaging (image , reduction = reduction )
465
451
466
452
467
- def fliplr_labels_deaugment (
468
- logits : Tensor ,
469
- reduction : MaybeStrOrCallable = "mean" ,
470
- ) -> Tensor :
453
+ def fliplr_labels_deaugment (logits : Tensor , reduction : MaybeStrOrCallable = "mean" ,) -> Tensor :
471
454
"""
472
455
Deaugment input tensor (output of the model) assuming the input was fliplr-augmented image (See fliplr_image_augment).
473
456
Args:
@@ -485,10 +468,7 @@ def fliplr_labels_deaugment(
485
468
return _deaugment_averaging (logits , reduction = reduction )
486
469
487
470
488
- def flips_labels_deaugment (
489
- logits : Tensor ,
490
- reduction : MaybeStrOrCallable = "mean" ,
491
- ) -> Tensor :
471
+ def flips_labels_deaugment (logits : Tensor , reduction : MaybeStrOrCallable = "mean" ,) -> Tensor :
492
472
"""
493
473
Deaugment input tensor (output of the model) assuming the input was flip-augmented image (See flips_image_augment).
494
474
Args:
@@ -543,9 +523,7 @@ def ms_image_augment(
543
523
544
524
545
525
def ms_labels_deaugment (
546
- logits : List [Tensor ],
547
- size_offsets : List [Union [int , Tuple [int , int ]]],
548
- reduction : MaybeStrOrCallable = "mean" ,
526
+ logits : List [Tensor ], size_offsets : List [Union [int , Tuple [int , int ]]], reduction : MaybeStrOrCallable = "mean" ,
549
527
):
550
528
"""
551
529
Deaugment logits
0 commit comments