Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class AscendCommonAttentionMetadata:
prefill_context_parallel_metadata: Optional[
AscendPrefillContextParallelMetadata] = None

max_seq_len: int = -1

def batch_size(self) -> int:
return self.seq_lens_cpu.shape[0]

def query_lens(self) -> torch.Tensor:
return self.query_start_loc[1:] - self.query_start_loc[:-1]

def split_decodes_and_prefills(
common_attn_metadata: AscendCommonAttentionMetadata,
Expand Down Expand Up @@ -212,3 +219,27 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat

def extend_flat_seqs(
seqs: torch.Tensor,
end_locs: torch.Tensor,
new_vals: torch.Tensor
) -> torch.Tensor:
"""
This function appends a single new value into multiple sequences
that are stored in a flat format. E.g.
[x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2]
"""
new_len = seqs.shape[0] + new_vals.shape[0]
new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype)
# indices for previous seqs
start_locs = end_locs[:-1] + 1
seqs_new_idxs = torch.ones_like(seqs)
seqs_new_idxs[start_locs] += 1
seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1
# indices for new values
new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device)
# assign seqs and new vals
new_seqs[seqs_new_idxs] = seqs
new_seqs[new_val_idxs] = new_vals
return new_seqs
5 changes: 5 additions & 0 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = 1
max_num_partial_prefills: int = 1
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
scheduler_cls: Union[str, Type[object]] = (
Expand All @@ -47,6 +48,7 @@ def initialize_from_config(
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["max_num_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
Expand Down Expand Up @@ -78,6 +80,9 @@ def __post_init__(self, *args) -> None:
self.max_long_partial_prefills = 1
self.long_prefill_token_threshold = MAX_INT

if self.max_num_partial_prefills is None:
self.max_num_partial_prefills = 1

if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
Expand Down
5 changes: 0 additions & 5 deletions vllm_ascend/patch/platform/patch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
from vllm_ascend.spec_decode.draft_proposer import DraftModelProposer


def get_spec_decode_method(method,
Expand All @@ -35,6 +36,8 @@ def get_spec_decode_method(method,
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
elif method == 'draft_model':
return DraftModelProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
275 changes: 275 additions & 0 deletions vllm_ascend/spec_decode/draft_proposer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
from dataclasses import dataclass, replace
from typing import Any

import torch

from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID

from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
from vllm_ascend.attention.attention_v1 import AscendMetadata
from vllm_ascend.attention.utils import extend_flat_seqs

logger = init_logger(__name__)


class DraftModelProposer(SpecDecodeBaseProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(
vllm_config=vllm_config,
device=device,
pass_hidden_states_to_model=False,
runner=runner,
)
self.draft_model_config = vllm_config.speculative_config.draft_model_config
self._raise_if_mrope()
self._raise_if_padded_drafter_batch()
self._raise_if_vocab_size_mismatch()
self._raise_if_draft_tp_mismatch()


def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):

attn_metadata = self._get_atten_dict(scheduler_output)
attn_metadata = attn_metadata[self.attn_layer_name]
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.runner.input_batch.req_ids[i]
req_state = self.runner.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)

if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
cu_num_tokens =attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.prepare_inputs(

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]

Check failure on line 91 in vllm_ascend/spec_decode/draft_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"DraftModelProposer" has no attribute "prepare_inputs" [attr-defined]
attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]

(target_token_ids, target_positions,
target_slot_mapping, cu_num_tokens) = merge_next_token_ids_into_token_ids(
input_token_ids=target_token_ids,
input_positions=target_positions,
cad=attn_metadata,
next_token_ids=next_token_ids,
block_size=self.block_size,
max_model_len=self.vllm_config.model_config.max_model_len,
arange=self.arange,
cu_num_tokens=cu_num_tokens)

draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=None,
target_slot_mapping=target_slot_mapping.to(torch.int32),
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()

return spec_token_ids



def _raise_if_mrope(self):
if self.draft_model_config.uses_mrope:
raise NotImplementedError(
"Speculative Decoding with draft models does not support M-RoPE yet"
)

def _raise_if_padded_drafter_batch(self):
if not self.vllm_config.speculative_config.disable_padded_drafter_batch:
raise NotImplementedError(
"Speculative Decoding with draft models does not support "
"padded drafter batch yet. Please pass --disable-padded-drafter-batch "
"in the speculative_config."
)

def _raise_if_vocab_size_mismatch(self):
speculative_config = self.vllm_config.speculative_config
if (
speculative_config.method == "draft_model"
and speculative_config.target_model_config is not None
and speculative_config.draft_model_config is not None
):
target_vocab_size = speculative_config.target_model_config.get_vocab_size()
draft_vocab_size = speculative_config.draft_model_config.get_vocab_size()
if target_vocab_size != draft_vocab_size:
raise ValueError(
f"Target and draft model should have the same vocabulary size. "
f"Target model vocab_size={target_vocab_size}. "
f"Draft model vocab_size={draft_vocab_size}. "
f"Using models with different tokenizers can cause out-of-bounds "
f"errors during speculative decoding."
)

def _raise_if_draft_tp_mismatch(self):
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
# the draft model with TP = 1, then the different TP ranks collide.
# Specifically when all ranks compile the draft model on rank 0
# (because TP=1), then the torch compile cache is overwritten and corrupted.
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
# To prevent this error, we assert that both TP sizes must be the same.
spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
if draft_tp != tgt_tp:
raise ValueError(
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
f"must be the same. Got {draft_tp} and {tgt_tp}. "
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
)

def set_input_ids_first_pass(
self,
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
num_tokens: int,
last_token_indices: torch.Tensor,
) -> None:
self.input_ids[:num_tokens] = target_token_ids

def load_model(self, target_model: Any) -> None:
"""Takes target_model to satisfy the type checker."""

# This must be computed before loading the draft model
# because that mutates the forward_context of the vllm_config
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
)

from vllm.compilation.backends import set_model_tag

draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
target_model_vllm_config=self.vllm_config
)
logger.info(
"Starting to load draft model %s. TP=%d, rank=%d",
draft_vllm_config.model_config.model,
draft_vllm_config.parallel_config.tensor_parallel_size,
draft_vllm_config.parallel_config.rank,
)
with set_model_tag("draft_model"):
self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")

# This must be computed after loading the draft model
# because that mutates the forward_context of the vllm_config
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys()
- target_attn_layer_names
)
self.attn_layer_name = next(iter(draft_attn_layer_names))

def create_vllm_config_for_draft_model(
target_model_vllm_config: VllmConfig,
) -> VllmConfig:
"""The vllm_config is configured for the target model, e.g.
its quant_config and parallel_config. But the draft model is potentially
quantized differently, and has potentially different tensor_parallel_size.
This function creates a new vllm_config configured for the draft model.
The vllm_config is useful when loading the draft model with get_model().
"""
old = target_model_vllm_config
new_parallel_config = replace(old.speculative_config.draft_parallel_config,
rank=old.parallel_config.rank
)

new: VllmConfig = replace(old,
quant_config=None, # quant_config is recomputed in __init__()
model_config=old.speculative_config.draft_model_config,
parallel_config=new_parallel_config,
)
return new

def merge_next_token_ids_into_token_ids(
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
cad: AscendMetadata,
next_token_ids: torch.Tensor,
block_size: int,
max_model_len: int,
arange: torch.Tensor,
cu_num_tokens
):
"""
Merges the next token ids with the existing token ids into a flat sequence.
Does the same for the positions, computes new slot mapping,
and updates the common_attn_metadata. The inputs are not modified in-place.
"""
query_end_locs = cu_num_tokens[1:] - 1
new_token_ids = extend_flat_seqs(
seqs=input_token_ids, end_locs=query_end_locs, new_vals=next_token_ids
)
logger.warning("new_token_ids: {}".format(new_token_ids))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This log message appears to be for debugging. Using logger.warning for diagnostic information can flood the logs and obscure actual warnings. Please use logger.debug instead.

Suggested change
logger.warning("new_token_ids: {}".format(new_token_ids))
logger.debug("new_token_ids: {}".format(new_token_ids))


# append new positions
positions_to_append = input_positions[query_end_locs] + 1
new_positions = extend_flat_seqs(
seqs=input_positions, end_locs=query_end_locs, new_vals=positions_to_append
)
# recompute slot mapping
batch_size, n_blocks_per_req = cad.block_tables.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)

query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
req_indices = torch.repeat_interleave(req_indices, query_lens.to(cad.query_start_loc.device) + 1)
block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
block_nums = cad.block_tables.view(-1)[block_table_indices]
block_offsets = new_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)

cu_num_tokens = cu_num_tokens + arange[: len(cu_num_tokens)]
return (new_token_ids, new_positions, new_slot_mapping, cu_num_tokens)
Loading
Loading