Skip to content

Commit 910495a

Browse files
authored
Remove image caching during training to reduce complexity and RAM consumption (#4401)
* delete mem cache * fix linter * remove unsued transforms * fix linter * fix unit test * fix API tests * remove psutil
1 parent 30df4ea commit 910495a

35 files changed

+6
-977
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ dependencies = [
3030
"omegaconf==2.3.0",
3131
"rich==14.0.0",
3232
"jsonargparse==4.30.0",
33-
"psutil==7.0.0", # Mem cache needs system checks
3433
"ftfy==6.3.1",
3534
"regex==2024.11.6",
3635
"importlib_resources==6.5.2",

src/otx/core/data/dataset/anomaly.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from otx.core.data.dataset.base import OTXDataset, Transforms
2323
from otx.core.data.entity.base import ImageInfo
24-
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
2524
from otx.core.types.image import ImageColorChannel
2625
from otx.core.types.label import AnomalyLabelInfo
2726
from otx.core.types.task import OTXTaskType
@@ -43,8 +42,6 @@ def __init__(
4342
task_type: OTXTaskType,
4443
dm_subset: DmDataset,
4544
transforms: Transforms,
46-
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
47-
mem_cache_img_max_size: tuple[int, int] | None = None,
4845
max_refetch: int = 1000,
4946
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
5047
stack_images: bool = True,
@@ -55,8 +52,6 @@ def __init__(
5552
super().__init__(
5653
dm_subset,
5754
transforms,
58-
mem_cache_handler,
59-
mem_cache_img_max_size,
6055
max_refetch,
6156
image_color_channel,
6257
stack_images,

src/otx/core/data/dataset/base.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
import cv2
1414
import numpy as np
1515
from datumaro.components.annotation import AnnotationType
16-
from datumaro.components.media import ImageFromFile
1716
from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend
1817
from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel
1918
from torch.utils.data import Dataset
2019

2120
from otx.core.data.entity.base import T_OTXDataEntity
22-
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER
2321
from otx.core.data.transform_libs.torchvision import Compose
2422
from otx.core.types.image import ImageColorChannel
2523
from otx.core.types.label import LabelInfo, NullLabelInfo
@@ -28,7 +26,6 @@
2826
if TYPE_CHECKING:
2927
from datumaro import DatasetSubset, Image
3028

31-
from otx.core.data.mem_cache import MemCacheHandlerBase
3229

3330
Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]]
3431

@@ -66,8 +63,6 @@ class OTXDataset(Dataset):
6663
Args:
6764
dm_subset: Datumaro subset of a dataset
6865
transforms: Transforms to apply on images
69-
mem_cache_handler: Handler of the images cache
70-
mem_cache_img_max_size: Max size of images to put in cache
7166
max_refetch: Maximum number of images to fetch in cache
7267
image_color_channel: Color channel of images
7368
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
@@ -79,8 +74,6 @@ def __init__(
7974
self,
8075
dm_subset: DatasetSubset,
8176
transforms: Transforms,
82-
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
83-
mem_cache_img_max_size: tuple[int, int] | None = None,
8477
max_refetch: int = 1000,
8578
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
8679
stack_images: bool = True,
@@ -89,8 +82,6 @@ def __init__(
8982
) -> None:
9083
self.dm_subset = dm_subset
9184
self.transforms = transforms
92-
self.mem_cache_handler = mem_cache_handler
93-
self.mem_cache_img_max_size = mem_cache_img_max_size
9485
self.max_refetch = max_refetch
9586
self.image_color_channel = image_color_channel
9687
self.stack_images = stack_images
@@ -166,12 +157,7 @@ def _get_img_data_and_shape(
166157
Returns:
167158
The image data, shape, and ROI meta information
168159
"""
169-
key = img.path if isinstance(img, ImageFromFile) else id(img)
170160
roi_meta = None
171-
# check if the image is already in the cache
172-
img_data, roi_meta = self.mem_cache_handler.get(key=key)
173-
if img_data is not None:
174-
return img_data, img_data.shape[:2], roi_meta
175161

176162
with image_decode_context():
177163
img_data = (
@@ -201,58 +187,8 @@ def _get_img_data_and_shape(
201187
img_data = img_data[y1:y2, x1:x2]
202188
roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)}
203189

204-
img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta)
205-
206190
return img_data, img_data.shape[:2], roi_meta
207191

208-
def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray:
209-
"""Cache an image after resizing.
210-
211-
If there is available space in the memory pool, the input image is cached.
212-
Before caching, the input image is resized if it is larger than the maximum image size
213-
specified by the memory caching handler.
214-
Otherwise, the input image is directly cached.
215-
After caching, the processed image data is returned.
216-
217-
Args:
218-
key: The key associated with the image.
219-
img_data: The image data to be cached.
220-
221-
Returns:
222-
The resized image if it was resized. Otherwise, the original image.
223-
"""
224-
if self.mem_cache_handler.frozen:
225-
return img_data
226-
227-
if self.mem_cache_img_max_size is None:
228-
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
229-
return img_data
230-
231-
height, width = img_data.shape[:2]
232-
max_height, max_width = self.mem_cache_img_max_size
233-
234-
if height <= max_height and width <= max_width:
235-
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
236-
return img_data
237-
238-
# Preserve the image size ratio and fit to max_height or max_width
239-
# e.g. (1000 / 2000 = 0.5, 1000 / 1000 = 1.0) => 0.5
240-
# h, w = 2000 * 0.5 => 1000, 1000 * 0.5 => 500, bounded by max_height
241-
min_scale = min(max_height / height, max_width / width)
242-
new_height, new_width = int(min_scale * height), int(min_scale * width)
243-
resized_img = cv2.resize(
244-
src=img_data,
245-
dsize=(new_width, new_height),
246-
interpolation=cv2.INTER_LINEAR,
247-
)
248-
249-
self.mem_cache_handler.put(
250-
key=key,
251-
data=resized_img,
252-
meta=meta,
253-
)
254-
return resized_img
255-
256192
@abstractmethod
257193
def _get_item_impl(self, idx: int) -> OTXDataItem | None:
258194
pass

src/otx/core/data/dataset/keypoint_detection.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torchvision.transforms.v2.functional import to_dtype, to_image
1616

1717
from otx.core.data.entity.base import ImageInfo
18-
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
1918
from otx.core.data.transform_libs.torchvision import Compose
2019
from otx.core.types.image import ImageColorChannel
2120
from otx.core.types.label import LabelInfo
@@ -33,8 +32,6 @@ def __init__(
3332
self,
3433
dm_subset: DatasetSubset,
3534
transforms: Transforms,
36-
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
37-
mem_cache_img_max_size: tuple[int, int] | None = None,
3835
max_refetch: int = 1000,
3936
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
4037
stack_images: bool = True,
@@ -44,8 +41,6 @@ def __init__(
4441
super().__init__(
4542
dm_subset,
4643
transforms,
47-
mem_cache_handler,
48-
mem_cache_img_max_size,
4944
max_refetch,
5045
image_color_channel,
5146
stack_images,

src/otx/core/data/dataset/segmentation.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torchvision.transforms.v2.functional import to_dtype, to_image
1616

1717
from otx.core.data.entity.base import ImageInfo
18-
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
1918
from otx.core.types.image import ImageColorChannel
2019
from otx.core.types.label import SegLabelInfo
2120
from otx.data.torch import OTXDataItem
@@ -161,8 +160,6 @@ def __init__(
161160
self,
162161
dm_subset: DmDataset,
163162
transforms: Transforms,
164-
mem_cache_handler: MemCacheHandlerBase = NULL_MEM_CACHE_HANDLER,
165-
mem_cache_img_max_size: tuple[int, int] | None = None,
166163
max_refetch: int = 1000,
167164
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
168165
to_tv_image: bool = True,
@@ -172,8 +169,6 @@ def __init__(
172169
super().__init__(
173170
dm_subset,
174171
transforms,
175-
mem_cache_handler,
176-
mem_cache_img_max_size,
177172
max_refetch,
178173
image_color_channel,
179174
to_tv_image,

src/otx/core/data/dataset/tile.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ def __init__(self, dataset: OTXDataset, tile_config: TileConfig) -> None:
288288
super().__init__(
289289
dataset.dm_subset,
290290
dataset.transforms,
291-
dataset.mem_cache_handler,
292-
dataset.mem_cache_img_max_size,
293291
dataset.max_refetch,
294292
dataset.image_color_channel,
295293
dataset.stack_images,

src/otx/core/data/factory.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from datumaro import Dataset as DmDataset
1818

1919
from otx.core.config.data import SubsetConfig
20-
from otx.core.data.mem_cache import MemCacheHandlerBase
2120

2221

2322
__all__ = ["TransformLibFactory", "OTXDatasetFactory"]
@@ -46,9 +45,7 @@ def create(
4645
task: OTXTaskType,
4746
dm_subset: DmDataset,
4847
cfg_subset: SubsetConfig,
49-
mem_cache_handler: MemCacheHandlerBase,
5048
data_format: str,
51-
mem_cache_img_max_size: tuple[int, int] | None = None,
5249
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
5350
include_polygons: bool = False,
5451
ignore_index: int = 255,
@@ -59,8 +56,6 @@ def create(
5956
"dm_subset": dm_subset,
6057
"transforms": transforms,
6158
"data_format": data_format,
62-
"mem_cache_handler": mem_cache_handler,
63-
"mem_cache_img_max_size": mem_cache_img_max_size,
6459
"image_color_channel": image_color_channel,
6560
"to_tv_image": cfg_subset.to_tv_image,
6661
}

0 commit comments

Comments
 (0)