From 9f02114e64cc505c605674a7b5d5789ebe3efc6c Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 13:38:08 +0100 Subject: [PATCH 01/11] fix(data): enable pin_memory for DataLoader instances across the codebase This commit updates various DataLoader instances in the project to enable the option, enhancing performance for data loading on GPU. Changes were made in the following files: - : Updated train and test DataLoader configurations. - : Modified datamodule DataLoader to include . - : Added to evaluation DataLoader. - : Updated DataLoader for datasets to utilize . - : Enabled for reference dataset DataLoader. - : Adjusted inference DataLoader to include . These changes aim to optimize memory usage and improve data transfer speeds during model training and inference. Signed-off-by: samet-akcay --- examples/api/02_data/mvtecad2.py | 16 ++++++++++++++-- src/anomalib/cli/cli.py | 5 +---- src/anomalib/data/datamodules/image/mvtecad2.py | 1 + src/anomalib/engine/engine.py | 8 ++------ .../models/image/winclip/lightning_model.py | 2 +- tools/inference/lightning_inference.py | 2 +- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/api/02_data/mvtecad2.py b/examples/api/02_data/mvtecad2.py index f241983696..d8f2ed6080 100644 --- a/examples/api/02_data/mvtecad2.py +++ b/examples/api/02_data/mvtecad2.py @@ -93,8 +93,20 @@ ) # Create dataloaders -train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=train_dataset.collate_fn) -test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=test_dataset.collate_fn) +train_loader = DataLoader( + train_dataset, + batch_size=4, + shuffle=True, + collate_fn=train_dataset.collate_fn, + pin_memory=True, +) +test_loader = DataLoader( + test_dataset, + batch_size=4, + shuffle=False, + collate_fn=test_dataset.collate_fn, + pin_memory=True, +) # Get some sample images train_samples = next(iter(train_loader)) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 4601f9a85d..a06f797bfb 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -304,10 +304,7 @@ def instantiate_classes(self) -> None: self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self._get(self.config_init, "data") if isinstance(self.datamodule, Dataset): - # Let PyTorch handle pin_memory automatically - # This ensures optimal behavior for both CPU and GPU users - # nosemgrep: trailofbits.python.automatic-memory-pinning.automatic-memory-pinning # noqa: ERA001 - self.datamodule = DataLoader(self.datamodule, collate_fn=self.datamodule.collate_fn) + self.datamodule = DataLoader(self.datamodule, collate_fn=self.datamodule.collate_fn, pin_memory=True) self.model = self._get(self.config_init, "model") self._configure_optimizers_method_to_model() self.instantiate_engine() diff --git a/src/anomalib/data/datamodules/image/mvtecad2.py b/src/anomalib/data/datamodules/image/mvtecad2.py index 9f8ba5ac1f..9daff86f0b 100644 --- a/src/anomalib/data/datamodules/image/mvtecad2.py +++ b/src/anomalib/data/datamodules/image/mvtecad2.py @@ -257,4 +257,5 @@ def test_dataloader(self, test_type: str | TestType | None = None) -> EVAL_DATAL batch_size=self.eval_batch_size, num_workers=self.num_workers, collate_fn=dataset.collate_fn, + pin_memory=True, ) diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 0c31023467..cd71e6a241 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -648,14 +648,10 @@ def predict( msg = f"Unknown type for dataloaders {type(dataloaders)}" raise TypeError(msg) if dataset is not None: - # Let PyTorch handle pin_memory automatically - # This ensures optimal behavior for both CPU and GPU users - # nosemgrep: trailofbits.python.automatic-memory-pinning.automatic-memory-pinning # noqa: ERA001 - dataloaders.append(DataLoader(dataset, collate_fn=dataset.collate_fn)) + dataloaders.append(DataLoader(dataset, collate_fn=dataset.collate_fn, pin_memory=True)) if data_path is not None: dataset = PredictDataset(data_path) - # nosemgrep: trailofbits.python.automatic-memory-pinning.automatic-memory-pinning # noqa: ERA001 - dataloaders.append(DataLoader(dataset, collate_fn=dataset.collate_fn)) + dataloaders.append(DataLoader(dataset, collate_fn=dataset.collate_fn, pin_memory=True)) dataloaders = dataloaders or None if self._should_run_validation(model or self.model, ckpt_path): diff --git a/src/anomalib/models/image/winclip/lightning_model.py b/src/anomalib/models/image/winclip/lightning_model.py index fb6702e66c..ee38611ee3 100644 --- a/src/anomalib/models/image/winclip/lightning_model.py +++ b/src/anomalib/models/image/winclip/lightning_model.py @@ -149,7 +149,7 @@ def setup(self, stage: str) -> None: self.few_shot_source, transform=self.pre_processor.test_transform if self.pre_processor else None, ) - dataloader = DataLoader(reference_dataset, batch_size=1, shuffle=False) + dataloader = DataLoader(reference_dataset, batch_size=1, shuffle=False, pin_memory=True) else: logger.info("Collecting reference images from training dataset") dataloader = self.trainer.datamodule.train_dataloader() diff --git a/tools/inference/lightning_inference.py b/tools/inference/lightning_inference.py index f92c0d2ab6..3aee61f435 100644 --- a/tools/inference/lightning_inference.py +++ b/tools/inference/lightning_inference.py @@ -53,7 +53,7 @@ def infer(args: Namespace) -> None: # create the dataset dataset = PredictDataset(**args.data) - dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn) + dataloader = DataLoader(dataset, collate_fn=dataset.collate_fn, pin_memory=True) engine.predict(model=model, dataloaders=[dataloader], ckpt_path=args.ckpt_path) From 58da4c2eedebc8a6d22cc72bedb31faab988d1a0 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 13:43:28 +0100 Subject: [PATCH 02/11] refactor(models): streamline decoder retrieval in function This commit refactors the function in to utilize a dictionary mapping for decoder architectures, improving readability and maintainability. The previous conditional checks have been replaced with a more efficient approach, enhancing the overall structure of the code. Signed-off-by: samet-akcay --- .../components/de_resnet.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py b/src/anomalib/models/image/reverse_distillation/components/de_resnet.py index 44dcf4c0d6..a01c056dbc 100644 --- a/src/anomalib/models/image/reverse_distillation/components/de_resnet.py +++ b/src/anomalib/models/image/reverse_distillation/components/de_resnet.py @@ -495,18 +495,20 @@ def get_decoder(name: str) -> ResNet: Returns: ResNet: Decoder ResNet architecture. """ - if name in { - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x8d", - "wide_resnet50_2", - "wide_resnet101_2", - }: - decoder = globals()[f"de_{name}"] + decoder_map = { + "resnet18": de_resnet18, + "resnet34": de_resnet34, + "resnet50": de_resnet50, + "resnet101": de_resnet101, + "resnet152": de_resnet152, + "resnext50_32x4d": de_resnext50_32x4d, + "resnext101_32x8d": de_resnext101_32x8d, + "wide_resnet50_2": de_wide_resnet50_2, + "wide_resnet101_2": de_wide_resnet101_2, + } + + if name in decoder_map: + decoder = decoder_map[name] else: msg = f"Decoder with architecture {name} not supported" raise ValueError(msg) From 6b07c77e0af8a55befbb91cdea05870e008c5164 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 13:52:50 +0100 Subject: [PATCH 03/11] fix(model): update logit_scale initialization to use torch.log for consistency This commit modifies the initialization of the logit_scale parameter in the CLIP model to utilize torch.log instead of np.log. This change ensures consistency in tensor operations and improves compatibility with PyTorch's computation graph. Signed-off-by: samet-akcay --- src/anomalib/models/video/ai_vad/clip/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/video/ai_vad/clip/model.py b/src/anomalib/models/video/ai_vad/clip/model.py index 06ee36922e..664556533c 100644 --- a/src/anomalib/models/video/ai_vad/clip/model.py +++ b/src/anomalib/models/video/ai_vad/clip/model.py @@ -317,7 +317,7 @@ def __init__( self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07))) self.initialize_parameters() From c0525deafe46502049d6b2aeac00142681bd72ba Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 13:56:48 +0100 Subject: [PATCH 04/11] fix(model): update anomaly map generation to use torch tensors for calculations This commit modifies the anomaly map generation logic to utilize PyTorch tensors instead of NumPy arrays for various calculations. This change enhances compatibility with the PyTorch computation graph and improves performance by leveraging GPU acceleration. Key updates include the conversion of statistical calculations and tensor operations to use PyTorch functions, ensuring consistency in tensor handling throughout the code. Signed-off-by: samet-akcay --- src/anomalib/models/image/uflow/anomaly_map.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/anomalib/models/image/uflow/anomaly_map.py b/src/anomalib/models/image/uflow/anomaly_map.py index 4457bd17e5..03606f0d35 100644 --- a/src/anomalib/models/image/uflow/anomaly_map.py +++ b/src/anomalib/models/image/uflow/anomaly_map.py @@ -190,27 +190,27 @@ def binomial_test( torch.Tensor: Log probability tensor of shape ``(batch_size, 1, height, width)`` """ - tau = st.chi2.ppf(probability_thr, 1) - half_win = np.max([int(window_size // 2), 1]) + tau = torch.tensor(st.chi2.ppf(probability_thr, 1)) + half_win = max(int(window_size // 2), 1) n_chann = z.shape[1] # Candidates z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu() z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1) - z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1).numpy() - observed_candidates_k = np.sum(z2_unfold_hw >= tau, axis=(-2, -1)) + z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1) + observed_candidates_k = torch.sum(z2_unfold_hw >= tau, dim=(-2, -1)) # All volume together - observed_candidates = np.sum(observed_candidates_k, axis=1, keepdims=True) + observed_candidates = torch.sum(observed_candidates_k, dim=1, keepdim=True) x = observed_candidates / n_chann n = int((2 * half_win + 1) ** 2) # Low precision if not high_precision: - log_prob = torch.tensor(st.binom.logsf(x, n, 1 - probability_thr) / np.log(10)) - # High precision - good and slow + log_prob = torch.tensor(st.binom.logsf(x.numpy(), n, 1 - probability_thr) / torch.log(torch.tensor(10.0))) else: + # High precision - good and slow to_mp = np.frompyfunc(mp.mpf, 1, 1) mpn = mp.mpf(n) mpp = probability_thr @@ -222,7 +222,7 @@ def integral(tensor: torch.Tensor) -> torch.Tensor: return integrate.quad(binomial_density, tensor, n)[0] integral_array = np.vectorize(integral) - prob = integral_array(x) + prob = integral_array(x.numpy()) log_prob = torch.tensor(np.log10(prob)) return log_prob From cdf8450b089ab31ec729878ddd7ef3b9e94e1a49 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 15:24:24 +0100 Subject: [PATCH 05/11] refactor(model): enhance anomaly map generation with PyTorch for statistical calculations This commit refactors the anomaly map generation logic to replace NumPy-based statistical calculations with PyTorch equivalents, specifically using the distribution for computing tau. Additionally, it improves precision handling by allowing the use of float64 in high precision mode. The changes streamline the computation process and maintain compatibility with the PyTorch computation graph. Signed-off-by: samet-akcay --- .../models/image/uflow/anomaly_map.py | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/src/anomalib/models/image/uflow/anomaly_map.py b/src/anomalib/models/image/uflow/anomaly_map.py index 03606f0d35..bb4943107f 100644 --- a/src/anomalib/models/image/uflow/anomaly_map.py +++ b/src/anomalib/models/image/uflow/anomaly_map.py @@ -15,22 +15,19 @@ See Also: - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps - - :func:`compute_anomaly_map`: Function to generate anomaly maps from latents """ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import numpy as np +import math + import scipy.stats as st import torch import torch.nn.functional as F # noqa: N812 -from mpmath import binomial, mp from omegaconf import ListConfig -from scipy import integrate from torch import Tensor, nn - -mp.dps = 15 # Set precision for NFA computation (in case of high_precision=True) +from torch.distributions import Normal class AnomalyMapGenerator(nn.Module): @@ -61,7 +58,6 @@ class AnomalyMapGenerator(nn.Module): torch.Size([1, 1, 256, 256]) See Also: - - :func:`compute_anomaly_map`: Main method for likelihood-based maps - :func:`compute_anomaly_mask`: Optional method for NFA-based segmentation """ @@ -190,13 +186,21 @@ def binomial_test( torch.Tensor: Log probability tensor of shape ``(batch_size, 1, height, width)`` """ - tau = torch.tensor(st.chi2.ppf(probability_thr, 1)) + # Calculate tau using pure PyTorch + normal_dist = Normal(0, 1) + p_adjusted = (probability_thr + 1) / 2 + tau = normal_dist.icdf(torch.tensor(p_adjusted)) ** 2 half_win = max(int(window_size // 2), 1) n_chann = z.shape[1] + # Use float64 for high precision mode + dtype = torch.float64 if high_precision else torch.float32 + z = z.to(dtype) + tau = tau.to(dtype) + # Candidates - z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu() + z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect") z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1) z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1) observed_candidates_k = torch.sum(z2_unfold_hw >= tau, dim=(-2, -1)) @@ -206,23 +210,9 @@ def binomial_test( x = observed_candidates / n_chann n = int((2 * half_win + 1) ** 2) - # Low precision - if not high_precision: - log_prob = torch.tensor(st.binom.logsf(x.numpy(), n, 1 - probability_thr) / torch.log(torch.tensor(10.0))) - else: - # High precision - good and slow - to_mp = np.frompyfunc(mp.mpf, 1, 1) - mpn = mp.mpf(n) - mpp = probability_thr - - def binomial_density(tensor: torch.tensor) -> torch.Tensor: - return binomial(mpn, to_mp(tensor)) * (1 - mpp) ** tensor * mpp ** (mpn - tensor) - - def integral(tensor: torch.Tensor) -> torch.Tensor: - return integrate.quad(binomial_density, tensor, n)[0] - - integral_array = np.vectorize(integral) - prob = integral_array(x.numpy()) - log_prob = torch.tensor(np.log10(prob)) + # Use scipy for the binomial test as PyTorch does not have a stable/direct equivalent. + # nosemgrep: trailofbits.python.numpy-in-pytorch-modules.numpy-in-pytorch-modules + x_np = x.detach().cpu().numpy() + log_prob_np = st.binom.logsf(x_np, n, 1 - probability_thr) / math.log(10) - return log_prob + return torch.from_numpy(log_prob_np).to(z.device) From f43b2ea199af0b0ea8af72ebebf89e66e5dd3cad Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 15:25:14 +0100 Subject: [PATCH 06/11] chore(license): update license year Signed-off-by: samet-akcay --- src/anomalib/models/image/uflow/anomaly_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/image/uflow/anomaly_map.py b/src/anomalib/models/image/uflow/anomaly_map.py index bb4943107f..f27723d10b 100644 --- a/src/anomalib/models/image/uflow/anomaly_map.py +++ b/src/anomalib/models/image/uflow/anomaly_map.py @@ -17,7 +17,7 @@ - :class:`AnomalyMapGenerator`: Main class for generating anomaly maps """ -# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2023-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import math From f237dcb1c3751869d4687cd635765753f3dde702 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Wed, 25 Jun 2025 15:35:31 +0100 Subject: [PATCH 07/11] fix(download): enhance URL validation and update download logic This commit improves the URL validation in the download function to ensure only http and https schemes are allowed. Additionally, it adds comments to clarify the safety of using under these conditions, enhancing code readability and security awareness. Signed-off-by: samet-akcay --- src/anomalib/data/utils/download.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/data/utils/download.py b/src/anomalib/data/utils/download.py index 050ce177e2..1d2ac472b0 100644 --- a/src/anomalib/data/utils/download.py +++ b/src/anomalib/data/utils/download.py @@ -309,6 +309,7 @@ def download_and_extract(root: Path, info: DownloadInfo) -> None: # audit url. allowing only http:// or https:// if info.url.startswith("http://") or info.url.startswith("https://"): with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=info.name) as progress_bar: + # nosemgrep: python.lang.security.audit.dynamic-urllib-use-detected.dynamic-urllib-use-detected # noqa: ERA001, E501 urlretrieve( # noqa: S310 # nosec B310 url=f"{info.url}", filename=downloaded_file_path, From b69c785d5a74db402cc4cd1b6158180979795718 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Fri, 12 Sep 2025 12:14:46 +0100 Subject: [PATCH 08/11] =?UTF-8?q?=F0=9F=94=A7=20chore:=20update=20copyrigh?= =?UTF-8?q?t=20year=20and=20add=20keys=20method=20to=20NumpyBatch=20and=20?= =?UTF-8?q?Batch=20classes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit updates the copyright year in the Numpy and Torch base files to reflect 2024-2025. Additionally, it introduces a new `keys` method in both the `NumpyBatch` and `Batch` classes, allowing users to retrieve field names with an option to include or exclude fields with None values. The `__getitem__` method is also added to enable dictionary-like access to field values. Signed-off-by: [Your Name] <[Your Email]> --- src/anomalib/data/dataclasses/numpy/base.py | 57 +++++++++++++++++++- src/anomalib/data/dataclasses/torch/base.py | 60 ++++++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/src/anomalib/data/dataclasses/numpy/base.py b/src/anomalib/data/dataclasses/numpy/base.py index 0c03b4f8ae..248230284f 100644 --- a/src/anomalib/data/dataclasses/numpy/base.py +++ b/src/anomalib/data/dataclasses/numpy/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Numpy-based dataclasses for Anomalib. @@ -13,6 +13,7 @@ """ from dataclasses import dataclass +from typing import Any import numpy as np @@ -55,3 +56,57 @@ class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]): Where ``B`` represents the batch dimension that is prepended to all tensor-like fields. """ + + def keys(self, include_none: bool = True) -> list[str]: + """Return a list of field names in the NumpyBatch. + + Args: + include_none: If True, returns all possible field names including those + that are None. If False, returns only field names that have non-None values. + Defaults to True for backward compatibility. + + Returns: + List of field names that can be accessed on this NumpyBatch instance. + When include_none=True, includes all fields from the input, output, and any + additional field classes that the specific batch type inherits from. + When include_none=False, includes only fields with actual data. + + Example: + >>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3)) + >>> all_keys = batch.keys() # Default: include_none=True + >>> 'pred_score' in all_keys # True (even though it's None) + True + >>> set_keys = batch.keys(include_none=False) + >>> 'pred_score' in set_keys # False (because it's None) + False + """ + from dataclasses import fields + + if include_none: + return [field.name for field in fields(self)] + + return [field.name for field in fields(self) if getattr(self, field.name) is not None] + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + """Get a field value using dictionary-like syntax. + + Args: + key: Field name to access. + + Returns: + The value of the specified field. + + Raises: + KeyError: If the field name is not found. + + Example: + >>> batch = NumpyBatch(image=np.random.rand(2, 224, 224, 3)) + >>> batch["image"].shape + (2, 224, 224, 3) + >>> batch["gt_label"] + None + """ + if not hasattr(self, key): + msg = f"Field '{key}' not found in {self.__class__.__name__}" + raise KeyError(msg) + return getattr(self, key) diff --git a/src/anomalib/data/dataclasses/torch/base.py b/src/anomalib/data/dataclasses/torch/base.py index b0260f6ef8..51bf0b0433 100644 --- a/src/anomalib/data/dataclasses/torch/base.py +++ b/src/anomalib/data/dataclasses/torch/base.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Torch-based dataclasses for Anomalib. @@ -13,7 +13,7 @@ from collections.abc import Callable from dataclasses import dataclass, fields -from typing import ClassVar, Generic, NamedTuple, TypeVar +from typing import Any, ClassVar, Generic, NamedTuple, TypeVar import torch from torchvision.tv_tensors import Mask @@ -138,3 +138,59 @@ class Batch(Generic[ImageT], _GenericBatch[torch.Tensor, ImageT, Mask, list[str] This class is typically subclassed to create more specific batch types (e.g., ``ImageBatch``, ``VideoBatch``) with additional fields and methods. """ + + def keys(self, include_none: bool = True) -> list[str]: + """Return a list of field names in the Batch. + + Args: + include_none: If True, returns all possible field names including those + that are None. If False, returns only field names that have non-None values. + Defaults to True for backward compatibility. + + Returns: + List of field names that can be accessed on this Batch instance. + When include_none=True, includes all fields from the input, output, and any + additional field classes that the specific batch type inherits from. + When include_none=False, includes only fields with actual data. + + Example: + >>> # Using any batch subclass (e.g., ImageBatch) + >>> batch = Batch(image=torch.rand(2, 3, 224, 224)) + >>> all_keys = batch.keys() # Default: include_none=True + >>> 'pred_score' in all_keys # True (even though it's None) + True + >>> set_keys = batch.keys(include_none=False) + >>> 'pred_score' in set_keys # False (because it's None) + False + """ + from dataclasses import fields + + if include_none: + return [field.name for field in fields(self)] + + return [field.name for field in fields(self) if getattr(self, field.name) is not None] + + def __getitem__(self, key: str) -> Any: # noqa: ANN401 + """Get a field value using dictionary-like syntax. + + Args: + key: Field name to access. + + Returns: + The value of the specified field. + + Raises: + KeyError: If the field name is not found. + + Example: + >>> # Using any batch subclass (e.g., ImageBatch) + >>> batch = Batch(image=torch.rand(2, 3, 224, 224)) + >>> batch["image"].shape + torch.Size([2, 3, 224, 224]) + >>> batch["gt_label"] + None + """ + if not hasattr(self, key): + msg = f"Field '{key}' not found in {self.__class__.__name__}" + raise KeyError(msg) + return getattr(self, key) From 857961eb130c6797cffa9c452dbf648380271d89 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Fri, 12 Sep 2025 12:20:30 +0100 Subject: [PATCH 09/11] fix mypy Signed-off-by: samet-akcay --- .../visualization/image/item_visualizer.py | 4 +- .../visualization/image/visualizer.py | 162 +++++++++++++++++- 2 files changed, 160 insertions(+), 6 deletions(-) diff --git a/src/anomalib/visualization/image/item_visualizer.py b/src/anomalib/visualization/image/item_visualizer.py index fe1666379d..2be8de4cc1 100644 --- a/src/anomalib/visualization/image/item_visualizer.py +++ b/src/anomalib/visualization/image/item_visualizer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """ImageItem visualization module. @@ -347,8 +347,6 @@ def visualize_image_item( field_value = getattr(item, field, None) if field_value is not None: image = get_visualize_function(field)(field_value, **fields_config.get(field, {})) - else: - logger.warning(f"Field '{field}' is None in item. Skipping visualization.") if image: field_images[field] = image.resize(field_size) diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index 64057b2df4..277a67835b 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -32,11 +32,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +from PIL import Image + # Only import types during type checking to avoid circular imports if TYPE_CHECKING: from lightning.pytorch import Trainer - from anomalib.data import ImageBatch + from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem from anomalib.models import AnomalibModule from anomalib.utils.path import generate_output_filename @@ -184,6 +186,157 @@ def __init__( self.text_config = {**DEFAULT_TEXT_CONFIG, **(text_config or {})} self.output_dir = output_dir + def visualize_image( + self, + predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch", + ) -> Image.Image | list[Image.Image | None] | None: + """Visualize image predictions. + + This method visualizes anomaly detection predictions intelligently: + - For single items or single-item batches: returns a single image + - For multi-item batches: returns a list of images + + Args: + predictions: The image prediction(s) to visualize. Can be: + - ``ImageItem``: Single torch-based image item + - ``NumpyImageItem``: Single numpy-based image item + - ``ImageBatch``: Batch of torch-based image items + - ``NumpyImageBatch``: Batch of numpy-based image items + + Returns: + - For single items or single-item batches: ``Image.Image`` or ``None`` + - For multi-item batches: ``list[Image.Image | None]`` + + Examples: + Visualize a torch-based image item: + + >>> from anomalib.data import ImageItem + >>> import torch + >>> item = ImageItem( + ... image=torch.rand(3, 224, 224), + ... anomaly_map=torch.rand(224, 224), + ... pred_mask=torch.rand(224, 224) > 0.5 + ... ) + >>> visualizer = ImageVisualizer() + >>> result = visualizer.visualize_image(item) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a numpy-based image item: + + >>> from anomalib.data import NumpyImageItem + >>> import numpy as np + >>> item = NumpyImageItem( + ... image=np.random.rand(224, 224, 3), + ... anomaly_map=np.random.rand(224, 224), + ... pred_mask=np.random.rand(224, 224) > 0.5 + ... ) + >>> result = visualizer.visualize_image(item) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a batch with one image (returns single image, not list): + + >>> from anomalib.data import ImageBatch + >>> single_batch = ImageBatch( + ... image=torch.rand(1, 3, 224, 224), + ... anomaly_map=torch.rand(1, 224, 224) + ... ) + >>> result = visualizer.visualize_image(single_batch) + >>> isinstance(result, Image.Image) or result is None + True + + Visualize a batch with multiple images (returns list): + + >>> multi_batch = ImageBatch( + ... image=torch.rand(3, 3, 224, 224), + ... anomaly_map=torch.rand(3, 224, 224) + ... ) + >>> results = visualizer.visualize_image(multi_batch) + >>> isinstance(results, list) and len(results) == 3 + True + + Note: + - The method uses the same configuration (fields, overlays, etc.) as specified + during initialization of the ``ImageVisualizer``. + - If an item cannot be visualized (e.g., missing required fields), the + corresponding result will be ``None``. + - This method now behaves identically to the ``__call__`` method. + """ + # Import here to avoid circular imports + from anomalib.data import ImageBatch, ImageItem, NumpyImageBatch, NumpyImageItem + + # Handle single items + if isinstance(predictions, (ImageItem, NumpyImageItem)): + return visualize_image_item( + predictions, + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + + # Handle batches + if isinstance(predictions, (ImageBatch, NumpyImageBatch)): + batch_size = len(predictions) + + # Single-item batch - return single image + if batch_size == 1: + image_item = next(iter(predictions)) + return visualize_image_item( + image_item, # type: ignore[arg-type] + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + + # Multi-item batch - return list of images + results = [] + for image_item in predictions: + visualization = visualize_image_item( + image_item, # type: ignore[arg-type] + fields=self.fields, + overlay_fields=self.overlay_fields, + field_size=self.field_size, + fields_config=self.fields_config, + overlay_fields_config=self.overlay_fields_config, + text_config=self.text_config, + ) + results.append(visualization) + return results + + msg = ( + f"Unsupported input type: {type(predictions)}. " + "Expected ImageItem, NumpyImageItem, ImageBatch, or NumpyImageBatch." + ) + raise TypeError(msg) + + def __call__( + self, + predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch", + ) -> Image.Image | list[Image.Image | None] | None: + """Make the visualizer callable. + + This method allows the visualizer to be used as a callable object. + It behaves identically to the ``visualize_image`` method. + + Args: + predictions: The predictions to visualize. Same as ``visualize_image``. + + Returns: + Same as ``visualize_image`` method. + + Examples: + >>> visualizer = ImageVisualizer() + >>> result = visualizer(predictions) # Equivalent to visualizer.visualize_image(predictions) + """ + return self.visualize_image(predictions) + def on_test_batch_end( self, trainer: "Trainer", @@ -212,11 +365,14 @@ def on_test_batch_end( if image is not None: # Get the dataset name and category to save the image + datamodule = getattr(trainer, "datamodule", None) + dataset_name = getattr(datamodule, "name", "") if datamodule else "" + category = getattr(datamodule, "category", "") if datamodule else "" filename = generate_output_filename( input_path=item.image_path or "", output_path=self.output_dir, - dataset_name=getattr(trainer.datamodule, "name", "") or "", - category=getattr(trainer.datamodule, "category", "") or "", + dataset_name=dataset_name or "", + category=category or "", ) # Save the image to the specified filename From ca6c7ee6068778c0e4883c7003e623f362f6f885 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Fri, 12 Sep 2025 13:58:39 +0100 Subject: [PATCH 10/11] refactor: improve dataset name and category handling in ImageVisualizer This commit updates the ImageVisualizer class to handle dataset name and category attributes more robustly by using None as the default value instead of empty strings. This change enhances clarity and ensures that the filename generation logic remains consistent when these attributes are not present. Signed-off-by: Samet Akcay samet.akcay@intel.com --- src/anomalib/visualization/image/visualizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index 277a67835b..dd9ed8f086 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -366,13 +366,13 @@ def on_test_batch_end( if image is not None: # Get the dataset name and category to save the image datamodule = getattr(trainer, "datamodule", None) - dataset_name = getattr(datamodule, "name", "") if datamodule else "" - category = getattr(datamodule, "category", "") if datamodule else "" + dataset_name = getattr(datamodule, "name", None) if datamodule else None + category = getattr(datamodule, "category", None) if datamodule else None filename = generate_output_filename( input_path=item.image_path or "", output_path=self.output_dir, - dataset_name=dataset_name or "", - category=category or "", + dataset_name=dataset_name, + category=category, ) # Save the image to the specified filename From 2fd4b730bed6037ad696e25a290ee65d18446c70 Mon Sep 17 00:00:00 2001 From: samet-akcay Date: Mon, 15 Sep 2025 11:47:40 +0100 Subject: [PATCH 11/11] refactor: unify visualization method naming in ImageVisualizer This commit refactors the ImageVisualizer class by renaming the `visualize_image` method to `visualize`, streamlining the interface for users. The updated method signature and examples in the docstring reflect this change, ensuring consistency in how predictions are visualized. Additionally, the copyright year has been updated to 2024-2025. Signed-off-by: Samet Akcay --- .../visualization/image/visualizer.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/anomalib/visualization/image/visualizer.py b/src/anomalib/visualization/image/visualizer.py index dd9ed8f086..c4fa568ddd 100644 --- a/src/anomalib/visualization/image/visualizer.py +++ b/src/anomalib/visualization/image/visualizer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Image visualization module for anomaly detection. @@ -16,7 +16,7 @@ >>> # Create visualizer with default settings >>> visualizer = ImageVisualizer() >>> # Generate visualization - >>> vis_result = visualizer.visualize(image=img, pred_mask=mask) + >>> vis_result = visualizer.visualize(predictions) The module ensures consistent visualization by: - Providing standardized field configurations @@ -186,7 +186,7 @@ def __init__( self.text_config = {**DEFAULT_TEXT_CONFIG, **(text_config or {})} self.output_dir = output_dir - def visualize_image( + def visualize( self, predictions: "ImageItem | NumpyImageItem | ImageBatch | NumpyImageBatch", ) -> Image.Image | list[Image.Image | None] | None: @@ -218,7 +218,7 @@ def visualize_image( ... pred_mask=torch.rand(224, 224) > 0.5 ... ) >>> visualizer = ImageVisualizer() - >>> result = visualizer.visualize_image(item) + >>> result = visualizer.visualize(item) >>> isinstance(result, Image.Image) or result is None True @@ -231,7 +231,7 @@ def visualize_image( ... anomaly_map=np.random.rand(224, 224), ... pred_mask=np.random.rand(224, 224) > 0.5 ... ) - >>> result = visualizer.visualize_image(item) + >>> result = visualizer.visualize(item) >>> isinstance(result, Image.Image) or result is None True @@ -242,7 +242,7 @@ def visualize_image( ... image=torch.rand(1, 3, 224, 224), ... anomaly_map=torch.rand(1, 224, 224) ... ) - >>> result = visualizer.visualize_image(single_batch) + >>> result = visualizer.visualize(single_batch) >>> isinstance(result, Image.Image) or result is None True @@ -252,7 +252,7 @@ def visualize_image( ... image=torch.rand(3, 3, 224, 224), ... anomaly_map=torch.rand(3, 224, 224) ... ) - >>> results = visualizer.visualize_image(multi_batch) + >>> results = visualizer.visualize(multi_batch) >>> isinstance(results, list) and len(results) == 3 True @@ -323,19 +323,19 @@ def __call__( """Make the visualizer callable. This method allows the visualizer to be used as a callable object. - It behaves identically to the ``visualize_image`` method. + It behaves identically to the ``visualize`` method. Args: - predictions: The predictions to visualize. Same as ``visualize_image``. + predictions: The predictions to visualize. Same as ``visualize``. Returns: - Same as ``visualize_image`` method. + Same as ``visualize`` method. Examples: >>> visualizer = ImageVisualizer() - >>> result = visualizer(predictions) # Equivalent to visualizer.visualize_image(predictions) + >>> result = visualizer(predictions) # Equivalent to visualizer.visualize(predictions) """ - return self.visualize_image(predictions) + return self.visualize(predictions) def on_test_batch_end( self,