Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
9f02114
fix(data): enable pin_memory for DataLoader instances across the code…
samet-akcay Jun 25, 2025
58da4c2
refactor(models): streamline decoder retrieval in function
samet-akcay Jun 25, 2025
6b07c77
fix(model): update logit_scale initialization to use torch.log for co…
samet-akcay Jun 25, 2025
c0525de
fix(model): update anomaly map generation to use torch tensors for ca…
samet-akcay Jun 25, 2025
cdf8450
refactor(model): enhance anomaly map generation with PyTorch for stat…
samet-akcay Jun 25, 2025
f43b2ea
chore(license): update license year
samet-akcay Jun 25, 2025
f237dcb
fix(download): enhance URL validation and update download logic
samet-akcay Jun 25, 2025
249bf24
Merge branch 'main' of github.com:open-edge-platform/anomalib
samet-akcay Jun 25, 2025
0fe6ec6
Merge branch 'main' of github.com:open-edge-platform/anomalib
samet-akcay Jul 7, 2025
5387953
Merge branch 'main' of github.com:open-edge-platform/anomalib
samet-akcay Jul 8, 2025
0c612a9
Merge branch 'main' of github.com:open-edge-platform/anomalib
samet-akcay Jul 9, 2025
dd5502a
Merge branch 'main' of github.com:open-edge-platform/anomalib
samet-akcay Jul 9, 2025
b69c785
πŸ”§ chore: update copyright year and add keys method to NumpyBatch and …
samet-akcay Sep 12, 2025
857961e
fix mypy
samet-akcay Sep 12, 2025
ca6c7ee
refactor: improve dataset name and category handling in ImageVisualizer
samet-akcay Sep 12, 2025
ea9b946
Merge branch 'main' into docs/data/add-backward-compatibility
samet-akcay Sep 12, 2025
2fd4b73
refactor: unify visualization method naming in ImageVisualizer
samet-akcay Sep 15, 2025
737dcaf
Merge branch 'docs/data/add-backward-compatibility' of github.com:sam…
samet-akcay Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion src/anomalib/data/dataclasses/numpy/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Numpy-based dataclasses for Anomalib.
Expand All @@ -13,6 +13,7 @@
"""

from dataclasses import dataclass
from typing import Any

import numpy as np

Expand Down Expand Up @@ -55,3 +56,57 @@ class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]):
Where ``B`` represents the batch dimension that is prepended to all
tensor-like fields.
"""

def keys(self, include_none: bool = True) -> list[str]:
"""Return a list of field names in the NumpyBatch.

Args:
include_none: If True, returns all possible field names including those
that are None. If False, returns only field names that have non-None values.
Defaults to True for backward compatibility.

Returns:
List of field names that can be accessed on this NumpyBatch instance.
When include_none=True, includes all fields from the input, output, and any
additional field classes that the specific batch type inherits from.
When include_none=False, includes only fields with actual data.

Example:
>>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3))
>>> all_keys = batch.keys() # Default: include_none=True
>>> 'pred_score' in all_keys # True (even though it's None)
True
>>> set_keys = batch.keys(include_none=False)
>>> 'pred_score' in set_keys # False (because it's None)
False
"""
from dataclasses import fields

if include_none:
return [field.name for field in fields(self)]

return [field.name for field in fields(self) if getattr(self, field.name) is not None]

def __getitem__(self, key: str) -> Any: # noqa: ANN401
"""Get a field value using dictionary-like syntax.

Args:
key: Field name to access.

Returns:
The value of the specified field.

Raises:
KeyError: If the field name is not found.

Example:
>>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3))
>>> batch["image"].shape
(2, 224, 224, 3)
>>> batch["gt_label"]
None
"""
if not hasattr(self, key):
msg = f"Field '{key}' not found in {self.__class__.__name__}"
raise KeyError(msg)
return getattr(self, key)
60 changes: 58 additions & 2 deletions src/anomalib/data/dataclasses/torch/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Torch-based dataclasses for Anomalib.
Expand All @@ -13,7 +13,7 @@

from collections.abc import Callable
from dataclasses import dataclass, fields
from typing import ClassVar, Generic, NamedTuple, TypeVar
from typing import Any, ClassVar, Generic, NamedTuple, TypeVar

import torch
from torchvision.tv_tensors import Mask
Expand Down Expand Up @@ -138,3 +138,59 @@ class Batch(Generic[ImageT], _GenericBatch[torch.Tensor, ImageT, Mask, list[str]
This class is typically subclassed to create more specific batch types
(e.g., ``ImageBatch``, ``VideoBatch``) with additional fields and methods.
"""

def keys(self, include_none: bool = True) -> list[str]:
"""Return a list of field names in the Batch.

Args:
include_none: If True, returns all possible field names including those
that are None. If False, returns only field names that have non-None values.
Defaults to True for backward compatibility.

Returns:
List of field names that can be accessed on this Batch instance.
When include_none=True, includes all fields from the input, output, and any
additional field classes that the specific batch type inherits from.
When include_none=False, includes only fields with actual data.

Example:
>>> # Using any batch subclass (e.g., ImageBatch)
>>> batch = Batch(image=torch.rand(2, 3, 224, 224))
>>> all_keys = batch.keys() # Default: include_none=True
>>> 'pred_score' in all_keys # True (even though it's None)
True
>>> set_keys = batch.keys(include_none=False)
>>> 'pred_score' in set_keys # False (because it's None)
False
"""
from dataclasses import fields

if include_none:
return [field.name for field in fields(self)]

return [field.name for field in fields(self) if getattr(self, field.name) is not None]

def __getitem__(self, key: str) -> Any: # noqa: ANN401
"""Get a field value using dictionary-like syntax.

Args:
key: Field name to access.

Returns:
The value of the specified field.

Raises:
KeyError: If the field name is not found.

Example:
>>> # Using any batch subclass (e.g., ImageBatch)
>>> batch = Batch(image=torch.rand(2, 3, 224, 224))
>>> batch["image"].shape
torch.Size([2, 3, 224, 224])
>>> batch["gt_label"]
None
"""
if not hasattr(self, key):
msg = f"Field '{key}' not found in {self.__class__.__name__}"
raise KeyError(msg)
return getattr(self, key)
4 changes: 1 addition & 3 deletions src/anomalib/visualization/image/item_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2024-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""ImageItem visualization module.
Expand Down Expand Up @@ -347,8 +347,6 @@ def visualize_image_item(
field_value = getattr(item, field, None)
if field_value is not None:
image = get_visualize_function(field)(field_value, **fields_config.get(field, {}))
else:
logger.warning(f"Field '{field}' is None in item. Skipping visualization.")
if image:
field_images[field] = image.resize(field_size)

Expand Down
162 changes: 159 additions & 3 deletions src/anomalib/visualization/image/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from PIL import Image

# Only import types during type checking to avoid circular imports
if TYPE_CHECKING:
from lightning.pytorch import Trainer

from anomalib.data import ImageBatch
from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem
from anomalib.models import AnomalibModule

from anomalib.utils.path import generate_output_filename
Expand Down Expand Up @@ -184,6 +186,157 @@ def __init__(
self.text_config = {**DEFAULT_TEXT_CONFIG, **(text_config or {})}
self.output_dir = output_dir

def visualize_image(
self,
predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch",
) -> Image.Image | list[Image.Image | None] | None:
"""Visualize image predictions.

This method visualizes anomaly detection predictions intelligently:
- For single items or single-item batches: returns a single image
- For multi-item batches: returns a list of images

Args:
predictions: The image prediction(s) to visualize. Can be:
- ``ImageItem``: Single torch-based image item
- ``NumpyImageItem``: Single numpy-based image item
- ``ImageBatch``: Batch of torch-based image items
- ``NumpyImageBatch``: Batch of numpy-based image items

Returns:
- For single items or single-item batches: ``Image.Image`` or ``None``
- For multi-item batches: ``list[Image.Image | None]``

Examples:
Visualize a torch-based image item:

>>> from anomalib.data import ImageItem
>>> import torch
>>> item = ImageItem(
... image=torch.rand(3, 224, 224),
... anomaly_map=torch.rand(224, 224),
... pred_mask=torch.rand(224, 224) > 0.5
... )
>>> visualizer = ImageVisualizer()
>>> result = visualizer.visualize_image(item)
>>> isinstance(result, Image.Image) or result is None
True

Visualize a numpy-based image item:

>>> from anomalib.data import NumpyImageItem
>>> import numpy as np
>>> item = NumpyImageItem(
... image=np.random.rand(224, 224, 3),
... anomaly_map=np.random.rand(224, 224),
... pred_mask=np.random.rand(224, 224) > 0.5
... )
>>> result = visualizer.visualize_image(item)
>>> isinstance(result, Image.Image) or result is None
True

Visualize a batch with one image (returns single image, not list):

>>> from anomalib.data import ImageBatch
>>> single_batch = ImageBatch(
... image=torch.rand(1, 3, 224, 224),
... anomaly_map=torch.rand(1, 224, 224)
... )
>>> result = visualizer.visualize_image(single_batch)
>>> isinstance(result, Image.Image) or result is None
True

Visualize a batch with multiple images (returns list):

>>> multi_batch = ImageBatch(
... image=torch.rand(3, 3, 224, 224),
... anomaly_map=torch.rand(3, 224, 224)
... )
>>> results = visualizer.visualize_image(multi_batch)
>>> isinstance(results, list) and len(results) == 3
True

Note:
- The method uses the same configuration (fields, overlays, etc.) as specified
during initialization of the ``ImageVisualizer``.
- If an item cannot be visualized (e.g., missing required fields), the
corresponding result will be ``None``.
- This method now behaves identically to the ``__call__`` method.
"""
# Import here to avoid circular imports
from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem

# Handle single items
if isinstance(predictions, (ImageItem, NumpyImageItem)):
return visualize_image_item(
predictions,
fields=self.fields,
overlay_fields=self.overlay_fields,
field_size=self.field_size,
fields_config=self.fields_config,
overlay_fields_config=self.overlay_fields_config,
text_config=self.text_config,
)

# Handle batches
if isinstance(predictions, (ImageBatch, NumpyImageBatch)):
batch_size = len(predictions)

# Single-item batch - return single image
if batch_size == 1:
image_item = next(iter(predictions))
return visualize_image_item(
image_item, # type: ignore[arg-type]
fields=self.fields,
overlay_fields=self.overlay_fields,
field_size=self.field_size,
fields_config=self.fields_config,
overlay_fields_config=self.overlay_fields_config,
text_config=self.text_config,
)

# Multi-item batch - return list of images
results = []
for image_item in predictions:
visualization = visualize_image_item(
image_item, # type: ignore[arg-type]
fields=self.fields,
overlay_fields=self.overlay_fields,
field_size=self.field_size,
fields_config=self.fields_config,
overlay_fields_config=self.overlay_fields_config,
text_config=self.text_config,
)
results.append(visualization)
return results

msg = (
f"Unsupported input type: {type(predictions)}. "
"Expected ImageItem, NumpyImageItem, ImageBatch, or NumpyImageBatch."
)
raise TypeError(msg)

def __call__(
self,
predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch",
) -> Image.Image | list[Image.Image | None] | None:
"""Make the visualizer callable.

This method allows the visualizer to be used as a callable object.
It behaves identically to the ``visualize_image`` method.

Args:
predictions: The predictions to visualize. Same as ``visualize_image``.

Returns:
Same as ``visualize_image`` method.

Examples:
>>> visualizer = ImageVisualizer()
>>> result = visualizer(predictions) # Equivalent to visualizer.visualize_image(predictions)
"""
return self.visualize_image(predictions)

def on_test_batch_end(
self,
trainer: "Trainer",
Expand Down Expand Up @@ -212,11 +365,14 @@ def on_test_batch_end(

if image is not None:
# Get the dataset name and category to save the image
datamodule = getattr(trainer, "datamodule", None)
dataset_name = getattr(datamodule, "name", "") if datamodule else ""
category = getattr(datamodule, "category", "") if datamodule else ""
filename = generate_output_filename(
input_path=item.image_path or "",
output_path=self.output_dir,
dataset_name=getattr(trainer.datamodule, "name", "") or "",
category=getattr(trainer.datamodule, "category", "") or "",
dataset_name=dataset_name or "",
category=category or "",
)

# Save the image to the specified filename
Expand Down
Loading