Skip to content

Commit 567500a

Browse files
committed
[Fix] Fix support Prithvi online inference using tensor
Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com>
1 parent f6d7aad commit 567500a

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

vllm/entrypoints/chat_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
140140
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
141141
CustomChatCompletionContentSimpleImageParam,
142142
ChatCompletionContentPartImageEmbedsParam,
143+
ChatCompletionContentPartTensorsParam,
143144
CustomChatCompletionContentSimpleAudioParam,
144145
CustomChatCompletionContentSimpleVideoParam, str]
145146

@@ -583,6 +584,8 @@ def _placeholder_str(self, modality: ModalityStr,
583584
return self._cached_token_str(self._tokenizer,
584585
hf_config.video_token_index)
585586
raise TypeError(f"Unknown {modality} model type: {model_type}")
587+
elif modality == "tensors":
588+
return None
586589
else:
587590
raise TypeError(f"Unknown modality: {modality}")
588591

@@ -641,6 +644,13 @@ def all_mm_data(self) -> Optional[MultiModalDataDict]:
641644
raise ValueError(\
642645
"Only one message can have {'type': 'image_embeds'}")
643646
mm_inputs["image"] = image_embeds_lst[0]
647+
648+
if "tensors" in items_by_modality:
649+
tensors_lst = items_by_modality["tensors"]
650+
if len(tensors_lst) > 1:
651+
raise ValueError(\
652+
"Only one message can have {'type': 'tensors'}")
653+
mm_inputs["tensors"] = tensors_lst[0]
644654
if "image" in items_by_modality:
645655
mm_inputs["image"] = items_by_modality["image"] # A list of images
646656
if "audio" in items_by_modality:
@@ -674,6 +684,12 @@ async def all_mm_data(self) -> Optional[MultiModalDataDict]:
674684
raise ValueError(
675685
"Only one message can have {'type': 'image_embeds'}")
676686
mm_inputs["image"] = image_embeds_lst[0]
687+
if "tensors" in items_by_modality:
688+
tensors_lst = items_by_modality["tensors"]
689+
if len(tensors_lst) > 1:
690+
raise ValueError(\
691+
"Only one message can have {'type': 'tensors'}")
692+
mm_inputs["tensors"] = tensors_lst[0]
677693
if "image" in items_by_modality:
678694
mm_inputs["image"] = items_by_modality["image"] # A list of images
679695
if "audio" in items_by_modality:

vllm/entrypoints/openai/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
11091109
model: Optional[str] = None
11101110
messages: list[ChatCompletionMessageParam]
11111111

1112-
encoding_format: Literal["float", "base64", "tensors"] = "float"
1112+
encoding_format: Literal["float", "base64", "tensor"] = "float"
11131113
dimensions: Optional[int] = None
11141114
user: Optional[str] = None
11151115
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

vllm/entrypoints/openai/serving_pooling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def _get_data(
4444
pt_float32 = output.data.to(dtype=torch.float32)
4545
pooling_bytes = np.array(pt_float32, dtype="float32").tobytes()
4646
return base64.b64encode(pooling_bytes).decode("utf-8")
47-
elif encoding_format == "tensors":
47+
elif encoding_format == "tensor":
4848
tensor_encoding_io = ImageEmbeddingMediaIO()
49-
tensor_encoding_io.encode_base64(output.data)
49+
return tensor_encoding_io.encode_base64(output.data)
5050

5151
assert_never(encoding_format)
5252

vllm/multimodal/image.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,12 @@ def load_file(self, filepath: Path) -> torch.Tensor:
9595

9696
def encode_base64(self, media: torch.Tensor) -> str:
9797
return base64.b64encode(media.numpy()).decode('utf-8')
98+
99+
# currently not used but it makes it easy
100+
# for users to reconstruct the result tensor without knowledge of the array shape
101+
def encode_tensor(self, media: torch.Tensor) -> str:
102+
buffer_tiff = BytesIO()
103+
torch.save(media.data, buffer_tiff)
104+
buffer_tiff.seek(0)
105+
binary_data = buffer_tiff.read()
106+
return base64.b64encode(binary_data).decode('utf-8')

0 commit comments

Comments
 (0)