Skip to content

Commit cf96366

Browse files
paulyu12paulyu12
andauthored
[Bugfix][LoRA][Patch] Fix the LoRA inference bug after upstream vLLM codebase changed (#2560)
### What this PR does / why we need it? The mergence of the upstream vllm-project/vllm#22592 caused a vllm-ascend LoRA inference bug. The details are following: According to [torch_npu/npu/_stream_check.py](https://github.com/Ascend/pytorch/blob/863b9071cbdf47023c12c246e3efa9c6e2285fc6/torch_npu/npu/_stream_check.py#L74), NPU device type tensors have attributes is_cuda=True and is_npu=True. This causes that vLLM's apply_repetition_penalties function will run into the branch of "if logits.is_cuda and logits.is_contiguous()" and call the custom op implemented in CUDA, which is not compatible with NPU. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@fe8d7b6 --------- Signed-off-by: paulyu12 <paulyu0307@gmail.com> Signed-off-by: paulyu12 <507435917@qq.com> Co-authored-by: paulyu12 <paulyu0307@gmail.com>
1 parent 1191a64 commit cf96366

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717

1818
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
1919
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
20+
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
2021
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import vllm
3+
from vllm._custom_ops import apply_repetition_penalties_torch
4+
5+
6+
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
7+
output_mask: torch.Tensor,
8+
repetition_penalties: torch.Tensor) -> None:
9+
"""Apply repetition penalties to logits in-place.
10+
11+
Args:
12+
logits: The logits tensor of shape [num_seqs, vocab_size].
13+
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
14+
output_mask: A boolean tensor indicating which tokens appear in the output.
15+
repetition_penalties: The repetition penalties of shape (num_seqs, ).
16+
"""
17+
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
18+
repetition_penalties)
19+
20+
21+
# NPU device type tensors have attributes is_cuda=True and is_npu=True, according to its implementation in
22+
# https://github.com/Ascend/pytorch/blob/863b9071cbdf47023c12c246e3efa9c6e2285fc6/torch_npu/npu/_stream_check.py#L74
23+
# This causes that vLLM's apply_repetition_penalties function will run into the branch of "if logits.is_cuda" and
24+
# call the custom op implemented in CUDA, which is not compatible with NPU.
25+
# Reference: https://github.com/vllm-project/vllm/blob/f66673a39d9f364194c249f28098cad8a5584ccb/vllm/_custom_ops.py#L314
26+
vllm._custom_ops.apply_repetition_penalties = apply_repetition_penalties

0 commit comments

Comments
 (0)