Skip to content

Commit c223200

Browse files
authored
[0.9.1][BUGFIX] [mtp][pd] FIX mtp torchair bug (#2610)
### What this PR does / why we need it? In the pd Disaggregation scenario, the first token of the inference after the d node receives the kv follows the eager mode. Fixes: Running with MTP torchair graph mode with Prefilling Decoding Disaggregation , if all requests processed by the D node are requests just transmitted from the P node, it will break the torchair graph. Reason: During PD Disaggregation , the P node only transmits the KV cache and prompt to the D node, not the actual tokens inferred (neither the main model tokens nor the MTP tokens are transmitted). Therefore, the D node will treat this request as one without MTP tokens for inference (seq_len=1). The community does not have graph mode issues because the community's attention has a seq_len=1 for each batch during the decode phase. We have issues because the graph mode pads according to processing 2 tokens per request. When there are some seq_len=1 and some seq_len=2, padding is done at the end. If all requests received by the D node are seq_len=1, padding cannot be performed normally according to the attention's fia operator constraints. Solution: The kv consumer uses extra torchair graph padding to avoid breaking FIA graph constrains (The one this PR implemented). The kv producer provides the correct tokens to the kv consumer, so that our graph mode constraints are not broken, and all logic is the same as the PD mixed deployment . Since we are using the community scheduler, the modification requires patching the vllm scheduler, but theoretically, performance should be better. (Maybe later ) Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 128c120 commit c223200

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
363363
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
364364
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
365365

366+
# kv role
367+
self.is_kv_producer = False
368+
self.is_kv_consumer = False
369+
if vllm_config.kv_transfer_config is not None:
370+
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
371+
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
372+
366373
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
367374
self.init_torchair_graph_batch_sizes()
368375

@@ -394,13 +401,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
394401
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
395402
self.in_profile_run = False
396403

397-
# kv role
398-
self.is_kv_producer = False
399-
self.is_kv_consumer = False
400-
if vllm_config.kv_transfer_config is not None:
401-
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
402-
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
403-
404404
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
405405
"""Update the cached states and the persistent batch with the scheduler
406406
output.
@@ -2339,7 +2339,13 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23392339
def check_torchair_graph_batch_sizes(self):
23402340
# return graph_batch_sizes according to the number of tokens
23412341
# first pad according to the number of requests
2342-
if len(self.torchair_graph_batch_sizes) == 0:
2342+
if self.is_kv_consumer:
2343+
# pd disaggregation scenario may incorrectly calculate the batch, so we force set it to max_num_reqs
2344+
self.torchair_graph_batch_sizes = [self.max_num_reqs]
2345+
logger.warning(
2346+
"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]"
2347+
)
2348+
elif len(self.torchair_graph_batch_sizes) == 0:
23432349
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
23442350
else:
23452351
self.torchair_graph_batch_sizes = sorted(
@@ -2355,10 +2361,23 @@ def check_torchair_graph_batch_sizes(self):
23552361
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
23562362

23572363
# we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1
2358-
self.torchair_graph_batch_sizes = [
2359-
graph_batch_size * self.decode_token_per_req
2360-
for graph_batch_size in self.torchair_graph_batch_sizes
2361-
]
2364+
if self.decode_token_per_req > 1:
2365+
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
2366+
if self.is_kv_consumer:
2367+
FIA_SEQ_LEN_LIMIT = 16
2368+
self.torchair_graph_batch_sizes = [
2369+
(graph_batch_size +
2370+
math.ceil(graph_batch_size / FIA_SEQ_LEN_LIMIT) +
2371+
math.ceil(graph_batch_size * self.decode_token_per_req /
2372+
FIA_SEQ_LEN_LIMIT / FIA_SEQ_LEN_LIMIT)) *
2373+
self.decode_token_per_req
2374+
for graph_batch_size in self.torchair_graph_batch_sizes
2375+
]
2376+
else:
2377+
self.torchair_graph_batch_sizes = [
2378+
graph_batch_size * self.decode_token_per_req
2379+
for graph_batch_size in self.torchair_graph_batch_sizes
2380+
]
23622381

23632382
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
23642383
tp_size = self.parallel_config.tensor_parallel_size

0 commit comments

Comments
 (0)