Skip to content

Commit 3df1171

Browse files
committed
[V1] Resolve failed concurrent structred output requests
Closes vllm-project#19493 Closes vllm-project#18376 Related to vllm-project#18780 Several people have noticed errors when using both the `xgrammar` and `guidance` backends where we would start generating invalid tokens for a request and they would be continuously rejected by the backend currently in use. The conditions seemed to be: - Only impacts certain models - Occurs with concurrent structured output requests After further investigation once an easy way to reproduce was provided via vllm-project#19493, I identified more details about the failure: - When the failured occurred in my test using a concurrency of 2, whichever request came in first was always successful. It was the second request that would fail. Debugging further identified that the bitmask was not being applied correctly, but only for that second request. In the GPU model runner, this translates to the 2nd row in the bitmask tensor and the 2nd row of the logits tensor. I could see that a couple bytes were left unmasked. I suspect the reason the issue appears to be model specific has to do with the vocab and what the tokens are that were left unmasked. I have not verified this part for sure. The reason it occurred with both structured output backends is because we use the `xgrammar` library's implementation of applying the bitmask in all cases. Xgrammar on cuda, by default, uses a triton kernel for applying the bitmask. I identified that by forcing it to use the `torch.compile` implementation instead, the problem is resolved. The torch implementation is used for all other accelerator types in Xgrammar's logic, so it seems fine to just force the use of that implementation. I have not yet narrowed down the problem in triton kernel, but this change works around the problem for vLLM. We can move back to Xgrammar's wrapper that chooses which implementation to use once we can verify everything is working properly again. Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 29fa5ca commit 3df1171

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@
6565

6666
if TYPE_CHECKING:
6767
import xgrammar as xgr
68+
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
6869

6970
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
7071
from vllm.v1.core.sched.output import SchedulerOutput
7172
else:
7273
xgr = LazyLoader("xgr", globals(), "xgrammar")
74+
xgr_torch_compile = LazyLoader(
75+
"xgr_torch_compile", globals(),
76+
"xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
7377

7478
logger = init_logger(__name__)
7579

@@ -1102,7 +1106,7 @@ def apply_grammar_bitmask(
11021106
# so we receive it in that format.
11031107
grammar_bitmask = torch.from_numpy(grammar_bitmask)
11041108

1105-
xgr.apply_token_bitmask_inplace(
1109+
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
11061110
logits,
11071111
grammar_bitmask.to(self.device, non_blocking=True),
11081112
indices=out_indices,

0 commit comments

Comments
 (0)