22
22
from vllm_ascend .attention .attention_v1 import AscendAttentionState
23
23
from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
24
24
from vllm_ascend .spec_decode .interface import Proposer , SpecDcodeType
25
+ from vllm_ascend .utils import vllm_version_is
25
26
26
27
PADDING_SLOT_ID = - 1
27
28
@@ -139,8 +140,6 @@ def generate_token_ids(self,
139
140
hidden_states : torch .Tensor = None ,
140
141
attn_metadata = None ,
141
142
aux_hidden_states : torch .Tensor = None ):
142
- if self .name == SpecDcodeType .EAGLE :
143
- raise NotImplementedError ("Eagle Is Not Supported Yet." )
144
143
145
144
attn_metadata = self ._get_eagle_atten_dict (scheduler_output )
146
145
next_token_ids : list [int ] = []
@@ -355,8 +354,12 @@ def _get_eagle_atten_dict(
355
354
decode_token_per_req = self .runner .decode_token_per_req ,
356
355
num_computed_tokens_cpu = None ,
357
356
seq_lens = None )
358
- attn_metadata_i = self .runner .attn_metadata_builder .build (
359
- common_attn_metadata , self .runner .get_model ())
357
+ if vllm_version_is ("0.10.2" ):
358
+ builder = self .runner .attn_groups [0 ][0 ].metadata_builder
359
+ else :
360
+ builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
361
+ attn_metadata_i = builder .build (0 , common_attn_metadata ,
362
+ self .runner .get_model ())
360
363
for layer_name in kv_cache_group_spec .layer_names :
361
364
attn_metadata [layer_name ] = attn_metadata_i
362
365
@@ -418,16 +421,19 @@ def _propose(
418
421
self .input_ids [:num_tokens - 1 ] = target_token_ids [1 :]
419
422
# Replace the last token with the next token.
420
423
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
421
- self .input_ids [last_token_indices ] = next_token_ids [0 ]
424
+ self .input_ids [last_token_indices ] = next_token_ids
425
+ seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
422
426
423
427
query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
424
428
max_query_len = query_lens .max ().item ()
429
+ attn_mask = self .attn_mask_builder .get_splitfuse_attn_mask (
430
+ seq_lens , target_positions , self .vllm_config .model_config .dtype ,
431
+ self .device )
425
432
426
433
common_attn_metadata = AscendCommonAttentionMetadata (
427
- query_start_loc = self .runner .query_start_loc [:batch_size + 1 ],
428
- query_start_loc_cpu = self .runner .query_start_loc_cpu [:batch_size +
429
- 1 ],
430
- seq_lens_cpu = self .runner .seq_lens_cpu ,
434
+ query_start_loc = cu_num_tokens .to (device ),
435
+ query_start_loc_cpu = cu_num_tokens ,
436
+ seq_lens_cpu = seq_lens .cpu (),
431
437
max_query_len = max_query_len ,
432
438
num_reqs = batch_size ,
433
439
num_actual_tokens = num_tokens ,
@@ -436,15 +442,19 @@ def _propose(
436
442
get_device_tensor (),
437
443
slot_mapping = target_slot_mapping ,
438
444
positions = target_positions ,
439
- attn_mask = self . runner . attn_mask ,
445
+ attn_mask = attn_mask ,
440
446
spec_attn_mask = self .runner .spec_attn_mask ,
441
447
attn_state = self .runner .attn_state ,
442
448
decode_token_per_req = self .runner .decode_token_per_req ,
443
449
num_computed_tokens_cpu = None ,
444
450
seq_lens = None )
445
451
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
446
- attn_metadata = self .runner .attn_metadata_builder .build (
447
- common_attn_metadata , self .runner .model )
452
+ if vllm_version_is ("0.10.2" ):
453
+ builder = self .runner .attn_groups [0 ][0 ].metadata_builder
454
+ else :
455
+ builder = self .runner .attn_groups [0 ][0 ].get_metadata_builder ()
456
+ attn_metadata = builder .build (0 , common_attn_metadata ,
457
+ self .runner .get_model ())
448
458
if self .use_cuda_graph and \
449
459
num_tokens <= self .cudagraph_batch_sizes [- 1 ]:
450
460
num_input_tokens = self .vllm_config .pad_for_cudagraph (num_tokens )
@@ -471,7 +481,10 @@ def _propose(
471
481
hidden_states = self .hidden_states [:num_input_tokens ],
472
482
)
473
483
sample_hidden_states = last_hidden_states [last_token_indices ]
474
- logits = self .model .compute_logits (sample_hidden_states , None )
484
+ if vllm_version_is ("0.10.2" ):
485
+ logits = self .model .compute_logits (sample_hidden_states , None )
486
+ else :
487
+ logits = self .model .compute_logits (sample_hidden_states )
475
488
draft_token_ids = logits .argmax (dim = - 1 )
476
489
477
490
# Early exit if there is only one draft token to be generated.
@@ -501,9 +514,8 @@ def _propose(
501
514
attn_metadata .num_actual_tokens = batch_size
502
515
attn_metadata .max_query_len = 1
503
516
attn_metadata .query_start_loc = self .arange [:batch_size + 1 ]
504
-
505
- if self .vllm_config .speculative_config .num_speculative_tokens > 2 :
506
- raise ValueError ("Speculative tokens > 2 are not supported yet." )
517
+ query_lens .fill_ (1 )
518
+ attn_metadata .query_lens = query_lens
507
519
508
520
attn_metadata .attn_state = AscendAttentionState .ChunkedPrefill
509
521
for now_speculative in range (
@@ -558,9 +570,8 @@ def _propose(
558
570
self .input_ids [:batch_size ] = input_ids
559
571
self .positions [:batch_size ] = clamped_positions
560
572
self .hidden_states [:batch_size ] = hidden_states
561
- positions = positions_cpu .to (device )
562
573
attn_mask = self .attn_mask_builder .get_splitfuse_attn_mask (
563
- attn_metadata .seq_lens , positions ,
574
+ attn_metadata .seq_lens , positions_cpu ,
564
575
self .vllm_config .model_config .dtype , self .device )
565
576
566
577
attn_metadata .attn_mask = attn_mask
@@ -577,8 +588,12 @@ def _propose(
577
588
hidden_states = self .hidden_states [:input_batch_size ],
578
589
)
579
590
hidden_states = hidden_states [:batch_size ]
580
- logits = self .model .compute_logits (last_hidden_states [:batch_size ],
581
- None )
591
+ if vllm_version_is ("0.10.2" ):
592
+ logits = self .model .compute_logits (
593
+ last_hidden_states [:batch_size ], None )
594
+ else :
595
+ logits = self .model .compute_logits (
596
+ last_hidden_states [:batch_size ])
582
597
583
598
# TODO(wenlong): get more than one token for tree attention
584
599
draft_token_ids = logits .argmax (dim = - 1 )
@@ -652,7 +667,8 @@ def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
652
667
dtype = torch .int32 ,
653
668
device = out_tensor .device ) + offset_tensor
654
669
values_to_store = torch .tensor (
655
- index_start , dtype = torch .int32 ,
670
+ index_start + global_start_offset ,
671
+ dtype = torch .int32 ,
656
672
device = out_tensor .device ) + offset_tensor
657
673
mask = (target_indices >= start_pos ) & \
658
674
(target_indices < end_pos ) & \
0 commit comments