@@ -28,23 +28,25 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
28
28
29
29
Args:
30
30
img_infos (list[ImageInfo]): Original image information before tiling.
31
- iou_threshold (float, optional ): IoU threshold for non-maximum suppression. Defaults to 0.45 .
32
- max_num_instances (int, optional ): Maximum number of instances to keep. Defaults to 500 .
33
-
31
+ num_classes (int ): Number of classes .
32
+ tile_config (TileConfig ): Tile configuration .
33
+ explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
34
34
"""
35
35
36
36
def __init__ (
37
37
self ,
38
38
img_infos : list [ImageInfo ],
39
39
num_classes : int ,
40
40
tile_config : TileConfig ,
41
+ explain_mode : bool = False ,
41
42
) -> None :
42
43
self .img_infos = img_infos
43
44
self .num_classes = num_classes
44
45
self .tile_size = tile_config .tile_size
45
46
self .iou_threshold = tile_config .iou_threshold
46
47
self .max_num_instances = tile_config .max_num_instances
47
48
self .with_full_img = tile_config .with_full_img
49
+ self .explain_mode = explain_mode
48
50
49
51
@abstractmethod
50
52
def _merge_entities (
@@ -116,10 +118,10 @@ def merge(
116
118
"""
117
119
entities_to_merge = defaultdict (list )
118
120
img_ids = []
119
- explain_mode = len ( batch_tile_preds [ 0 ]. feature_vector ) > 0
121
+ explain_mode = self . explain_mode
120
122
121
- for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs ):
122
- batch_size = tile_preds . batch_size
123
+ for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs , strict = True ):
124
+ batch_size = len ( tile_attrs )
123
125
saliency_maps = tile_preds .saliency_map if explain_mode else [[] for _ in range (batch_size )]
124
126
feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (batch_size )]
125
127
for tile_attr , tile_img_info , tile_bboxes , tile_labels , tile_scores , tile_s_map , tile_f_vect in zip (
@@ -130,6 +132,7 @@ def merge(
130
132
tile_preds .scores ,
131
133
saliency_maps ,
132
134
feature_vectors ,
135
+ strict = True ,
133
136
):
134
137
offset_x , offset_y , _ , _ = tile_attr ["roi" ]
135
138
tile_bboxes [:, 0 ::2 ] += offset_x
@@ -155,7 +158,7 @@ def merge(
155
158
156
159
return [
157
160
self ._merge_entities (image_info , entities_to_merge [img_id ], explain_mode )
158
- for img_id , image_info in zip (img_ids , self .img_infos )
161
+ for img_id , image_info in zip (img_ids , self .img_infos , strict = True )
159
162
]
160
163
161
164
def _merge_entities (
@@ -316,10 +319,10 @@ def merge(
316
319
"""
317
320
entities_to_merge = defaultdict (list )
318
321
img_ids = []
319
- explain_mode = len ( batch_tile_preds [ 0 ]. feature_vector ) > 0
322
+ explain_mode = self . explain_mode
320
323
321
- for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs ):
322
- feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (tile_preds . batch_size )]
324
+ for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs , strict = True ):
325
+ feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (len ( tile_attrs ) )]
323
326
for tile_attr , tile_img_info , tile_bboxes , tile_labels , tile_scores , tile_masks , tile_f_vect in zip (
324
327
tile_attrs ,
325
328
tile_preds .imgs_info ,
@@ -328,6 +331,7 @@ def merge(
328
331
tile_preds .scores ,
329
332
tile_preds .masks ,
330
333
feature_vectors ,
334
+ strict = True ,
331
335
):
332
336
keep_indices = tile_masks .to_sparse ().sum ((1 , 2 )).to_dense () > 0
333
337
keep_indices = keep_indices .nonzero (as_tuple = True )[0 ]
@@ -362,7 +366,7 @@ def merge(
362
366
363
367
return [
364
368
self ._merge_entities (image_info , entities_to_merge [img_id ], explain_mode )
365
- for img_id , image_info in zip (img_ids , self .img_infos )
369
+ for img_id , image_info in zip (img_ids , self .img_infos , strict = True )
366
370
]
367
371
368
372
def _merge_entities (
@@ -454,6 +458,18 @@ def get_saliency_maps_from_masks(
454
458
class SegmentationTileMerge (TileMerge ):
455
459
"""Semantic segmentation tile merge."""
456
460
461
+ def __init__ (
462
+ self ,
463
+ img_infos : list [ImageInfo ],
464
+ num_classes : int ,
465
+ tile_config : TileConfig ,
466
+ explain_mode : bool = False ,
467
+ ) -> None :
468
+ super ().__init__ (img_infos , num_classes , tile_config , explain_mode )
469
+ if explain_mode :
470
+ msg = "Explain mode is not supported for segmentation"
471
+ raise ValueError (msg )
472
+
457
473
def merge (
458
474
self ,
459
475
batch_tile_preds : list [SegBatchPredEntity ],
@@ -470,7 +486,7 @@ def merge(
470
486
"""
471
487
entities_to_merge = defaultdict (list )
472
488
img_ids = []
473
- explain_mode = len ( batch_tile_preds [ 0 ]. feature_vector ) > 0
489
+ explain_mode = self . explain_mode
474
490
475
491
for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs ):
476
492
batch_size = tile_preds .batch_size
@@ -538,15 +554,9 @@ def _merge_entities(
538
554
]
539
555
full_logits_mask = full_logits_mask / vote_mask .unsqueeze (0 )
540
556
541
- seg_pred_entity = SegPredEntity (
557
+ return SegPredEntity (
542
558
image = torch .empty (img_size ),
543
559
img_info = img_info ,
544
560
masks = full_logits_mask .argmax (0 ).unsqueeze (0 ),
545
561
score = [],
546
562
)
547
-
548
- if explain_mode :
549
- msg = "Explain mode is not supported for segmentation task."
550
- raise NotImplementedError (msg )
551
-
552
- return seg_pred_entity
0 commit comments