diff --git a/src/anomalib/data/dataclasses/numpy/base.py b/src/anomalib/data/dataclasses/numpy/base.py index 0c03b4f8ae..248230284f 100644 --- a/src/anomalib/data/dataclasses/numpy/base.py +++ b/src/anomalib/data/dataclasses/numpy/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Numpy-based dataclasses for Anomalib. @@ -13,6 +13,7 @@ """ from dataclasses import dataclass +from typing import Any import numpy as np @@ -55,3 +56,57 @@ class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]): Where ``B`` represents the batch dimension that is prepended to all tensor-like fields. """ + + def keys(self, include_none: bool = True) -> list[str]: + """Return a list of field names in the NumpyBatch. + + Args: + include_none: If True, returns all possible field names including those + that are None. If False, returns only field names that have non-None values. + Defaults to True for backward compatibility. + + Returns: + List of field names that can be accessed on this NumpyBatch instance. + When include_none=True, includes all fields from the input, output, and any + additional field classes that the specific batch type inherits from. + When include_none=False, includes only fields with actual data. + + Example: + >>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3)) + >>> all_keys = batch.keys() # Default: include_none=True + >>> 'pred_score' in all_keys # True (even though it's None) + True + >>> set_keys = batch.keys(include_none=False) + >>> 'pred_score' in set_keys # False (because it's None) + False + """ + from dataclasses import fields + + if include_none: + return [field.name for field in fields(self)] + + return [field.name for field in fields(self) if getattr(self, field.name) is not None] + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + """Get a field value using dictionary-like syntax. + + Args: + key: Field name to access. + + Returns: + The value of the specified field. + + Raises: + KeyError: If the field name is not found. + + Example: + >>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3)) + >>> batch["image"].shape + (2, 224, 224, 3) + >>> batch["gt_label"] + None + """ + if not hasattr(self, key): + msg = f"Field '{key}' not found in {self.__class__.__name__}" + raise KeyError(msg) + return getattr(self, key) diff --git a/src/anomalib/data/dataclasses/torch/base.py b/src/anomalib/data/dataclasses/torch/base.py index b0260f6ef8..51bf0b0433 100644 --- a/src/anomalib/data/dataclasses/torch/base.py +++ b/src/anomalib/data/dataclasses/torch/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Torch-based dataclasses for Anomalib. @@ -13,7 +13,7 @@ from collections.abc import Callable from dataclasses import dataclass, fields -from typing import ClassVar, Generic, NamedTuple, TypeVar +from typing import Any, ClassVar, Generic, NamedTuple, TypeVar import torch from torchvision.tv_tensors import Mask @@ -138,3 +138,59 @@ class Batch(Generic[ImageT], _GenericBatch[torch.Tensor, ImageT, Mask, list[str] This class is typically subclassed to create more specific batch types (e.g., ``ImageBatch``, ``VideoBatch``) with additional fields and methods. """ + + def keys(self, include_none: bool = True) -> list[str]: + """Return a list of field names in the Batch. + + Args: + include_none: If True, returns all possible field names including those + that are None. If False, returns only field names that have non-None values. + Defaults to True for backward compatibility. + + Returns: + List of field names that can be accessed on this Batch instance. + When include_none=True, includes all fields from the input, output, and any + additional field classes that the specific batch type inherits from. + When include_none=False, includes only fields with actual data. + + Example: + >>> # Using any batch subclass (e.g., ImageBatch) + >>> batch = Batch(image=torch.rand(2, 3, 224, 224)) + >>> all_keys = batch.keys() # Default: include_none=True + >>> 'pred_score' in all_keys # True (even though it's None) + True + >>> set_keys = batch.keys(include_none=False) + >>> 'pred_score' in set_keys # False (because it's None) + False + """ + from dataclasses import fields + + if include_none: + return [field.name for field in fields(self)] + + return [field.name for field in fields(self) if getattr(self, field.name) is not None] + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + """Get a field value using dictionary-like syntax. + + Args: + key: Field name to access. + + Returns: + The value of the specified field. + + Raises: + KeyError: If the field name is not found. + + Example: + >>> # Using any batch subclass (e.g., ImageBatch) + >>> batch = Batch(image=torch.rand(2, 3, 224, 224)) + >>> batch["image"].shape + torch.Size([2, 3, 224, 224]) + >>> batch["gt_label"] + None + """ + if not hasattr(self, key): + msg = f"Field '{key}' not found in {self.__class__.__name__}" + raise KeyError(msg) + return getattr(self, key) diff --git a/src/anomalib/visualization/image/item_visualizer.py b/src/anomalib/visualization/image/item_visualizer.py index fe1666379d..2be8de4cc1 100644 --- a/src/anomalib/visualization/image/item_visualizer.py +++ b/src/anomalib/visualization/image/item_visualizer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """ImageItem visualization module. @@ -347,8 +347,6 @@ def visualize_image_item( field_value = getattr(item, field, None) if field_value is not None: image = get_visualize_function(field)(field_value, **fields_config.get(field, {})) - else: - logger.warning(f"Field '{field}' is None in item. Skipping visualization.") if image: field_images[field] = image.resize(field_size) diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index 64057b2df4..c4fa568ddd 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Image visualization module for anomaly detection. @@ -16,7 +16,7 @@ >>> # Create visualizer with default settings >>> visualizer = ImageVisualizer() >>> # Generate visualization - >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + >>> vis_result = visualizer.visualize(predictions) The module ensures consistent visualization by: - Providing standardized field configurations @@ -32,11 +32,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from PIL import Image + # Only import types during type checking to avoid circular imports if TYPE_CHECKING: from lightning.pytorch import Trainer - from anomalib.data import ImageBatch + from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem from anomalib.models import AnomalibModule from anomalib.utils.path import generate_output_filename @@ -184,6 +186,157 @@ def __init__( self.text_config = {**DEFAULT_TEXT_CONFIG, **(text_config or {})} self.output_dir = output_dir + def visualize( + self, + predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch", + ) -> Image.Image | list[Image.Image | None] | None: + """Visualize image predictions. + + This method visualizes anomaly detection predictions intelligently: + - For single items or single-item batches: returns a single image + - For multi-item batches: returns a list of images + + Args: + predictions: The image prediction(s) to visualize. Can be: + - ``ImageItem``: Single torch-based image item + - ``NumpyImageItem``: Single numpy-based image item + - ``ImageBatch``: Batch of torch-based image items + - ``NumpyImageBatch``: Batch of numpy-based image items + + Returns: + - For single items or single-item batches: ``Image.Image`` or ``None`` + - For multi-item batches: ``list[Image.Image | None]`` + + Examples: + Visualize a torch-based image item: + + >>> from anomalib.data import ImageItem + >>> import torch + >>> item = ImageItem( + ... image=torch.rand(3, 224, 224), + ... anomaly_map=torch.rand(224, 224), + ... pred_mask=torch.rand(224, 224) > 0.5 + ... ) + >>> visualizer = ImageVisualizer() + >>> result = visualizer.visualize(item) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a numpy-based image item: + + >>> from anomalib.data import NumpyImageItem + >>> import numpy as np + >>> item = NumpyImageItem( + ... image=np.random.rand(224, 224, 3), + ... anomaly_map=np.random.rand(224, 224), + ... pred_mask=np.random.rand(224, 224) > 0.5 + ... ) + >>> result = visualizer.visualize(item) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a batch with one image (returns single image, not list): + + >>> from anomalib.data import ImageBatch + >>> single_batch = ImageBatch( + ... image=torch.rand(1, 3, 224, 224), + ... anomaly_map=torch.rand(1, 224, 224) + ... ) + >>> result = visualizer.visualize(single_batch) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a batch with multiple images (returns list): + + >>> multi_batch = ImageBatch( + ... image=torch.rand(3, 3, 224, 224), + ... anomaly_map=torch.rand(3, 224, 224) + ... ) + >>> results = visualizer.visualize(multi_batch) + >>> isinstance(results, list) and len(results) == 3 + True + + Note: + - The method uses the same configuration (fields, overlays, etc.) as specified + during initialization of the ``ImageVisualizer``. + - If an item cannot be visualized (e.g., missing required fields), the + corresponding result will be ``None``. + - This method now behaves identically to the ``__call__`` method. + """ + # Import here to avoid circular imports + from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem + + # Handle single items + if isinstance(predictions, (ImageItem, NumpyImageItem)): + return visualize_image_item( + predictions, + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + + # Handle batches + if isinstance(predictions, (ImageBatch, NumpyImageBatch)): + batch_size = len(predictions) + + # Single-item batch - return single image + if batch_size == 1: + image_item = next(iter(predictions)) + return visualize_image_item( + image_item, # type: ignore[arg-type] + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + + # Multi-item batch - return list of images + results = [] + for image_item in predictions: + visualization = visualize_image_item( + image_item, # type: ignore[arg-type] + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + results.append(visualization) + return results + + msg = ( + f"Unsupported input type: {type(predictions)}. " + "Expected ImageItem, NumpyImageItem, ImageBatch, or NumpyImageBatch." + ) + raise TypeError(msg) + + def __call__( + self, + predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch", + ) -> Image.Image | list[Image.Image | None] | None: + """Make the visualizer callable. + + This method allows the visualizer to be used as a callable object. + It behaves identically to the ``visualize`` method. + + Args: + predictions: The predictions to visualize. Same as ``visualize``. + + Returns: + Same as ``visualize`` method. + + Examples: + >>> visualizer = ImageVisualizer() + >>> result = visualizer(predictions) # Equivalent to visualizer.visualize(predictions) + """ + return self.visualize(predictions) + def on_test_batch_end( self, trainer: "Trainer", @@ -212,11 +365,14 @@ def on_test_batch_end( if image is not None: # Get the dataset name and category to save the image + datamodule = getattr(trainer, "datamodule", None) + dataset_name = getattr(datamodule, "name", None) if datamodule else None + category = getattr(datamodule, "category", None) if datamodule else None filename = generate_output_filename( input_path=item.image_path or "", output_path=self.output_dir, - dataset_name=getattr(trainer.datamodule, "name", "") or "", - category=getattr(trainer.datamodule, "category", "") or "", + dataset_name=dataset_name, + category=category, ) # Save the image to the specified filename