Skip to content

Commit 313883e

Browse files
committed
feat: implement workspace caching for fused MoE in vllm_fuse_moe.py
- Added a new function `_get_fused_moe_workspace` to manage long-lived workspaces for CUDA graph capture/replay, improving memory management during model execution. - Replaced direct tensor allocations in `fused_experts_impl` with calls to the new workspace function, enhancing efficiency and reducing memory overhead.
1 parent 8a60621 commit 313883e

1 file changed

Lines changed: 59 additions & 9 deletions

File tree

diffulex_kernel/python/vllm_fuse_moe.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,56 @@
2323

2424
logger = init_logger(__name__)
2525

26+
_FUSED_MOE_WORKSPACE_CACHE: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
27+
28+
29+
def _get_fused_moe_workspace(
30+
*,
31+
hidden_states: torch.Tensor,
32+
m: int,
33+
top_k: int,
34+
intermediate_size_times_2: int,
35+
hidden_size: int,
36+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
37+
"""Return long-lived workspaces for CUDA graph capture/replay.
38+
39+
vLLM's fused MoE allocates intermediate tensors inside the forward call. In
40+
eager mode that is fine, but CUDA graph replay needs captured tensor
41+
addresses to remain valid. Keeping the workspace tensors in a module-level
42+
cache mirrors vLLM's static-buffer usage in graph mode and avoids replaying
43+
kernels against allocator-reclaimed temporary buffers.
44+
"""
45+
key = (
46+
hidden_states.device.type,
47+
hidden_states.device.index,
48+
hidden_states.dtype,
49+
int(m),
50+
int(top_k),
51+
int(intermediate_size_times_2),
52+
int(hidden_size),
53+
)
54+
workspace = _FUSED_MOE_WORKSPACE_CACHE.get(key)
55+
if workspace is None:
56+
workspace = (
57+
torch.empty(
58+
(m, top_k, intermediate_size_times_2),
59+
device=hidden_states.device,
60+
dtype=hidden_states.dtype,
61+
),
62+
torch.empty(
63+
(m * top_k, intermediate_size_times_2 // 2),
64+
device=hidden_states.device,
65+
dtype=hidden_states.dtype,
66+
),
67+
torch.empty(
68+
(m, top_k, hidden_size),
69+
device=hidden_states.device,
70+
dtype=hidden_states.dtype,
71+
),
72+
)
73+
_FUSED_MOE_WORKSPACE_CACHE[key] = workspace
74+
return workspace
75+
2676

2777
@triton.jit
2878
def fused_moe_kernel_gptq_awq(
@@ -1178,15 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor,
11781228

11791229
config = get_config_func(M)
11801230

1181-
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
1182-
device=hidden_states.device,
1183-
dtype=hidden_states.dtype)
1184-
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
1185-
device=hidden_states.device,
1186-
dtype=hidden_states.dtype)
1187-
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
1188-
device=hidden_states.device,
1189-
dtype=hidden_states.dtype)
1231+
intermediate_cache1, intermediate_cache2, intermediate_cache3 = (
1232+
_get_fused_moe_workspace(
1233+
hidden_states=hidden_states,
1234+
m=M,
1235+
top_k=topk_ids.shape[1],
1236+
intermediate_size_times_2=N,
1237+
hidden_size=w2.shape[1],
1238+
)
1239+
)
11901240

11911241
if hidden_states.dtype == torch.bfloat16:
11921242
compute_type = tl.bfloat16

0 commit comments

Comments
 (0)