|
1 | 1 | import torch
|
2 | 2 | import torch_npu
|
3 |
| -from vllm.config import LogprobsMode |
4 | 3 | from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
5 | 4 | from vllm.v1.sample.sampler import Sampler
|
6 | 5 |
|
7 |
| -from vllm_ascend.utils import is_310p |
| 6 | +from vllm_ascend.utils import is_310p, vllm_version_is |
8 | 7 |
|
9 |
| -DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS |
| 8 | +if vllm_version_is("0.10.2"): |
| 9 | + from vllm.config import LogprobsMode |
| 10 | + DEFAULT_LOGPROBS_MODE = LogprobsMode.RAW_LOGPROBS |
| 11 | +else: |
| 12 | + DEFAULT_LOGPROBS_MODE = "raw_logprobs" |
10 | 13 |
|
11 | 14 |
|
12 | 15 | class AscendSampler(Sampler):
|
@@ -65,10 +68,18 @@ def forward_native(self, logits, generators, k, p):
|
65 | 68 | """Override pytorch native implementation to torch_npu"""
|
66 | 69 | logits = self._apply_top_k_top_p(logits, k, p)
|
67 | 70 | logits_to_return = None
|
68 |
| - if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: |
69 |
| - logits_to_return = logits |
70 |
| - elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: |
71 |
| - logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) |
| 71 | + if vllm_version_is("0.10.2"): |
| 72 | + if self.logprobs_mode == LogprobsMode.PROCESSED_LOGITS: |
| 73 | + logits_to_return = logits |
| 74 | + elif self.logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS: |
| 75 | + logits_to_return = logits.log_softmax(dim=-1, |
| 76 | + dtype=torch.float32) |
| 77 | + else: |
| 78 | + if self.logprobs_mode == "processed_logits": |
| 79 | + logits_to_return = logits |
| 80 | + elif self.logprobs_mode == "processed_logprobs": |
| 81 | + logits_to_return = logits.log_softmax(dim=-1, |
| 82 | + dtype=torch.float32) |
72 | 83 |
|
73 | 84 | probs = logits.softmax(dim=-1, dtype=torch.float32)
|
74 | 85 | return random_sample(probs, generators), logits_to_return
|
0 commit comments