diff --git a/tests/e2e/singlecard/test_completion_with_prompt_embeds.py b/tests/e2e/singlecard/test_completion_with_prompt_embeds.py new file mode 100644 index 0000000000..02ba41d72d --- /dev/null +++ b/tests/e2e/singlecard/test_completion_with_prompt_embeds.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import io + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +# downloading lora to test lora requests +from openai import BadRequestError +from transformers import AutoConfig + +from ..utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + +class RemoteOpenAIServer: + DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc: subprocess.Popen = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + + def __init__(self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError("You have manually specified the port " + "when `auto_port=True`.") + + # No need for a port if using unix sockets + if "--uds" not in vllm_serve_args: + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError("You have manually specified the seed " + f"when `seed={seed}`.") + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.uds = args.uds + if args.uds: + self.host = None + self.port = None + else: + self.host = str(args.host or 'localhost') + self.port = int(args.port) + + self.show_hidden_metrics = \ + args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + self._start_server(model, vllm_serve_args, env_dict) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), + timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _poll(self) -> Optional[int]: + """Subclasses override this method to customize process polling""" + return self.proc.poll() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + client = (httpx.Client(transport=httpx.HTTPTransport( + uds=self.uds)) if self.uds else requests) + while True: + try: + if client.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self._poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError( + "Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return (f"http://{self.uds.split('/')[-1]}" + if self.uds else f"http://{self.host}:{self.port}") + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.OpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) + + +@pytest.fixture(scope="module") +def default_server_args() -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + ] + + +EXAMPLE_PROMPTS = [ + "Hello, my name is", + "What is an LLM?", +] + + +def _encode_embeds(embeds: torch.Tensor): + buffer = io.BytesIO() + torch.save(embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.fixture(scope="module") +def example_prompt_embeds(hf_runner): + """Create example embeddings and return them as base64 encoded string.""" + with hf_runner(MODEL_NAME) as hf_model: + example_embeddings = hf_model.get_prompt_embeddings(EXAMPLE_PROMPTS) + + return [_encode_embeds(item) for item in example_embeddings] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, + model_name: str, +): + encoded_embeds, encoded_embeds2 = example_prompt_embeds + + # Test case: Single prompt embeds input + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + # Test case: mixed text and prompt_embeds + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, + logprobs_arg: int, + model_name: str, +): + encoded_embeds, encoded_embeds2 = example_prompt_embeds + + # Test case: Logprobs using prompt_embeds + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + +@pytest.mark.asyncio +async def test_prompt_logprobs_raises_error( + example_prompt_embeds, + client_with_prompt_embeds: openai.AsyncOpenAI, +): + encoded_embeds, _ = example_prompt_embeds + + with pytest.raises(BadRequestError, match="not compatible"): + await client_with_prompt_embeds.completions.create( + model=MODEL_NAME, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={ + "prompt_embeds": encoded_embeds, + "prompt_logprobs": True + }, + ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 279b767241..532f424a7c 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -18,14 +18,25 @@ # import functools +import json import os import signal +import subprocess +import sys +import time from collections.abc import Sequence -from typing import Callable +from typing import Any, Callable, Optional +import httpx +import openai +import requests import torch import torch.nn.functional as F from typing_extensions import ParamSpec +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.cli.serve import ServeSubcommand +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils import FlexibleArgumentParser, get_open_port _P = ParamSpec("_P") @@ -104,3 +115,152 @@ def check_embeddings_close( f"\n{name_1}:\t{embeddings_1[:16]!r}") assert sim >= 1 - tol, fail_msg + + +class RemoteOpenAIServer: + DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + + def _start_server(self, model: str, vllm_serve_args: list[str], + env_dict: Optional[dict[str, str]]) -> None: + """Subclasses override this method to customize server process launch + """ + env = os.environ.copy() + # the current process might initialize cuda, + # to be safe, we should use spawn method + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) + self.proc: subprocess.Popen = subprocess.Popen( + ["vllm", "serve", model, *vllm_serve_args], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + + def __init__(self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None) -> None: + if auto_port: + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: + raise ValueError("You have manually specified the port " + "when `auto_port=True`.") + + # No need for a port if using unix sockets + if "--uds" not in vllm_serve_args: + # Don't mutate the input args + vllm_serve_args = vllm_serve_args + [ + "--port", str(get_open_port()) + ] + if seed is not None: + if "--seed" in vllm_serve_args: + raise ValueError("You have manually specified the seed " + f"when `seed={seed}`.") + + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] + + if override_hf_configs is not None: + vllm_serve_args = vllm_serve_args + [ + "--hf-overrides", + json.dumps(override_hf_configs) + ] + + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + subparsers = parser.add_subparsers(required=False, dest="subparser") + parser = ServeSubcommand().subparser_init(subparsers) + args = parser.parse_args(["--model", model, *vllm_serve_args]) + self.uds = args.uds + if args.uds: + self.host = None + self.port = None + else: + self.host = str(args.host or 'localhost') + self.port = int(args.port) + + self.show_hidden_metrics = \ + args.show_hidden_metrics_for_version is not None + + # download the model before starting the server to avoid timeout + is_local = os.path.isdir(model) + if not is_local: + engine_args = AsyncEngineArgs.from_cli_args(args) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) + + self._start_server(model, vllm_serve_args, env_dict) + max_wait_seconds = max_wait_seconds or 240 + self._wait_for_server(url=self.url_for("health"), + timeout=max_wait_seconds) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.terminate() + try: + self.proc.wait(8) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() + + def _poll(self) -> Optional[int]: + """Subclasses override this method to customize process polling""" + return self.proc.poll() + + def _wait_for_server(self, *, url: str, timeout: float): + # run health check + start = time.time() + client = (httpx.Client(transport=httpx.HTTPTransport( + uds=self.uds)) if self.uds else requests) + while True: + try: + if client.get(url).status_code == 200: + break + except Exception: + # this exception can only be raised by requests.get, + # which means the server is not ready yet. + # the stack trace is not useful, so we suppress it + # by using `raise from None`. + result = self._poll() + if result is not None and result != 0: + raise RuntimeError("Server exited unexpectedly.") from None + + time.sleep(0.5) + if time.time() - start > timeout: + raise RuntimeError( + "Server failed to start in time.") from None + + @property + def url_root(self) -> str: + return (f"http://{self.uds.split('/')[-1]}" + if self.uds else f"http://{self.host}:{self.port}") + + def url_for(self, *parts: str) -> str: + return self.url_root + "/" + "/".join(parts) + + def get_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.OpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return openai.AsyncOpenAI(base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff055e47b6..b55f47ec48 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -68,7 +68,8 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LazyLoader, cdiv, get_dtype_size, - is_pin_memory_available) + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -301,11 +302,16 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.is_multimodal_model = self.model_config.is_multimodal_model self.is_pooling_model = self.model_config.pooler_config is not None - if self.is_multimodal_model: - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.model_config.get_hidden_size()), + self.enable_prompt_embeds = self.model_config.enable_prompt_embeds + if self.is_multimodal_model or self.enable_prompt_embeds: + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, + self.model_config.get_hidden_size(), dtype=self.dtype, - device=self.device) + numpy=False) + self.is_token_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + # Set up Attention if vllm_version_is("0.10.2"): self.attn_backend = get_attn_backend( @@ -611,6 +617,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -919,7 +926,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): self.input_batch.num_computed_tokens_cpu[index] num_scheduled_tokens = \ scheduler_output.num_scheduled_tokens[req_id] - num_prompt_tokens = len(req.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, @@ -1151,6 +1159,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) return # Async scheduling case, where some decode requests from the previous @@ -1178,6 +1188,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration # So input_ids_cpu will have all the input ids. @@ -1191,6 +1203,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], non_blocking=True) + self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. @@ -1330,15 +1343,60 @@ def _prepare_inputs( # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - + token_indices_tensor = torch.from_numpy(token_indices) # Prepare input_ids. # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices), + token_indices_tensor, out=self.input_ids_cpu[:total_num_scheduled_tokens]) + is_token_ids = self.input_batch.is_token_ids.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds: + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[output_idx:output_idx + + actual_num_sched].copy_( + req_embeds[start_pos:actual_end] + ) + + output_idx += num_sched # Prepare some information for building Attention-Metadata # Compute and commit slot mapping @@ -2154,6 +2212,8 @@ def execute_model( self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, + start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] @@ -2410,6 +2470,9 @@ def _dummy_run( if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None @@ -3483,6 +3546,9 @@ def _get_prompt_logprobs_dict( # Get metadata for this request. request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) @@ -3566,3 +3632,18 @@ def get_supported_pooling_tasks(self): def _build_drafter_prepare_inputs_torchair_param(self): return False + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a npu wise stream sync, which + # would block other copy ops from other npu streams. + # A npu event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist() \ No newline at end of file diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index d1ebd023c3..01ef3b3276 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -29,7 +29,7 @@ MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds,swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, @@ -47,7 +47,7 @@ class CachedRequestState: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -66,9 +66,11 @@ class CachedRequestState: mm_hashes: Optional[list[PlaceholderRange]] = None lora_request: Optional[LoRARequest] = None + prompt_embeds: Optional[torch.Tensor] = None def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) @property def num_tokens(self) -> int: @@ -93,6 +95,10 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: + if self.prompt_token_ids is None: + raise ValueError( + f"Tried to access token index {idx}, but that token was " + "provided via prompt_embeds, and its ID is unknown.") return self.prompt_token_ids[idx] else: return self.output_token_ids[idx - self.num_prompt_tokens] @@ -137,6 +143,14 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), + device="cpu", + dtype=bool, + pin_memory=False) + # Store prompt embeddings per request to avoid OOM from large upfront + # allocation if max_model_len is big. + # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) + self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -341,15 +355,23 @@ def add_request( self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) + if request.prompt_token_ids is not None: + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + self.is_token_ids[req_index, :num_prompt_tokens] = True + else: + self.is_token_ids[req_index, :num_prompt_tokens] = False + if request.prompt_embeds is not None: + self.req_prompt_embeds[req_index] = request.prompt_embeds self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. + self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. @@ -549,6 +571,20 @@ def swap_states(self, i1: int, i2: int) -> None: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + # Swap prompt embeddings if they exist + embeds_i1 = self.req_prompt_embeds.get(i1) + embeds_i2 = self.req_prompt_embeds.get(i2) + if embeds_i1 is not None: + self.req_prompt_embeds[i2] = embeds_i1 + else: + self.req_prompt_embeds.pop(i2, None) + if embeds_i2 is not None: + self.req_prompt_embeds[i1] = embeds_i2 + else: + self.req_prompt_embeds.pop(i1, None) + swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -627,6 +663,11 @@ def condense(self) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ + last_req_index, :num_tokens] + if last_req_index in self.req_prompt_embeds: + self.req_prompt_embeds[ + empty_index] = self.req_prompt_embeds.pop(last_req_index) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index]