Skip to content

Commit df3c776

Browse files
Move get_idx_list_per_classes to dataset class and address other PR comments.
Signed-off-by: Albert van Houten <albert.van.houten@intel.com>
1 parent 785f5e6 commit df3c776

File tree

12 files changed

+30
-71
lines changed

12 files changed

+30
-71
lines changed

lib/src/otx/data/dataset/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from __future__ import annotations
77

88
from abc import abstractmethod
9+
from collections import defaultdict
910
from collections.abc import Iterable
1011
from contextlib import contextmanager
1112
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union
1213

1314
import cv2
1415
import numpy as np
15-
from datumaro.components.annotation import AnnotationType
16+
from datumaro.components.annotation import AnnotationType, LabelCategories
1617
from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend
1718
from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel
1819
from torch.utils.data import Dataset
@@ -196,3 +197,18 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None:
196197
def collate_fn(self) -> Callable:
197198
"""Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader."""
198199
return OTXDataItem.collate_fn
200+
201+
def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]:
202+
"""Compute class statistics."""
203+
stats: dict[int | str, list[int]] = defaultdict(list)
204+
for item_idx, item in enumerate(self.dm_subset):
205+
for ann in item.annotations:
206+
if use_string_label:
207+
labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories())
208+
stats[labels.items[ann.label].name].append(item_idx)
209+
else:
210+
stats[ann.label].append(item_idx)
211+
# Remove duplicates in label stats idx: O(n)
212+
for k in stats:
213+
stats[k] = list(dict.fromkeys(stats[k]))
214+
return stats

lib/src/otx/data/dataset/base_new.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,5 @@ def collate_fn(self) -> Callable:
157157
return _default_collate_fn
158158

159159
@abc.abstractmethod
160-
def get_idx_list_per_classes(self) -> dict[int, list[int]]:
160+
def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]:
161161
"""Get a dictionary with class labels as keys and lists of corresponding sample indices as values."""

lib/src/otx/data/dataset/classification_new.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ def __init__(self, **kwargs) -> None:
2121
kwargs["sample_type"] = ClassificationSample
2222
super().__init__(**kwargs)
2323

24-
def get_idx_list_per_classes(self) -> dict[int, list[int]]:
24+
def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]:
2525
"""Get index list per class."""
2626
idx_list_per_classes: dict[int, list[int]] = {}
2727
for idx in range(len(self)):
2828
item = self.dm_subset[idx]
2929
label_id = item.label.item()
30+
if use_string_label:
31+
label_id = self.label_info.labels[label_id]
3032
if label_id not in idx_list_per_classes:
3133
idx_list_per_classes[label_id] = []
3234
idx_list_per_classes[label_id].append(idx)

lib/src/otx/data/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def create(
8282
categories = cls._get_label_categories(dm_subset, data_format)
8383
dataset = DatasetNew(ClassificationSample, categories={"label": categories})
8484
for item in dm_subset:
85-
if len(item.media.data.shape) == 3: # TODO: Account for grayscale images
85+
if len(item.media.data.shape) == 3: # TODO(albert): Account for grayscale images
8686
dataset.append(ClassificationSample.from_dm_item(item))
8787
common_kwargs["dm_subset"] = dataset
8888
return OTXMulticlassClsDataset(**common_kwargs)
8989

9090
if task == OTXTaskType.MULTI_LABEL_CLS:
91-
from otx.data.dataset.classification import OTXMultilabelClsDataset
91+
from .dataset.classification import OTXMultilabelClsDataset
9292

9393
return OTXMultilabelClsDataset(**common_kwargs)
9494

lib/src/otx/data/samplers/balanced_sampler.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@
99
from typing import TYPE_CHECKING
1010

1111
import torch
12-
from datumaro import DatasetSubset
13-
from datumaro.experimental import Dataset as NewDataset
1412
from torch.utils.data import Sampler
1513

16-
from otx.data.utils import get_idx_list_per_classes
17-
1814
if TYPE_CHECKING:
1915
from otx.data.dataset.base import OTXDataset
2016
from otx.data.dataset.base_new import OTXDataset as OTXDatasetNew
@@ -64,11 +60,7 @@ def __init__(
6460
super().__init__(dataset)
6561

6662
# img_indices: dict[label: list[idx]]
67-
ann_stats: dict[int | str, list[int]]
68-
if isinstance(dataset.dm_subset, DatasetSubset):
69-
ann_stats = get_idx_list_per_classes(dataset.dm_subset)
70-
elif isinstance(dataset.dm_subset, NewDataset):
71-
ann_stats = dataset.get_idx_list_per_classes() # type: ignore[attr-defined]
63+
ann_stats = dataset.get_idx_list_per_classes()
7264

7365
self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0}
7466
self.num_cls = len(self.img_indices.keys())

lib/src/otx/data/samplers/class_incremental_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.utils.data import Sampler
1313

1414
from otx.data.dataset.base import OTXDataset
15-
from otx.data.utils import get_idx_list_per_classes
1615

1716

1817
class ClassIncrementalSampler(Sampler):
@@ -65,7 +64,7 @@ def __init__(
6564
super().__init__(dataset)
6665

6766
# Need to split new classes dataset indices & old classses dataset indices
68-
ann_stats = get_idx_list_per_classes(dataset.dm_subset, True)
67+
ann_stats = dataset.get_idx_list_per_classes(True)
6968
new_indices, old_indices = [], []
7069
for cls in new_classes:
7170
new_indices.extend(ann_stats[cls])

lib/src/otx/data/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
adapt_input_size_to_dataset,
88
adapt_tile_config,
99
get_adaptive_num_workers,
10-
get_idx_list_per_classes,
1110
import_object_from_module,
1211
instantiate_sampler,
1312
)
@@ -17,6 +16,5 @@
1716
"adapt_input_size_to_dataset",
1817
"instantiate_sampler",
1918
"get_adaptive_num_workers",
20-
"get_idx_list_per_classes",
2119
"import_object_from_module",
2220
]

lib/src/otx/data/utils/utils.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
import cv2
1616
import numpy as np
1717
import torch
18-
from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon
18+
from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, Polygon
1919
from datumaro.components.annotation import Shape as _Shape
2020

2121
from otx.types import OTXTaskType
2222
from otx.utils.device import is_xpu_available
2323

2424
if TYPE_CHECKING:
25-
from datumaro import Dataset as DmDataset
2625
from datumaro import DatasetSubset
2726
from torch.utils.data import Dataset, Sampler
2827

@@ -322,22 +321,6 @@ def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None:
322321
return min(cpu_count() // (num_dataloader * num_devices), 8) # max available num_workers is 8
323322

324323

325-
def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]:
326-
"""Compute class statistics."""
327-
stats: dict[int | str, list[int]] = defaultdict(list)
328-
for item_idx, item in enumerate(dm_dataset):
329-
for ann in item.annotations:
330-
if use_string_label:
331-
labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories())
332-
stats[labels.items[ann.label].name].append(item_idx)
333-
else:
334-
stats[ann.label].append(item_idx)
335-
# Remove duplicates in label stats idx: O(n)
336-
for k in stats:
337-
stats[k] = list(dict.fromkeys(stats[k]))
338-
return stats
339-
340-
341324
def import_object_from_module(obj_path: str) -> Any: # noqa: ANN401
342325
"""Get object from import format string."""
343326
module_name, obj_name = obj_path.rsplit(".", 1)

lib/tests/unit/data/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
@pytest.fixture(params=["bytes", "file"])
44-
def fxt_dm_item(requeset, tmpdir) -> DatasetItem:
44+
def fxt_dm_item(request, tmpdir) -> DatasetItem:
4545
np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8)
4646
np_img[:, :, 0] = 0 # Set 0 for B channel
4747
np_img[:, :, 1] = 1 # Set 1 for G channel

lib/tests/unit/data/samplers/test_balanced_sampler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from otx.data.dataset.base import OTXDataset
1414
from otx.data.samplers.balanced_sampler import BalancedSampler
15-
from otx.data.utils import get_idx_list_per_classes
1615

1716

1817
@pytest.fixture()
@@ -81,7 +80,7 @@ def test_sampler_iter_with_multiple_replicas(self, fxt_imbalanced_dataset):
8180

8281
def test_compute_class_statistics(self, fxt_imbalanced_dataset):
8382
# Compute class statistics
84-
stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset)
83+
stats = fxt_imbalanced_dataset.get_idx_list_per_classes()
8584

8685
# Check the expected results
8786
assert stats == {0: list(range(100)), 1: list(range(100, 108))}
@@ -90,7 +89,7 @@ def test_sampler_iter_per_class(self, fxt_imbalanced_dataset):
9089
batch_size = 4
9190
sampler = BalancedSampler(fxt_imbalanced_dataset)
9291

93-
stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset)
92+
stats = fxt_imbalanced_dataset.get_idx_list_per_classes()
9493
class_0_idx = stats[0]
9594
class_1_idx = stats[1]
9695
list_iter = list(iter(sampler))

0 commit comments

Comments
 (0)