Skip to content

Commit 047f738

Browse files
committed
update explain_mode
1 parent 231f253 commit 047f738

File tree

4 files changed

+32
-19
lines changed

4 files changed

+32
-19
lines changed

src/otx/core/model/detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> D
241241
inputs.imgs_info,
242242
self.num_classes,
243243
self.tile_config,
244+
self.explain_mode,
244245
)
245246
for batch_tile_attrs, batch_tile_input in inputs.unbind():
246247
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)

src/otx/core/model/instance_segmentation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntit
232232
inputs.imgs_info,
233233
self.num_classes,
234234
self.tile_config,
235+
self.explain_mode,
235236
)
236237
for batch_tile_attrs, batch_tile_input in inputs.unbind():
237238
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)

src/otx/core/model/segmentation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def forward_tiles(self, inputs: OTXTileBatchDataEntity[SegBatchDataEntity]) -> S
245245
inputs.imgs_info,
246246
self.num_classes,
247247
self.tile_config,
248+
self.explain_mode,
248249
)
249250
for batch_tile_attrs, batch_tile_input in inputs.unbind():
250251
tile_size = batch_tile_attrs[0]["tile_size"]

src/otx/core/utils/tile_merge.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,25 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
2828
2929
Args:
3030
img_infos (list[ImageInfo]): Original image information before tiling.
31-
iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.45.
32-
max_num_instances (int, optional): Maximum number of instances to keep. Defaults to 500.
33-
31+
num_classes (int): Number of classes.
32+
tile_config (TileConfig): Tile configuration.
33+
explain_mode (bool, optional): Whether or not tiles have explain features. Default: False.
3434
"""
3535

3636
def __init__(
3737
self,
3838
img_infos: list[ImageInfo],
3939
num_classes: int,
4040
tile_config: TileConfig,
41+
explain_mode: bool = False,
4142
) -> None:
4243
self.img_infos = img_infos
4344
self.num_classes = num_classes
4445
self.tile_size = tile_config.tile_size
4546
self.iou_threshold = tile_config.iou_threshold
4647
self.max_num_instances = tile_config.max_num_instances
4748
self.with_full_img = tile_config.with_full_img
49+
self.explain_mode = explain_mode
4850

4951
@abstractmethod
5052
def _merge_entities(
@@ -116,10 +118,10 @@ def merge(
116118
"""
117119
entities_to_merge = defaultdict(list)
118120
img_ids = []
119-
explain_mode = len(batch_tile_preds[0].feature_vector) > 0
121+
explain_mode = self.explain_mode
120122

121-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
122-
batch_size = tile_preds.batch_size
123+
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
124+
batch_size = len(tile_attrs)
123125
saliency_maps = tile_preds.saliency_map if explain_mode else [[] for _ in range(batch_size)]
124126
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(batch_size)]
125127
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_s_map, tile_f_vect in zip(
@@ -130,6 +132,7 @@ def merge(
130132
tile_preds.scores,
131133
saliency_maps,
132134
feature_vectors,
135+
strict=True,
133136
):
134137
offset_x, offset_y, _, _ = tile_attr["roi"]
135138
tile_bboxes[:, 0::2] += offset_x
@@ -155,7 +158,7 @@ def merge(
155158

156159
return [
157160
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
158-
for img_id, image_info in zip(img_ids, self.img_infos)
161+
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
159162
]
160163

161164
def _merge_entities(
@@ -316,10 +319,10 @@ def merge(
316319
"""
317320
entities_to_merge = defaultdict(list)
318321
img_ids = []
319-
explain_mode = len(batch_tile_preds[0].feature_vector) > 0
322+
explain_mode = self.explain_mode
320323

321-
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
322-
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(tile_preds.batch_size)]
324+
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs, strict=True):
325+
feature_vectors = tile_preds.feature_vector if explain_mode else [[] for _ in range(len(tile_attrs))]
323326
for tile_attr, tile_img_info, tile_bboxes, tile_labels, tile_scores, tile_masks, tile_f_vect in zip(
324327
tile_attrs,
325328
tile_preds.imgs_info,
@@ -328,6 +331,7 @@ def merge(
328331
tile_preds.scores,
329332
tile_preds.masks,
330333
feature_vectors,
334+
strict=True,
331335
):
332336
keep_indices = tile_masks.to_sparse().sum((1, 2)).to_dense() > 0
333337
keep_indices = keep_indices.nonzero(as_tuple=True)[0]
@@ -362,7 +366,7 @@ def merge(
362366

363367
return [
364368
self._merge_entities(image_info, entities_to_merge[img_id], explain_mode)
365-
for img_id, image_info in zip(img_ids, self.img_infos)
369+
for img_id, image_info in zip(img_ids, self.img_infos, strict=True)
366370
]
367371

368372
def _merge_entities(
@@ -454,6 +458,18 @@ def get_saliency_maps_from_masks(
454458
class SegmentationTileMerge(TileMerge):
455459
"""Semantic segmentation tile merge."""
456460

461+
def __init__(
462+
self,
463+
img_infos: list[ImageInfo],
464+
num_classes: int,
465+
tile_config: TileConfig,
466+
explain_mode: bool = False,
467+
) -> None:
468+
super().__init__(img_infos, num_classes, tile_config, explain_mode)
469+
if explain_mode:
470+
msg = "Explain mode is not supported for segmentation"
471+
raise ValueError(msg)
472+
457473
def merge(
458474
self,
459475
batch_tile_preds: list[SegBatchPredEntity],
@@ -470,7 +486,7 @@ def merge(
470486
"""
471487
entities_to_merge = defaultdict(list)
472488
img_ids = []
473-
explain_mode = len(batch_tile_preds[0].feature_vector) > 0
489+
explain_mode = self.explain_mode
474490

475491
for tile_preds, tile_attrs in zip(batch_tile_preds, batch_tile_attrs):
476492
batch_size = tile_preds.batch_size
@@ -538,15 +554,9 @@ def _merge_entities(
538554
]
539555
full_logits_mask = full_logits_mask / vote_mask.unsqueeze(0)
540556

541-
seg_pred_entity = SegPredEntity(
557+
return SegPredEntity(
542558
image=torch.empty(img_size),
543559
img_info=img_info,
544560
masks=full_logits_mask.argmax(0).unsqueeze(0),
545561
score=[],
546562
)
547-
548-
if explain_mode:
549-
msg = "Explain mode is not supported for segmentation task."
550-
raise NotImplementedError(msg)
551-
552-
return seg_pred_entity

0 commit comments

Comments
 (0)