|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
8 | 8 | from abc import abstractmethod
|
| 9 | +from collections import defaultdict |
9 | 10 | from collections.abc import Iterable
|
10 | 11 | from contextlib import contextmanager
|
11 | 12 | from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union
|
12 | 13 |
|
13 | 14 | import cv2
|
14 | 15 | import numpy as np
|
15 |
| -from datumaro.components.annotation import AnnotationType |
| 16 | +from datumaro.components.annotation import AnnotationType, LabelCategories |
16 | 17 | from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend
|
17 | 18 | from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel
|
18 | 19 | from torch.utils.data import Dataset
|
@@ -196,3 +197,18 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None:
|
196 | 197 | def collate_fn(self) -> Callable:
|
197 | 198 | """Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader."""
|
198 | 199 | return OTXDataItem.collate_fn
|
| 200 | + |
| 201 | + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: |
| 202 | + """Compute class statistics.""" |
| 203 | + stats: dict[int | str, list[int]] = defaultdict(list) |
| 204 | + for item_idx, item in enumerate(self.dm_subset): |
| 205 | + for ann in item.annotations: |
| 206 | + if use_string_label: |
| 207 | + labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories()) |
| 208 | + stats[labels.items[ann.label].name].append(item_idx) |
| 209 | + else: |
| 210 | + stats[ann.label].append(item_idx) |
| 211 | + # Remove duplicates in label stats idx: O(n) |
| 212 | + for k in stats: |
| 213 | + stats[k] = list(dict.fromkeys(stats[k])) |
| 214 | + return stats |
0 commit comments