1
-
2
1
import torch
3
2
import torch .nn as nn
4
3
from vllm .attention .layer import Attention
@@ -49,7 +48,6 @@ def __init__(
49
48
dtype = self .runner .dtype ,
50
49
device = self .device )
51
50
52
-
53
51
# We need +1 here because the arange is used to set query_start_loc,
54
52
# which has one more element than batch_size.
55
53
self .arange = torch .arange (vllm_config .scheduler_config .max_num_seqs +
@@ -102,8 +100,6 @@ def dummy_run(self,
102
100
moe_comm_method = self .runner ._select_moe_comm_method (
103
101
num_tokens , with_prefill )
104
102
105
-
106
-
107
103
if skip_attn :
108
104
attn_metadata = None
109
105
else :
@@ -289,7 +285,6 @@ def _propose(
289
285
# Replace the last token with the next token.
290
286
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
291
287
292
-
293
288
self .input_ids [last_token_indices ] = next_token_ids
294
289
295
290
query_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
@@ -344,8 +339,6 @@ def _propose(
344
339
for layer_name in self .attn_layer_name :
345
340
attn_metadata [layer_name ] = attn_metadata_mtp
346
341
347
-
348
-
349
342
self .positions [:num_tokens ] = target_positions
350
343
self .hidden_states [:num_tokens ] = target_hidden_states
351
344
@@ -379,7 +372,6 @@ def _propose(
379
372
model_kwargs = {}
380
373
model_kwargs ["attn_metadata" ] = attn_metadata
381
374
382
-
383
375
hidden_states = self .model (
384
376
input_ids = self .input_ids [:num_input_tokens ],
385
377
positions = self .positions [:num_input_tokens ],
@@ -418,10 +410,8 @@ def _propose(
418
410
if step == self .num_speculative_tokens - 1 or with_prefill :
419
411
break
420
412
421
-
422
413
attn_metadata_i = attn_metadata [self .attn_layer_name [0 ]]
423
414
424
-
425
415
if step == 0 :
426
416
positions = target_positions [last_token_indices ]
427
417
hidden_states = hidden_states [last_token_indices ]
@@ -432,7 +422,6 @@ def _propose(
432
422
if attn_metadata_i .num_decode_tokens != 0 :
433
423
attn_metadata_i .num_decode_tokens = batch_size
434
424
435
-
436
425
input_ids = draft_token_ids_list [- 1 ].int ()
437
426
positions += 1
438
427
@@ -489,7 +478,6 @@ def _propose(
489
478
draft_token_ids = torch .stack (draft_token_ids_list , dim = 1 )
490
479
return draft_token_ids
491
480
492
-
493
481
# TODO Using torch instead of triton may result in poor performance
494
482
def _prepare_input_kernel (self , out_ptr : torch .Tensor ,
495
483
cu_query_lens : torch .Tensor ,
0 commit comments