-
Notifications
You must be signed in to change notification settings - Fork 453
splitting MTP into graph mode and non-graph mode #3030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MTP proposer by splitting it into a base MtpProposer
for non-graph mode and a MtpTorchairProposer
for graph mode. This is a good architectural change that improves separation of concerns. However, I've found a critical issue in the factory function get_spec_decode_method
that could lead to incorrect proposer instantiation. Additionally, there is some code duplication in the new MtpTorchairProposer
that should be addressed to improve maintainability.
vllm_ascend/spec_decode/__init__.py
Outdated
elif method == 'deepseek_mtp' and not is_torchair_graph: | ||
return MtpProposer(vllm_config, device, runner) | ||
elif is_torchair_graph: | ||
return MtpTorchairProposer(vllm_config, device, runner) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current logic for selecting the speculative decoding method is flawed. If is_torchair_graph
is True
, it will return MtpTorchairProposer
for any method
that is not ngram
or eagle
/eagle3
. This is likely not the intended behavior. For example, if a new speculative decoding method is added in the future, and is_torchair_graph
is true, it would incorrectly return MtpTorchairProposer
.
I suggest refactoring the logic to be clearer and more robust by handling the deepseek_mtp
method in a single block.
elif method == 'deepseek_mtp' and not is_torchair_graph: | |
return MtpProposer(vllm_config, device, runner) | |
elif is_torchair_graph: | |
return MtpTorchairProposer(vllm_config, device, runner) | |
elif method == 'deepseek_mtp': | |
if is_torchair_graph: | |
return MtpTorchairProposer(vllm_config, device, runner) | |
return MtpProposer(vllm_config, device, runner) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this
def _prepare_inputs( | ||
self, | ||
# [batch_size + 1] | ||
cu_target_query_lens: torch.Tensor, | ||
# [batch_size] | ||
num_rejected_tokens: torch.Tensor, | ||
token_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
slot_mapping: torch.Tensor, | ||
is_torchair_graph: bool = False | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, | ||
torch.Tensor, torch.Tensor]: | ||
# cu_target_query_lens: [0, a, a + b, a + b + c] | ||
# num_rejected_tokens: [n1, n2, n3] | ||
# num_tokens_per_req: [a - n1, b - n2, c - n3] | ||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] | ||
# token_indices: [0, 1, ..., a - n1 - 1, | ||
# a, a + 1, ..., a + b - n2 - 1, | ||
# a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
# [0, a, a + b, a + b + c] -> [a, b, c] | ||
query_len_per_req = (cu_target_query_lens[1:] - | ||
cu_target_query_lens[:-1]) | ||
# [a, b, c] -> [a - n1, b - n2, c - n3] | ||
num_tokens_per_req = query_len_per_req - num_rejected_tokens | ||
if is_torchair_graph: | ||
cu_num_tokens = cu_target_query_lens | ||
relative_index = query_len_per_req - num_rejected_tokens - 1 | ||
token_indices = cu_num_tokens[:-1] + relative_index | ||
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model | ||
target_token_ids = token_ids | ||
target_positions = positions | ||
target_hidden_states = hidden_states | ||
target_slot_mapping = slot_mapping | ||
else: | ||
cu_num_tokens = torch.empty_like(cu_target_query_lens) | ||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) | ||
cu_num_tokens[0] = 0 | ||
|
||
# FIXME(woosuk): Avoid synchronization. | ||
num_tokens = cu_num_tokens[-1].item() | ||
token_indices = torch.zeros( | ||
num_tokens, | ||
dtype=torch.int32, | ||
device=cu_num_tokens.device, | ||
) | ||
|
||
BLOCK_SIZE = 1024 | ||
self._prepare_input_kernel( | ||
token_indices, | ||
cu_target_query_lens, | ||
cu_num_tokens, | ||
block_size=BLOCK_SIZE, | ||
) | ||
target_token_ids = token_ids[token_indices] | ||
target_positions = positions[token_indices] | ||
target_hidden_states = hidden_states[token_indices] | ||
target_slot_mapping = slot_mapping[token_indices] | ||
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _prepare_inputs
method is identical to the one in the parent class MtpProposer
. This code duplication is unnecessary and makes the code harder to maintain. Since MtpTorchairProposer
inherits from MtpProposer
, you can remove this method from MtpTorchairProposer
and it will inherit the implementation from its parent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should remove torchair unrelated codes in _prepare_inputs
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
[splitting MTP into graph mode and non-graph mode]