Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/singlecard/test_offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:


@pytest.mark.parametrize("model", MULTIMODALITY_MODELS)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="qwen2.5_vl is not supported on v1")
def test_multimodal(model, prompt_template, vllm_runner):
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
Expand Down
273 changes: 268 additions & 5 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
Expand All @@ -61,6 +64,9 @@
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import AttentionMaskBuilder
Expand Down Expand Up @@ -362,6 +368,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
Expand All @@ -374,6 +381,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if req_index is not None:
removed_req_indices.append(req_index)

# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)

# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
Expand Down Expand Up @@ -415,6 +430,43 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
lora_request=new_req_data.lora_request,
)

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend(
mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend(
mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True

hf_config = self.model_config.hf_config

self.requests[req_id].mrope_positions, \
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)

req_ids_to_add.append(req_id)

# Update the states of the running/resumed requests.
Expand Down Expand Up @@ -535,6 +587,166 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
else:
return None

def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
req = self.requests[req_id]
assert req.mrope_positions is not None

num_computed_tokens = \
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)

if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
num_prompt_tokens - num_computed_tokens)
completion_part_len = max(
0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0

assert num_scheduled_tokens == prompt_part_len + completion_part_len

if prompt_part_len > 0:
# prompt's mrope_positions are pre-computed
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + prompt_part_len
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len

self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]

mrope_pos_ptr += prompt_part_len

if completion_part_len > 0:
# compute completion's mrope_positions on-the-fly
dst_start = mrope_pos_ptr
dst_end = mrope_pos_ptr + completion_part_len

self.mrope_positions_cpu[:, dst_start:dst_end] = \
MRotaryEmbedding.get_next_input_positions_tensor(
req.mrope_position_delta,
context_len=num_computed_tokens +
prompt_part_len,
seq_len=num_computed_tokens +
prompt_part_len +
completion_part_len,
)

mrope_pos_ptr += completion_part_len

def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return

# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]

for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))

# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)

# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)

sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)

for output in curr_group_outputs:
encoder_outputs.append(output)

# Cache the encoder outputs.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}

self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)

def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length

# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue

start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]

if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]

mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds

def _process_reqs(
self,
scheduler_output: "SchedulerOutput",
Expand Down Expand Up @@ -594,6 +806,17 @@ def _process_reqs(
arange,
out=positions_np)

# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)

if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)

self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
positions = self.positions[:num_input_tokens]
Expand Down Expand Up @@ -706,6 +929,43 @@ def _process_reqs(
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]

# prepare the MRoPE for mllm if using multimodal
num_input_tokens = total_num_scheduled_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []

if self.is_multimodal_model:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this if case can be merge to L782 as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this indeed can be merge, but in the vllm upstream, this is separate. So I think this design may be for the clarity of the concept of executing encode images (using ViT part) of the multimodal_model and the concept of converting the place_holder embeddings to the actual images embeddings? whether I need to merge this two if case or just be consistent with vllm upstream?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, just keep the same is better for maintenance

# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_input_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_input_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]

# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
Expand All @@ -722,7 +982,7 @@ def _process_reqs(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
else:
Expand All @@ -731,7 +991,7 @@ def _process_reqs(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
inputs_embeds=inputs_embeds,
**model_kwargs,
)

Expand Down Expand Up @@ -1214,8 +1474,11 @@ def _dummy_run(
return hidden_states

def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
self._profile_multimodal()
# FIXME Profile with multimodal encoder & encoder cache.
# current _profile_multimodal() using PyTorch SDPA backend method not
# support for window/full attn to reduce Memcpy operations, so will cause
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
# self._profile_multimodal()
Copy link
Collaborator

@wangxiyuan wangxiyuan May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can _profile_multimodal be deleted directly? or add more check like cuda in vllm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently can just delete it, because currently the created dummy data in the forward process will have a large peak memory usage which is easy to make out of memory using 910B3 if I use the profile code same with the gpu_model_runner.py (I have tried this code in my machine, but not commit it in this pull request). However, in xformers backend it won't out of memory using my RTX 4090, I still not figure out the real cause, but I found in every VisionBlock forward in visiontransformer will increase a large memory usage in HBM, but in the upstream vllm code it will only increase my GPU RTX 4090 memory a little in many VisionBlock forwarding. And it was this blk forward loop eventually make the memory increased more than 64 GB in my 910B3 and raise out of memory Error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we should leave the code here. Let's re-visit it once the oom problem is solved. @zouyida2052 please double check this as well. Thanks.


# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
Expand Down