Skip to content

Conversation

weisirui-eng
Copy link

@weisirui-eng weisirui-eng commented Sep 19, 2025

[splitting MTP into graph mode and non-graph mode]

Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MTP proposer by splitting it into a base MtpProposer for non-graph mode and a MtpTorchairProposer for graph mode. This is a good architectural change that improves separation of concerns. However, I've found a critical issue in the factory function get_spec_decode_method that could lead to incorrect proposer instantiation. Additionally, there is some code duplication in the new MtpTorchairProposer that should be addressed to improve maintainability.

Comment on lines 30 to 33
elif method == 'deepseek_mtp' and not is_torchair_graph:
return MtpProposer(vllm_config, device, runner)
elif is_torchair_graph:
return MtpTorchairProposer(vllm_config, device, runner)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for selecting the speculative decoding method is flawed. If is_torchair_graph is True, it will return MtpTorchairProposer for any method that is not ngram or eagle/eagle3. This is likely not the intended behavior. For example, if a new speculative decoding method is added in the future, and is_torchair_graph is true, it would incorrectly return MtpTorchairProposer.

I suggest refactoring the logic to be clearer and more robust by handling the deepseek_mtp method in a single block.

Suggested change
elif method == 'deepseek_mtp' and not is_torchair_graph:
return MtpProposer(vllm_config, device, runner)
elif is_torchair_graph:
return MtpTorchairProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
if is_torchair_graph:
return MtpTorchairProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this

Comment on lines 219 to 277
def _prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
is_torchair_graph: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
if is_torchair_graph:
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
else:
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0

# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.zeros(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)

BLOCK_SIZE = 1024
self._prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
target_token_ids = token_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = slot_mapping[token_indices]
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _prepare_inputs method is identical to the one in the parent class MtpProposer. This code duplication is unnecessary and makes the code harder to maintain. Since MtpTorchairProposer inherits from MtpProposer, you can remove this method from MtpTorchairProposer and it will inherit the implementation from its parent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should remove torchair unrelated codes in _prepare_inputs

Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants