Skip to content

[VLMs] add helpers to get multimodal encodings #37743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,8 +1386,26 @@ def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
vision_feature_layer: int = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.

Args:
pixel_values (`torch.FloatTensor]`)
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`)
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,26 @@ def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
vision_feature_layer: int = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.

Args:
pixel_values (`torch.FloatTensor]`)
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`)
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
Expand Down
17 changes: 13 additions & 4 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def get_decoder(self):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
vision_feature_layer: Union[int, List[int]] = None,
vision_feature_select_strategy: str = None,
**kwargs,
):
"""
Expand All @@ -309,16 +309,25 @@ def get_image_features(
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
vision_feature_layer (`Union[int, List[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)

if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")

Expand Down
78 changes: 48 additions & 30 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,6 +2165,50 @@ def _preprocess_accelerate(self):
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility

def get_image_features(
self,
pixel_values: torch.FloatTensor,
interpolate_pos_encoding: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.

Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=True,
)
image_embeds = vision_outputs[0]

# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs[0]

# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
if return_dict:
return language_model_inputs, vision_outputs, query_outputs
return language_model_inputs

@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig)
def forward(
Expand Down Expand Up @@ -2245,37 +2289,11 @@ def forward(
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
language_model_inputs, vision_outputs, query_outputs = self.get_image_features(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=True
)
image_embeds = vision_outputs[0]

# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0]

# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)

# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
vision_outputs = vision_outputs.to_tuple() if not return_dict else vision_outputs
query_outputs = query_outputs.to_tuple() if not return_dict else query_outputs
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,12 @@ def set_input_embeddings(self, value):
self.embed_tokens = value

def get_image_tokens(self, pixel_values: torch.FloatTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)

def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
Expand Down Expand Up @@ -1279,7 +1285,7 @@ def forward(
)

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
image_tokens = self.get_image_features(pixel_values)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,12 @@ def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)

def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)

def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
Expand Down Expand Up @@ -1922,7 +1928,7 @@ def forward(
)

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
image_tokens = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,12 @@ def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)

def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
logger.warning(
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
)
return self.get_image_featues(pixel_values)

def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
Expand Down Expand Up @@ -1261,7 +1267,7 @@ def forward(
)

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
image_tokens = self.get_image_features(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
Expand Down
25 changes: 18 additions & 7 deletions src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,25 @@ def gather_continuous_embeddings(
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
)
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices]
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
output_embeddings.device
)
return output_embeddings

def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.

Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
for patch in pixel_values
]
return patch_embeddings

@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -308,12 +324,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
.squeeze(0)
.to(inputs_embeds.device)
for patch in image_patches
]
patch_embeddings = self.get_image_features(image_patches)
inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
continuous_embeddings=patch_embeddings,
Expand Down
Loading