|
23 | 23 |
|
24 | 24 | logger = init_logger(__name__) |
25 | 25 |
|
| 26 | +_FUSED_MOE_WORKSPACE_CACHE: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} |
| 27 | + |
| 28 | + |
| 29 | +def _get_fused_moe_workspace( |
| 30 | + *, |
| 31 | + hidden_states: torch.Tensor, |
| 32 | + m: int, |
| 33 | + top_k: int, |
| 34 | + intermediate_size_times_2: int, |
| 35 | + hidden_size: int, |
| 36 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 37 | + """Return long-lived workspaces for CUDA graph capture/replay. |
| 38 | +
|
| 39 | + vLLM's fused MoE allocates intermediate tensors inside the forward call. In |
| 40 | + eager mode that is fine, but CUDA graph replay needs captured tensor |
| 41 | + addresses to remain valid. Keeping the workspace tensors in a module-level |
| 42 | + cache mirrors vLLM's static-buffer usage in graph mode and avoids replaying |
| 43 | + kernels against allocator-reclaimed temporary buffers. |
| 44 | + """ |
| 45 | + key = ( |
| 46 | + hidden_states.device.type, |
| 47 | + hidden_states.device.index, |
| 48 | + hidden_states.dtype, |
| 49 | + int(m), |
| 50 | + int(top_k), |
| 51 | + int(intermediate_size_times_2), |
| 52 | + int(hidden_size), |
| 53 | + ) |
| 54 | + workspace = _FUSED_MOE_WORKSPACE_CACHE.get(key) |
| 55 | + if workspace is None: |
| 56 | + workspace = ( |
| 57 | + torch.empty( |
| 58 | + (m, top_k, intermediate_size_times_2), |
| 59 | + device=hidden_states.device, |
| 60 | + dtype=hidden_states.dtype, |
| 61 | + ), |
| 62 | + torch.empty( |
| 63 | + (m * top_k, intermediate_size_times_2 // 2), |
| 64 | + device=hidden_states.device, |
| 65 | + dtype=hidden_states.dtype, |
| 66 | + ), |
| 67 | + torch.empty( |
| 68 | + (m, top_k, hidden_size), |
| 69 | + device=hidden_states.device, |
| 70 | + dtype=hidden_states.dtype, |
| 71 | + ), |
| 72 | + ) |
| 73 | + _FUSED_MOE_WORKSPACE_CACHE[key] = workspace |
| 74 | + return workspace |
| 75 | + |
26 | 76 |
|
27 | 77 | @triton.jit |
28 | 78 | def fused_moe_kernel_gptq_awq( |
@@ -1178,15 +1228,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, |
1178 | 1228 |
|
1179 | 1229 | config = get_config_func(M) |
1180 | 1230 |
|
1181 | | - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), |
1182 | | - device=hidden_states.device, |
1183 | | - dtype=hidden_states.dtype) |
1184 | | - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), |
1185 | | - device=hidden_states.device, |
1186 | | - dtype=hidden_states.dtype) |
1187 | | - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), |
1188 | | - device=hidden_states.device, |
1189 | | - dtype=hidden_states.dtype) |
| 1231 | + intermediate_cache1, intermediate_cache2, intermediate_cache3 = ( |
| 1232 | + _get_fused_moe_workspace( |
| 1233 | + hidden_states=hidden_states, |
| 1234 | + m=M, |
| 1235 | + top_k=topk_ids.shape[1], |
| 1236 | + intermediate_size_times_2=N, |
| 1237 | + hidden_size=w2.shape[1], |
| 1238 | + ) |
| 1239 | + ) |
1190 | 1240 |
|
1191 | 1241 | if hidden_states.dtype == torch.bfloat16: |
1192 | 1242 | compute_type = tl.bfloat16 |
|
0 commit comments