Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available)
is_pin_memory_available,
length_from_prompt_token_ids_or_embeds)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
Expand Down Expand Up @@ -285,11 +286,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
dtype=self.dtype,
device=self.device)
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)
Comment on lines 298 to 304
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The inputs_embeds buffer is only initialized for multimodal models. However, it is also required when enable_prompt_embeds is true for non-multimodal models. Without this initialization, an AttributeError will be raised when self.inputs_embeds is accessed later in methods like _dummy_run or _prepare_input_ids. The condition should be updated to include self.enable_prompt_embeds.

Suggested change
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
dtype=self.dtype,
device=self.device)
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)

self.is_token_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)

# Set up Attention
self.attn_backend = get_attn_backend(
Expand Down Expand Up @@ -572,6 +576,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
Expand Down Expand Up @@ -843,7 +848,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds)

if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
Expand Down Expand Up @@ -1016,6 +1022,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return

# Async scheduling case, where some decode requests from the previous
Expand Down Expand Up @@ -1043,6 +1051,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
Expand All @@ -1056,6 +1066,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
Expand Down Expand Up @@ -1195,15 +1206,60 @@ def _prepare_inputs(
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])

token_indices_tensor = torch.from_numpy(token_indices)
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
token_indices_tensor,
out=self.input_ids_cpu[:total_num_scheduled_tokens])
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])

# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
if self.input_batch.req_prompt_embeds:
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]

# Skip if this request doesn't have embeddings
if req_idx not in self.input_batch.req_prompt_embeds:
output_idx += num_sched
continue

# Skip if no tokens scheduled
if num_sched <= 0:
output_idx += num_sched
continue

req_embeds = self.input_batch.req_prompt_embeds[req_idx]
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]

# Skip if trying to read beyond available embeddings
if start_pos >= req_embeds.shape[0]:
output_idx += num_sched
continue

# Copy available embeddings
end_pos = start_pos + num_sched
actual_end = min(end_pos, req_embeds.shape[0])
actual_num_sched = actual_end - start_pos

if actual_num_sched > 0:
self.inputs_embeds.cpu[output_idx:output_idx +
actual_num_sched].copy_(
req_embeds[start_pos:actual_end]
)

output_idx += num_sched

# Prepare some information for building Attention-Metadata
# Compute and commit slot mapping
Expand Down Expand Up @@ -1985,6 +2041,7 @@ def execute_model(

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
Expand Down Expand Up @@ -2200,6 +2257,9 @@ def _dummy_run(
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
Expand Down Expand Up @@ -3070,6 +3130,9 @@ def _get_prompt_logprobs_dict(

# Get metadata for this request.
request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
Expand Down
55 changes: 48 additions & 7 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds,swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
Expand All @@ -45,7 +45,7 @@
class CachedRequestState:

req_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
Expand All @@ -61,9 +61,11 @@ class CachedRequestState:
mrope_position_delta: Optional[int] = None

lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None

def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)

@property
def num_tokens(self) -> int:
Expand All @@ -78,6 +80,10 @@ def mm_inputs(self) -> list[MultiModalKwargs]:

def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
Expand Down Expand Up @@ -122,6 +128,14 @@ def __init__(
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
Expand Down Expand Up @@ -326,15 +340,23 @@ def add_request(
self.req_id_to_index[req_id] = req_index

# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
Expand Down Expand Up @@ -534,6 +556,20 @@ def swap_states(self, i1: int, i2: int) -> None:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp

self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)

Expand Down Expand Up @@ -612,6 +648,11 @@ def condense(self) -> None:
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
Expand Down
Loading