Skip to content

Commit e100d1a

Browse files
committed
Resolve formatting issues
Signed-off-by: weisirui-eng <weisirui@h-partners.com>
1 parent 3759c80 commit e100d1a

File tree

2 files changed

+1
-20
lines changed

2 files changed

+1
-20
lines changed

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import torch
32
import torch.nn as nn
43
from vllm.attention.layer import Attention
@@ -49,7 +48,6 @@ def __init__(
4948
dtype=self.runner.dtype,
5049
device=self.device)
5150

52-
5351
# We need +1 here because the arange is used to set query_start_loc,
5452
# which has one more element than batch_size.
5553
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -102,8 +100,6 @@ def dummy_run(self,
102100
moe_comm_method = self.runner._select_moe_comm_method(
103101
num_tokens, with_prefill)
104102

105-
106-
107103
if skip_attn:
108104
attn_metadata = None
109105
else:
@@ -289,7 +285,6 @@ def _propose(
289285
# Replace the last token with the next token.
290286
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
291287

292-
293288
self.input_ids[last_token_indices] = next_token_ids
294289

295290
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
@@ -344,8 +339,6 @@ def _propose(
344339
for layer_name in self.attn_layer_name:
345340
attn_metadata[layer_name] = attn_metadata_mtp
346341

347-
348-
349342
self.positions[:num_tokens] = target_positions
350343
self.hidden_states[:num_tokens] = target_hidden_states
351344

@@ -379,7 +372,6 @@ def _propose(
379372
model_kwargs = {}
380373
model_kwargs["attn_metadata"] = attn_metadata
381374

382-
383375
hidden_states = self.model(
384376
input_ids=self.input_ids[:num_input_tokens],
385377
positions=self.positions[:num_input_tokens],
@@ -418,10 +410,8 @@ def _propose(
418410
if step == self.num_speculative_tokens - 1 or with_prefill:
419411
break
420412

421-
422413
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
423414

424-
425415
if step == 0:
426416
positions = target_positions[last_token_indices]
427417
hidden_states = hidden_states[last_token_indices]
@@ -432,7 +422,6 @@ def _propose(
432422
if attn_metadata_i.num_decode_tokens != 0:
433423
attn_metadata_i.num_decode_tokens = batch_size
434424

435-
436425
input_ids = draft_token_ids_list[-1].int()
437426
positions += 1
438427

@@ -489,7 +478,6 @@ def _propose(
489478
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
490479
return draft_token_ids
491480

492-
493481
# TODO Using torch instead of triton may result in poor performance
494482
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
495483
cu_query_lens: torch.Tensor,

vllm_ascend/torchair/mtp_torchair_proposer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.v1.core.sched.output import SchedulerOutput
1616
from vllm.v1.sample.metadata import SamplingMetadata
1717
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
18+
1819
from vllm_ascend.ascend_config import get_ascend_config
1920
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
2021
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
@@ -56,7 +57,6 @@ def load_model(self, model) -> None:
5657
self.model = TorchairDeepSeekMTP(
5758
vllm_config=self.vllm_config).to(target_device)
5859

59-
6060
draft_attn_layer_names = (
6161
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
6262
target_attn_layer_names)
@@ -329,16 +329,12 @@ def _propose(
329329
num_computed_tokens_cpu=None,
330330
seq_lens=None)
331331

332-
333-
334-
335332
attn_metadata = self.runner.attn_metadata_builder.build(
336333
0, common_attn_metadata, self.runner.get_model())
337334

338335
self.positions[:num_tokens] = target_positions
339336
self.hidden_states[:num_tokens] = target_hidden_states
340337

341-
342338
# torchair mode can reuse self.runner.num_tokens_across_dp
343339
num_tokens_across_dp = self.runner.num_tokens_across_dp
344340
with_prefill = self.runner.with_prefill
@@ -420,7 +416,6 @@ def _propose(
420416
if step == self.num_speculative_tokens - 1 or with_prefill:
421417
break
422418

423-
424419
attn_metadata_i = attn_metadata
425420

426421
if step == 0:
@@ -546,5 +541,3 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
546541
config=config,
547542
ge_cache=False)
548543
return self.torchair_compiled_models[batch_size]
549-
550-

0 commit comments

Comments
 (0)