Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/pix2struct.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ The original code can be found [here](https://github.com/google-research/pix2str
[[autodoc]] Pix2StructImageProcessor
- preprocess

## Pix2StructImageProcessorFast

[[autodoc]] Pix2StructImageProcessorFast
- preprocess

## Pix2StructTextModel

[[autodoc]] Pix2StructTextModel
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/cli/add_fast_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def add_fast_image_processor(
image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file)
if not image_processor_name:
raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}")
elif len(image_processor_name) > 1:
raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}")

image_processor_name = image_processor_name[0]
fast_image_processor_name = image_processor_name + "Fast"
Expand Down
123 changes: 123 additions & 0 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,129 @@ def divide_to_patches(

@auto_docstring
class BaseImageProcessorFast(BaseImageProcessor):
r"""
Base class for fast image processors using PyTorch and TorchVision for image transformations.

This class provides a complete implementation for standard image preprocessing operations (resize, crop, rescale,
normalize) with GPU support and batch processing optimizations. Most image processors can be implemented by simply
setting class attributes; only processors requiring custom logic need to override methods.

Basic Implementation
--------------------

For processors that only need standard operations (resize, center crop, rescale, normalize), define class
attributes:

class MyImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = IMAGENET_DEFAULT_MEAN
image_std = IMAGENET_DEFAULT_STD
size = {"height": 224, "width": 224}
do_resize = True
do_rescale = True
do_normalize = True

Custom Processing
-----------------

Override `_preprocess` (most common):
For custom image processing logic, override `_preprocess`. This method receives a list of torch tensors with
channel dimension first and should return a BatchFeature. Use `group_images_by_shape` and `reorder_images` for
efficient batch processing:

def _preprocess(
self,
images: list[torch.Tensor],
do_resize: bool,
size: SizeDict,
# ... other parameters
**kwargs,
) -> BatchFeature:
# Group images by shape for batched operations
grouped_images, indices = group_images_by_shape(images)
processed_groups = {}

for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(stacked_images, size)
# Custom processing here
processed_groups[shape] = stacked_images

processed_images = reorder_images(processed_groups, indices)
return BatchFeature(data={"pixel_values": torch.stack(processed_images)})

Override `_preprocess_image_like_inputs` (for additional inputs):
For processors handling multiple input types (e.g., images + segmentation maps), override this method:

def _preprocess_image_like_inputs(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_convert_rgb: bool,
input_data_format: ChannelDimension,
device: Optional[torch.device] = None,
**kwargs,
) -> BatchFeature:
images = self._prepare_image_like_inputs(images, do_convert_rgb, input_data_format, device)
batch_feature = self._preprocess(images, **kwargs)

if segmentation_maps is not None:
# Process segmentation maps separately
maps = self._prepare_image_like_inputs(segmentation_maps, ...)
batch_feature["labels"] = self._preprocess(maps, ...)

return batch_feature

Override `_further_process_kwargs` (for custom kwargs formatting):
To format custom kwargs before validation:

def _further_process_kwargs(self, custom_param=None, **kwargs):
kwargs = super()._further_process_kwargs(**kwargs)
if custom_param is not None:
kwargs["custom_param"] = self._format_custom_param(custom_param)
return kwargs

Override `_validate_preprocess_kwargs` (for custom validation):
To add custom validation logic:

def _validate_preprocess_kwargs(self, custom_param=None, **kwargs):
super()._validate_preprocess_kwargs(**kwargs)
if custom_param is not None and custom_param < 0:
raise ValueError("custom_param must be non-negative")

Override `_prepare_images_structure` (for nested inputs):
By default, nested image lists are flattened. Override to preserve structure:

def _prepare_images_structure(self, images, expected_ndims=3):
# Custom logic to handle nested structure
return images # Return as-is or with custom processing

Custom Parameters
-----------------

To add parameters beyond `ImagesKwargs`, create a custom kwargs class and set it as `valid_kwargs`:

class MyImageProcessorKwargs(ImagesKwargs):
custom_param: Optional[int] = None
another_param: Optional[bool] = None

class MyImageProcessorFast(BaseImageProcessorFast):
valid_kwargs = MyImageProcessorKwargs
custom_param = 10 # default value

def _preprocess(self, images, custom_param, **kwargs):
# Use custom_param in processing
...

Key Notes
---------

- Images in `_preprocess` are always torch tensors with channel dimension first, regardless of input format
- Arguments not provided by users default to class attribute values
- Use batch processing utilities (`group_images_by_shape`, `reorder_images`) for GPU efficiency
- Image loading, format conversion, and argument handling are automatic - focus only on processing logic
"""

resample = None
image_mean = None
image_std = None
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("perception_lm", (None, "PerceptionLMImageProcessorFast")),
("phi4_multimodal", (None, "Phi4MultimodalImageProcessorFast")),
("pix2struct", ("Pix2StructImageProcessor", None)),
("pix2struct", ("Pix2StructImageProcessor", "Pix2StructImageProcessorFast")),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")),
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/pix2struct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if TYPE_CHECKING:
from .configuration_pix2struct import *
from .image_processing_pix2struct import *
from .image_processing_pix2struct_fast import *
from .modeling_pix2struct import *
from .processing_pix2struct import *
else:
Expand Down
Loading