From 7d9bd805f1e4f96a4cdb9b53f63ad49c2480a647 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 4 Nov 2025 20:13:46 +0000 Subject: [PATCH] Add pix2struct fast image processor --- docs/source/en/model_doc/pix2struct.md | 5 + .../cli/add_fast_image_processor.py | 2 - .../image_processing_utils_fast.py | 123 +++++ .../models/auto/image_processing_auto.py | 2 +- .../models/pix2struct/__init__.py | 1 + .../image_processing_pix2struct_fast.py | 339 ++++++++++++ src/transformers/utils/auto_docstring.py | 8 + .../test_image_processing_pix2struct.py | 486 +++++++++++------- 8 files changed, 763 insertions(+), 203 deletions(-) create mode 100644 src/transformers/models/pix2struct/image_processing_pix2struct_fast.py diff --git a/docs/source/en/model_doc/pix2struct.md b/docs/source/en/model_doc/pix2struct.md index 6894ba7bb593..6a68b6381a01 100644 --- a/docs/source/en/model_doc/pix2struct.md +++ b/docs/source/en/model_doc/pix2struct.md @@ -65,6 +65,11 @@ The original code can be found [here](https://github.com/google-research/pix2str [[autodoc]] Pix2StructImageProcessor - preprocess +## Pix2StructImageProcessorFast + +[[autodoc]] Pix2StructImageProcessorFast + - preprocess + ## Pix2StructTextModel [[autodoc]] Pix2StructTextModel diff --git a/src/transformers/cli/add_fast_image_processor.py b/src/transformers/cli/add_fast_image_processor.py index c1aaa5bdaaed..8f37c962528c 100644 --- a/src/transformers/cli/add_fast_image_processor.py +++ b/src/transformers/cli/add_fast_image_processor.py @@ -59,8 +59,6 @@ def add_fast_image_processor( image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file) if not image_processor_name: raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}") - elif len(image_processor_name) > 1: - raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}") image_processor_name = image_processor_name[0] fast_image_processor_name = image_processor_name + "Fast" diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index f675da162079..2525378f8c40 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -166,6 +166,129 @@ def divide_to_patches( @auto_docstring class BaseImageProcessorFast(BaseImageProcessor): + r""" + Base class for fast image processors using PyTorch and TorchVision for image transformations. + + This class provides a complete implementation for standard image preprocessing operations (resize, crop, rescale, + normalize) with GPU support and batch processing optimizations. Most image processors can be implemented by simply + setting class attributes; only processors requiring custom logic need to override methods. + + Basic Implementation + -------------------- + + For processors that only need standard operations (resize, center crop, rescale, normalize), define class + attributes: + + class MyImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"height": 224, "width": 224} + do_resize = True + do_rescale = True + do_normalize = True + + Custom Processing + ----------------- + + Override `_preprocess` (most common): + For custom image processing logic, override `_preprocess`. This method receives a list of torch tensors with + channel dimension first and should return a BatchFeature. Use `group_images_by_shape` and `reorder_images` for + efficient batch processing: + + def _preprocess( + self, + images: list[torch.Tensor], + do_resize: bool, + size: SizeDict, + # ... other parameters + **kwargs, + ) -> BatchFeature: + # Group images by shape for batched operations + grouped_images, indices = group_images_by_shape(images) + processed_groups = {} + + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(stacked_images, size) + # Custom processing here + processed_groups[shape] = stacked_images + + processed_images = reorder_images(processed_groups, indices) + return BatchFeature(data={"pixel_values": torch.stack(processed_images)}) + + Override `_preprocess_image_like_inputs` (for additional inputs): + For processors handling multiple input types (e.g., images + segmentation maps), override this method: + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[torch.device] = None, + **kwargs, + ) -> BatchFeature: + images = self._prepare_image_like_inputs(images, do_convert_rgb, input_data_format, device) + batch_feature = self._preprocess(images, **kwargs) + + if segmentation_maps is not None: + # Process segmentation maps separately + maps = self._prepare_image_like_inputs(segmentation_maps, ...) + batch_feature["labels"] = self._preprocess(maps, ...) + + return batch_feature + + Override `_further_process_kwargs` (for custom kwargs formatting): + To format custom kwargs before validation: + + def _further_process_kwargs(self, custom_param=None, **kwargs): + kwargs = super()._further_process_kwargs(**kwargs) + if custom_param is not None: + kwargs["custom_param"] = self._format_custom_param(custom_param) + return kwargs + + Override `_validate_preprocess_kwargs` (for custom validation): + To add custom validation logic: + + def _validate_preprocess_kwargs(self, custom_param=None, **kwargs): + super()._validate_preprocess_kwargs(**kwargs) + if custom_param is not None and custom_param < 0: + raise ValueError("custom_param must be non-negative") + + Override `_prepare_images_structure` (for nested inputs): + By default, nested image lists are flattened. Override to preserve structure: + + def _prepare_images_structure(self, images, expected_ndims=3): + # Custom logic to handle nested structure + return images # Return as-is or with custom processing + + Custom Parameters + ----------------- + + To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`: + + class MyImageProcessorKwargs(ImagesKwargs): + custom_param: Optional[int] = None + another_param: Optional[bool] = None + + class MyImageProcessorFast(BaseImageProcessorFast): + valid_kwargs = MyImageProcessorKwargs + custom_param = 10 # default value + + def _preprocess(self, images, custom_param, **kwargs): + # Use custom_param in processing + ... + + Key Notes + --------- + + - Images in `_preprocess` are always torch tensors with channel dimension first, regardless of input format + - Arguments not provided by users default to class attribute values + - Use batch processing utilities (`group_images_by_shape`, `reorder_images`) for GPU efficiency + - Image loading, format conversion, and argument handling are automatic - focus only on processing logic + """ + resample = None image_mean = None image_std = None diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 6c9b69db4555..b5f7ff76d64d 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -149,7 +149,7 @@ ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), ("perception_lm", (None, "PerceptionLMImageProcessorFast")), ("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")), - ("pix2struct", ("Pix2StructImageProcessor", None)), + ("pix2struct", ("Pix2StructImageProcessor", "Pix2StructImageProcessorFast")), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")), ("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")), diff --git a/src/transformers/models/pix2struct/__init__.py b/src/transformers/models/pix2struct/__init__.py index aa645dff0494..7cf28e634b13 100644 --- a/src/transformers/models/pix2struct/__init__.py +++ b/src/transformers/models/pix2struct/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .configuration_pix2struct import * from .image_processing_pix2struct import * + from .image_processing_pix2struct_fast import * from .modeling_pix2struct import * from .processing_pix2struct import * else: diff --git a/src/transformers/models/pix2struct/image_processing_pix2struct_fast.py b/src/transformers/models/pix2struct/image_processing_pix2struct_fast.py new file mode 100644 index 000000000000..8dcea3c9ab27 --- /dev/null +++ b/src/transformers/models/pix2struct/image_processing_pix2struct_fast.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Pix2Struct.""" + +from typing import Optional, Union + +import torch +from PIL import Image +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature, get_size_dict +from ...image_processing_utils_fast import BaseImageProcessorFast +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ChannelDimension, ImageInput, SizeDict +from ...processing_utils import Unpack +from ...utils import TensorType, auto_docstring +from .image_processing_pix2struct import Pix2StructImageProcessorKwargs, render_text + + +# Disable as it causes issues with torch.compile +@torch.compiler.disable +def torch_extract_patches(image_tensor, patch_height, patch_width): + """ + Extract patches from image tensor. Returns tensor of shape (batch, rows, columns, patch_height*patch_width*channels). + + Args: + image_tensor (`torch.Tensor`): + Image tensor of shape (batch, channels, height, width). + patch_height (`int`): + Height of patches to extract. + patch_width (`int`): + Width of patches to extract. + """ + batch_size, channels, height, width = image_tensor.shape + patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) + patches = patches.reshape(batch_size, channels, patch_height, patch_width, -1) + patches = patches.permute(0, 4, 2, 3, 1).reshape( + batch_size, height // patch_height, width // patch_width, channels * patch_height * patch_width + ) + return patches + + +@auto_docstring +class Pix2StructImageProcessorFast(BaseImageProcessorFast): + rescale_factor = None + do_normalize = True + do_convert_rgb = True + patch_size = {"height": 16, "width": 16} + max_patches = 2048 + is_vqa = False + valid_kwargs = Pix2StructImageProcessorKwargs + model_input_names = ["flattened_patches", "attention_mask"] + + def _further_process_kwargs( + self, + patch_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> dict: + """ + Process custom Pix2Struct kwargs, specifically converting patch_size to SizeDict. + """ + # Call super to handle standard kwargs processing (like converting patch_size to SizeDict) + kwargs = super()._further_process_kwargs(**kwargs) + kwargs["patch_size"] = SizeDict(**get_size_dict(size=patch_size, param_name="patch_size")) + + return kwargs + + def _validate_preprocess_kwargs(self, **kwargs): + """ + Skip standard validation as Pix2Struct uses custom preprocessing. + """ + # Pix2Struct doesn't use standard resize/rescale/normalize parameters + # so we skip the default validation + pass + + def render_header( + self, + image: torch.Tensor, + header: str, + font_bytes: Optional[bytes] = None, + font_path: Optional[str] = None, + ) -> torch.Tensor: + """ + Render header text on image using torch tensors. + + Args: + image (`torch.Tensor`): + Image tensor in channel-first format (C, H, W). + header (`str`): + Header text to render. + font_bytes (`bytes`, *optional*): + Font bytes to use for rendering. + font_path (`str`, *optional*): + Path to font file to use for rendering. + + Returns: + `torch.Tensor`: Image with header in channel-first format (C, H, W). + """ + device = image.device + dtype = image.dtype + + # Convert tensor to PIL (channel-first to channel-last for PIL) + if image.dtype == torch.uint8: + image_pil = F.to_pil_image(image) + else: + # If float, convert to uint8 first + image_uint8 = (image * 255).clamp(0, 255).to(torch.uint8) + image_pil = F.to_pil_image(image_uint8) + + # Render header text as PIL image + header_image = render_text(header, font_bytes=font_bytes, font_path=font_path) + + # Calculate new dimensions + new_width = max(header_image.width, image_pil.width) + new_height = int(image_pil.height * (new_width / image_pil.width)) + new_header_height = int(header_image.height * (new_width / header_image.width)) + + # Create new image and paste header and original image + new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white") + new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0)) + new_image.paste(image_pil.resize((new_width, new_height)), (0, new_header_height)) + + # Convert back to tensor (channel-first) + result = F.pil_to_tensor(new_image).to(device) + + # Convert back to original dtype if needed + if dtype != torch.uint8: + result = result.float() / 255.0 + + return result + + def normalize(self, images: torch.Tensor) -> torch.Tensor: + """ + Normalize batched images using per-image mean and standard deviation. + + Args: + images (`torch.Tensor`): + Batched float image tensor of shape (B, C, H, W). + + Returns: + `torch.Tensor`: Normalized images of shape (B, C, H, W). + """ + # Compute mean and std per image along spatial and channel dimensions + mean = images.mean(dim=(1, 2, 3), keepdim=True) # Shape: (B, 1, 1, 1) + std = images.std(dim=(1, 2, 3), keepdim=True) # Shape: (B, 1, 1, 1) + + num_elements_per_image = images.shape[1] * images.shape[2] * images.shape[3] + min_std = 1.0 / num_elements_per_image**0.5 + adjusted_stddev = torch.maximum(std, torch.tensor(min_std, device=std.device)) + + return (images - mean) / adjusted_stddev + + def extract_flattened_patches( + self, + images: torch.Tensor, + max_patches: int, + patch_size: SizeDict, + ) -> torch.Tensor: + """ + Extract flattened patches from a batch of images. + + Args: + images (`torch.Tensor`): + Batched images tensor of shape (batch, channels, height, width). + max_patches (`int`): + Maximum number of patches to extract. + patch_size (`dict[str, int]`): + Dictionary containing patch height and width. + + Returns: + `torch.Tensor`: Batched flattened patches with row/column IDs of shape (batch, max_patches, patch_dim). + """ + patch_height, patch_width = patch_size.height, patch_size.width + batch_size, channels, image_height, image_width = images.shape + + # Calculate scale to maximize patches while respecting max_patches + scale = (max_patches * (patch_height / image_height) * (patch_width / image_width)) ** 0.5 + num_feasible_rows = max(min(int(scale * image_height / patch_height), max_patches), 1) + num_feasible_cols = max(min(int(scale * image_width / patch_width), max_patches), 1) + resized_height = max(num_feasible_rows * patch_height, 1) + resized_width = max(num_feasible_cols * patch_width, 1) + + # Resize images (batched) using parent class method + resize_size = SizeDict(height=resized_height, width=resized_width) + images = self.resize( + image=images, size=resize_size, interpolation=F.InterpolationMode.BILINEAR, antialias=True + ) + + # Extract patches: [batch, rows, columns, patch_height * patch_width * channels] + patches = torch_extract_patches(images, patch_height, patch_width) + + batch_size, rows, columns, depth = patches.shape + + # Reshape to [batch, rows * columns, depth] + patches = patches.reshape(batch_size, rows * columns, depth) + + # Create row and column IDs + row_ids = ( + torch.arange(rows, device=images.device).reshape(rows, 1).repeat(1, columns).reshape(1, rows * columns, 1) + ) + col_ids = ( + torch.arange(columns, device=images.device) + .reshape(1, columns) + .repeat(rows, 1) + .reshape(1, rows * columns, 1) + ) + + # Expand to batch size + row_ids = row_ids.expand(batch_size, -1, -1) + col_ids = col_ids.expand(batch_size, -1, -1) + + # Offset by 1 so IDs don't contain zeros (which represent padding) + row_ids = (row_ids + 1).float() + col_ids = (col_ids + 1).float() + + # Concatenate row_ids, col_ids, and patches: [batch, rows * columns, 2 + depth] + result = torch.cat([row_ids, col_ids, patches], dim=-1) + + # Pad to max_patches: [batch, max_patches, 2 + depth] + result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float() + + return result + + @auto_docstring + def preprocess( + self, + images: ImageInput, + header_text: Optional[Union[str, list[str]]] = None, + **kwargs: Unpack[Pix2StructImageProcessorKwargs], + ) -> BatchFeature: + r""" + header_text (`Union[str, list[str]]`, *optional*): + Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`. + """ + return super().preprocess(images, header_text=header_text, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + header_text: Optional[Union[str, list[str]]] = None, + do_convert_rgb: bool = True, + input_data_format: ChannelDimension = ChannelDimension.FIRST, + device: Optional[Union[str, torch.device]] = None, + **kwargs: Unpack[Pix2StructImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess images for Pix2Struct. + """ + # Prepare images (converts to torch tensors) + images = self._prepare_image_like_inputs( + images=images, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + device=device, + ) + + # Handle VQA mode with header rendering + is_vqa = kwargs.get("is_vqa", self.is_vqa) + if is_vqa: + if header_text is None: + raise ValueError("A header text must be provided for VQA models.") + + font_bytes = kwargs.pop("font_bytes", None) + font_path = kwargs.pop("font_path", None) + + if isinstance(header_text, str): + header_text = [header_text] * len(images) + + # Render headers using torch-native method + images = [ + self.render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path) + for i, image in enumerate(images) + ] + + return self._preprocess(images, **kwargs) + + def _preprocess( + self, + images: list[torch.Tensor], + do_normalize: bool, + max_patches: int, + patch_size: SizeDict, + return_tensors: Optional[Union[str, TensorType]], + disable_grouping: bool, + **kwargs, + ) -> BatchFeature: + """ + Preprocess images to extract flattened patches. + """ + # Group images by shape first for efficient batch processing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + + flattened_patches_grouped = {} + attention_masks_grouped = {} + + for shape, stacked_images in grouped_images.items(): + # Convert to float if needed (for resize and other operations) + if stacked_images.dtype == torch.uint8: + stacked_images = stacked_images.float() + + # Normalize batched images with per-image mean and std + if do_normalize: + stacked_images = self.normalize(stacked_images) + + patches = self.extract_flattened_patches( + images=stacked_images, max_patches=max_patches, patch_size=patch_size + ) + masks = (patches.sum(dim=-1) != 0).float() + + flattened_patches_grouped[shape] = patches + attention_masks_grouped[shape] = masks + + flattened_patches = reorder_images(flattened_patches_grouped, grouped_images_index) + attention_masks = reorder_images(attention_masks_grouped, grouped_images_index) + + # Stack if return_tensors is set + if return_tensors: + flattened_patches = torch.stack(flattened_patches, dim=0) + attention_masks = torch.stack(attention_masks, dim=0) + + return BatchFeature( + data={"flattened_patches": flattened_patches, "attention_mask": attention_masks}, + tensor_type=return_tensors, + ) + + +__all__ = ["Pix2StructImageProcessorFast"] diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index d68317e6c903..b67d9fa11760 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -240,6 +240,14 @@ class ImageProcessorArgs: "shape": None, } + image_seq_length = { + "description": """ + The number of image tokens to be used for each image in the input. + Added for backward compatibility but this should be set as a processor attribute in future models. + """, + "shape": None, + } + class ModelArgs: labels = { diff --git a/tests/models/pix2struct/test_image_processing_pix2struct.py b/tests/models/pix2struct/test_image_processing_pix2struct.py index e300850a474c..0fb3019cf9da 100644 --- a/tests/models/pix2struct/test_image_processing_pix2struct.py +++ b/tests/models/pix2struct/test_image_processing_pix2struct.py @@ -16,10 +16,11 @@ import unittest import numpy as np +from packaging import version from transformers.image_utils import load_image -from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_torch_accelerator, require_vision, slow, torch_device +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs from ...test_processing_common import url_to_local_path @@ -33,6 +34,9 @@ from transformers import Pix2StructImageProcessor + if is_torchvision_available(): + from transformers import Pix2StructImageProcessorFast + class Pix2StructImageProcessingTester: def __init__( @@ -87,6 +91,7 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F @require_vision class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = Pix2StructImageProcessor if is_vision_available() else None + fast_image_processing_class = Pix2StructImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -97,198 +102,265 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) def test_expected_patches(self): dummy_image = self.image_processor_tester.prepare_dummy_image() - image_processor = self.image_processing_class(**self.image_processor_dict) - max_patch = 2048 + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + max_patch = 2048 - inputs = image_processor(dummy_image, return_tensors="pt", max_patches=max_patch) - torch.testing.assert_close(inputs.flattened_patches.mean(), torch.tensor(0.0606), rtol=1e-3, atol=1e-3) + inputs = image_processor(dummy_image, return_tensors="pt", max_patches=max_patch) + torch.testing.assert_close(inputs.flattened_patches.mean(), torch.tensor(0.0606), rtol=1e-3, atol=1e-3) def test_call_pil(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * self.image_processor_tester.num_channels - ) + 2 - - for max_patch in self.image_processor_tester.max_patches: + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + # Test not batched input - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) def test_call_vqa(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * self.image_processor_tester.num_channels - ) + 2 - - image_processor.is_vqa = True - - for max_patch in self.image_processor_tester.max_patches: + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + # Test not batched input - with self.assertRaises(ValueError): + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + image_processor.is_vqa = True + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + with self.assertRaises(ValueError): + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + + dummy_text = "Hello" + encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch + image_inputs[0], return_tensors="pt", max_patches=max_patch, header_text=dummy_text ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) - dummy_text = "Hello" - - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch, header_text=dummy_text - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch, header_text=dummy_text - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch, header_text=dummy_text + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) def test_call_numpy(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * self.image_processor_tester.num_channels - ) + 2 - - for max_patch in self.image_processor_tester.max_patches: - # Test not batched input - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) def test_call_numpy_4_channels(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random numpy tensors - self.image_processor_tester.num_channels = 4 - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) - for image in image_inputs: - self.assertIsInstance(image, np.ndarray) - - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * self.image_processor_tester.num_channels - ) + 2 - - for max_patch in self.image_processor_tester.max_patches: - # Test not batched input - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) - self.image_processor_tester.num_channels = 3 + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch, input_data_format="channels_last" + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + self.image_processor_tester.num_channels = 3 def test_call_pytorch(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random PyTorch tensors - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) - for image in image_inputs: - self.assertIsInstance(image, torch.Tensor) - - # Test not batched input - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * self.image_processor_tester.num_channels - ) + 2 - - for max_patch in self.image_processor_tester.max_patches: + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + # Test not batched input - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * self.image_processor_tester.num_channels + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) + + @require_vision + @require_torch + def test_slow_fast_equivalence(self): + dummy_image = self.image_processor_tester.prepare_dummy_image() + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_image, return_tensors="pt", max_patches=2048) + encoding_fast = image_processor_fast(dummy_image, return_tensors="pt", max_patches=2048) + # Pix2Struct uses flattened_patches instead of pixel_values + self._assert_slow_fast_tensors_equivalence(encoding_slow.flattened_patches, encoding_fast.flattened_patches) + + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt", max_patches=2048) + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt", max_patches=2048) + # Pix2Struct uses flattened_patches instead of pixel_values + self._assert_slow_fast_tensors_equivalence(encoding_slow.flattened_patches, encoding_fast.flattened_patches) + + @slow + @require_torch_accelerator + @require_vision + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + # Pix2Struct uses flattened_patches instead of pixel_values + self._assert_slow_fast_tensors_equivalence( + output_eager.flattened_patches, output_compiled.flattened_patches, atol=1e-4, rtol=1e-4, mean_atol=1e-5 + ) @require_torch @require_vision class Pix2StructImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = Pix2StructImageProcessor if is_vision_available() else None + fast_image_processing_class = Pix2StructImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() @@ -300,42 +372,44 @@ def image_processor_dict(self): return self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): - image_processor = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processor, "do_normalize")) - self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) def test_call_pil(self): - # Initialize image_processor - image_processor = self.image_processing_class(**self.image_processor_dict) - # create random PIL images - image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) - for image in image_inputs: - self.assertIsInstance(image, Image.Image) - - # Test not batched input - expected_hidden_dim = ( - (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) - * (self.image_processor_tester.num_channels - 1) - ) + 2 - - for max_patch in self.image_processor_tester.max_patches: + for image_processing_class in self.image_processor_list: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + # Test not batched input - encoded_images = image_processor( - image_inputs[0], return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (1, max_patch, expected_hidden_dim), - ) - - # Test batched - encoded_images = image_processor( - image_inputs, return_tensors="pt", max_patches=max_patch - ).flattened_patches - self.assertEqual( - encoded_images.shape, - (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), - ) + expected_hidden_dim = ( + (self.image_processor_tester.patch_size["height"] * self.image_processor_tester.patch_size["width"]) + * (self.image_processor_tester.num_channels - 1) + ) + 2 + + for max_patch in self.image_processor_tester.max_patches: + # Test not batched input + encoded_images = image_processor( + image_inputs[0], return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (1, max_patch, expected_hidden_dim), + ) + + # Test batched + encoded_images = image_processor( + image_inputs, return_tensors="pt", max_patches=max_patch + ).flattened_patches + self.assertEqual( + encoded_images.shape, + (self.image_processor_tester.batch_size, max_patch, expected_hidden_dim), + ) @unittest.skip(reason="Pix2StructImageProcessor does not support 4 channels yet") # FIXME Amy def test_call_numpy(self): @@ -350,3 +424,15 @@ def test_call_pytorch(self): ) # FIXME Amy def test_call_numpy_4_channels(self): return super().test_call_torch() + + @unittest.skip(reason="Pix2StructImageProcessor does not support 4 channels yet") + def test_slow_fast_equivalence(self): + pass + + @unittest.skip(reason="Pix2StructImageProcessor does not support 4 channels yet") + def test_slow_fast_equivalence_batched(self): + pass + + @unittest.skip(reason="Pix2StructImageProcessor does not support 4 channels yet") + def test_can_compile_fast_image_processor(self): + pass