|
19 | 19 |
|
20 | 20 | import math
|
21 | 21 | import types
|
22 |
| -from typing import Optional |
| 22 | +from typing import Optional, Dict |
23 | 23 |
|
24 | 24 | import torch
|
25 | 25 | import torch.distributed as dist
|
@@ -51,7 +51,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
51 | 51 | def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
52 | 52 | super().__init__(vllm_config, device)
|
53 | 53 | self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
54 |
| - vllm_config, device) |
| 54 | + None, None, vllm_config, device) |
55 | 55 |
|
56 | 56 | ascend_config = get_ascend_config()
|
57 | 57 | self.new_kv_cache_bytes = -1
|
@@ -135,7 +135,8 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
|
135 | 135 | is_torchair_compile, input_ids,
|
136 | 136 | positions, attn_metadata, num_tokens,
|
137 | 137 | intermediate_tensors, inputs_embeds):
|
138 |
| - |
| 138 | + if attn_metadata is not None and isinstance(attn_metadata, Dict): |
| 139 | + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] |
139 | 140 | if not with_prefill:
|
140 | 141 | # Only mark static while compiling
|
141 | 142 | if is_torchair_compile:
|
@@ -281,6 +282,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
281 | 282 | input_ids, positions,
|
282 | 283 | intermediate_tensors,
|
283 | 284 | inputs_embeds):
|
| 285 | + if attn_metadata is not None and isinstance(attn_metadata, Dict): |
| 286 | + attn_metadata = attn_metadata['model.layers.0.self_attn.attn'] |
284 | 287 | model_kwargs = {
|
285 | 288 | "kv_caches": self.kv_caches,
|
286 | 289 | "attn_metadata": attn_metadata
|
|
0 commit comments