Skip to content

Commit 1735463

Browse files
russellbminpeter
authored andcommitted
[V1] Resolve failed concurrent structured output requests (vllm-project#19565)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 1b09728 commit 1735463

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@
6666

6767
if TYPE_CHECKING:
6868
import xgrammar as xgr
69+
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
6970

7071
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
7172
from vllm.v1.core.sched.output import SchedulerOutput
7273
else:
7374
xgr = LazyLoader("xgr", globals(), "xgrammar")
75+
xgr_torch_compile = LazyLoader(
76+
"xgr_torch_compile", globals(),
77+
"xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
7478

7579
logger = init_logger(__name__)
7680

@@ -1103,7 +1107,10 @@ def apply_grammar_bitmask(
11031107
# so we receive it in that format.
11041108
grammar_bitmask = torch.from_numpy(grammar_bitmask)
11051109

1106-
xgr.apply_token_bitmask_inplace(
1110+
# Force use of the torch.compile implementation from xgrammar to work
1111+
# around issues with the Triton kernel in concurrent structured output
1112+
# scenarios. See PR #19565 and issues #19493, #18376 for details.
1113+
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
11071114
logits,
11081115
grammar_bitmask.to(self.device, non_blocking=True),
11091116
indices=out_indices,

0 commit comments

Comments
 (0)